mirror of
https://github.com/chidiwilliams/buzz.git
synced 2024-05-22 02:26:33 +02:00
Add black formatting (#571)
This commit is contained in:
parent
f5f77b3908
commit
c498e60949
2
build.py
2
build.py
|
@ -2,7 +2,7 @@ import subprocess
|
|||
|
||||
|
||||
def build(setup_kwargs):
|
||||
subprocess.call(['make', 'buzz/whisper_cpp.py'])
|
||||
subprocess.call(["make", "buzz/whisper_cpp.py"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -4,7 +4,10 @@ from PyQt6.QtGui import QAction, QKeySequence
|
|||
|
||||
|
||||
class Action(QAction):
|
||||
def setShortcut(self, shortcut: typing.Union['QKeySequence', 'QKeySequence.StandardKey', str, int]) -> None:
|
||||
def setShortcut(
|
||||
self,
|
||||
shortcut: typing.Union["QKeySequence", "QKeySequence.StandardKey", str, int],
|
||||
) -> None:
|
||||
super().setShortcut(shortcut)
|
||||
self.setToolTip(Action.get_tooltip(self))
|
||||
|
||||
|
|
|
@ -3,6 +3,6 @@ import sys
|
|||
|
||||
|
||||
def get_asset_path(path: str):
|
||||
if getattr(sys, 'frozen', False):
|
||||
if getattr(sys, "frozen", False):
|
||||
return os.path.join(os.path.dirname(sys.executable), path)
|
||||
return os.path.join(os.path.dirname(__file__), '..', path)
|
||||
return os.path.join(os.path.dirname(__file__), "..", path)
|
||||
|
|
25
buzz/buzz.py
25
buzz/buzz.py
|
@ -9,7 +9,7 @@ from typing import TextIO
|
|||
from appdirs import user_log_dir
|
||||
|
||||
# Check for segfaults if not running in frozen mode
|
||||
if getattr(sys, 'frozen', False) is False:
|
||||
if getattr(sys, "frozen", False) is False:
|
||||
faulthandler.enable()
|
||||
|
||||
# Sets stderr to no-op TextIO when None (run as Windows GUI).
|
||||
|
@ -19,30 +19,35 @@ if sys.stderr is None:
|
|||
|
||||
# Adds the current directory to the PATH, so the ffmpeg binary get picked up:
|
||||
# https://stackoverflow.com/a/44352931/9830227
|
||||
app_dir = getattr(sys, '_MEIPASS', os.path.dirname(
|
||||
os.path.abspath(__file__)))
|
||||
app_dir = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
|
||||
os.environ["PATH"] += os.pathsep + app_dir
|
||||
|
||||
# Add the app directory to the DLL list: https://stackoverflow.com/a/64303856
|
||||
if platform.system() == 'Windows':
|
||||
if platform.system() == "Windows":
|
||||
os.add_dll_directory(app_dir)
|
||||
|
||||
|
||||
def main():
|
||||
if platform.system() == 'Linux':
|
||||
multiprocessing.set_start_method('spawn')
|
||||
if platform.system() == "Linux":
|
||||
multiprocessing.set_start_method("spawn")
|
||||
|
||||
# Fixes opening new window when app has been frozen on Windows:
|
||||
# https://stackoverflow.com/a/33979091
|
||||
multiprocessing.freeze_support()
|
||||
|
||||
log_dir = user_log_dir(appname='Buzz')
|
||||
log_dir = user_log_dir(appname="Buzz")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
log_format = "[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s"
|
||||
logging.basicConfig(filename=os.path.join(log_dir, 'logs.txt'), level=logging.DEBUG, format=log_format)
|
||||
log_format = (
|
||||
"[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s"
|
||||
)
|
||||
logging.basicConfig(
|
||||
filename=os.path.join(log_dir, "logs.txt"),
|
||||
level=logging.DEBUG,
|
||||
format=log_format,
|
||||
)
|
||||
|
||||
if getattr(sys, 'frozen', False) is False:
|
||||
if getattr(sys, "frozen", False) is False:
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
stdout_handler.setLevel(logging.DEBUG)
|
||||
stdout_handler.setFormatter(logging.Formatter(log_format))
|
||||
|
|
|
@ -9,11 +9,11 @@ from .transcriber import FileTranscriptionTask
|
|||
|
||||
|
||||
class TasksCache:
|
||||
def __init__(self, cache_dir=user_cache_dir('Buzz')):
|
||||
def __init__(self, cache_dir=user_cache_dir("Buzz")):
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.cache_dir = cache_dir
|
||||
self.pickle_cache_file_path = os.path.join(cache_dir, 'tasks')
|
||||
self.tasks_list_file_path = os.path.join(cache_dir, 'tasks.json')
|
||||
self.pickle_cache_file_path = os.path.join(cache_dir, "tasks")
|
||||
self.tasks_list_file_path = os.path.join(cache_dir, "tasks.json")
|
||||
|
||||
def save(self, tasks: List[FileTranscriptionTask]):
|
||||
self.save_json_tasks(tasks=tasks)
|
||||
|
@ -23,16 +23,20 @@ class TasksCache:
|
|||
return self.load_json_tasks()
|
||||
|
||||
try:
|
||||
with open(self.pickle_cache_file_path, 'rb') as file:
|
||||
with open(self.pickle_cache_file_path, "rb") as file:
|
||||
return pickle.load(file)
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
except (pickle.UnpicklingError, AttributeError, ValueError): # delete corrupted cache
|
||||
except (
|
||||
pickle.UnpicklingError,
|
||||
AttributeError,
|
||||
ValueError,
|
||||
): # delete corrupted cache
|
||||
os.remove(self.pickle_cache_file_path)
|
||||
return []
|
||||
|
||||
def load_json_tasks(self) -> List[FileTranscriptionTask]:
|
||||
with open(self.tasks_list_file_path, 'r') as file:
|
||||
with open(self.tasks_list_file_path, "r") as file:
|
||||
task_ids = json.load(file)
|
||||
|
||||
tasks = []
|
||||
|
@ -57,7 +61,7 @@ class TasksCache:
|
|||
file.write(json_str)
|
||||
|
||||
def get_task_path(self, task_id: int):
|
||||
path = os.path.join(self.cache_dir, 'transcriptions', f'{task_id}.json')
|
||||
path = os.path.join(self.cache_dir, "transcriptions", f"{task_id}.json")
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
return path
|
||||
|
||||
|
|
173
buzz/cli.py
173
buzz/cli.py
|
@ -7,8 +7,14 @@ from PyQt6.QtCore import QCommandLineParser, QCommandLineOption
|
|||
from buzz.gui import Application
|
||||
from buzz.model_loader import ModelType, WhisperModelSize, TranscriptionModel
|
||||
from buzz.store.keyring_store import KeyringStore
|
||||
from buzz.transcriber import Task, FileTranscriptionTask, FileTranscriptionOptions, TranscriptionOptions, LANGUAGES, \
|
||||
OutputFormat
|
||||
from buzz.transcriber import (
|
||||
Task,
|
||||
FileTranscriptionTask,
|
||||
FileTranscriptionOptions,
|
||||
TranscriptionOptions,
|
||||
LANGUAGES,
|
||||
OutputFormat,
|
||||
)
|
||||
|
||||
|
||||
class CommandLineError(Exception):
|
||||
|
@ -17,11 +23,11 @@ class CommandLineError(Exception):
|
|||
|
||||
|
||||
class CommandLineModelType(enum.Enum):
|
||||
WHISPER = 'whisper'
|
||||
WHISPER_CPP = 'whispercpp'
|
||||
HUGGING_FACE = 'huggingface'
|
||||
FASTER_WHISPER = 'fasterwhisper'
|
||||
OPEN_AI_WHISPER_API = 'openaiapi'
|
||||
WHISPER = "whisper"
|
||||
WHISPER_CPP = "whispercpp"
|
||||
HUGGING_FACE = "huggingface"
|
||||
FASTER_WHISPER = "fasterwhisper"
|
||||
OPEN_AI_WHISPER_API = "openaiapi"
|
||||
|
||||
|
||||
def parse_command_line(app: Application):
|
||||
|
@ -29,13 +35,13 @@ def parse_command_line(app: Application):
|
|||
try:
|
||||
parse(app, parser)
|
||||
except CommandLineError as exc:
|
||||
print(f'Error: {str(exc)}\n', file=sys.stderr)
|
||||
print(f"Error: {str(exc)}\n", file=sys.stderr)
|
||||
print(parser.helpText())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def parse(app: Application, parser: QCommandLineParser):
|
||||
parser.addPositionalArgument('<command>', 'One of the following commands:\n- add')
|
||||
parser.addPositionalArgument("<command>", "One of the following commands:\n- add")
|
||||
parser.parse(app.arguments())
|
||||
|
||||
args = parser.positionalArguments()
|
||||
|
@ -50,36 +56,63 @@ def parse(app: Application, parser: QCommandLineParser):
|
|||
if command == "add":
|
||||
parser.clearPositionalArguments()
|
||||
|
||||
parser.addPositionalArgument('files', 'Input file paths', '[file file file...]')
|
||||
parser.addPositionalArgument("files", "Input file paths", "[file file file...]")
|
||||
|
||||
task_option = QCommandLineOption(['t', 'task'],
|
||||
f'The task to perform. Allowed: {join_values(Task)}. Default: {Task.TRANSCRIBE.value}.',
|
||||
'task',
|
||||
Task.TRANSCRIBE.value)
|
||||
model_type_option = QCommandLineOption(['m', 'model-type'],
|
||||
f'Model type. Allowed: {join_values(CommandLineModelType)}. Default: {CommandLineModelType.WHISPER.value}.',
|
||||
'model-type',
|
||||
CommandLineModelType.WHISPER.value)
|
||||
model_size_option = QCommandLineOption(['s', 'model-size'],
|
||||
f'Model size. Use only when --model-type is whisper, whispercpp, or fasterwhisper. Allowed: {join_values(WhisperModelSize)}. Default: {WhisperModelSize.TINY.value}.',
|
||||
'model-size', WhisperModelSize.TINY.value)
|
||||
hugging_face_model_id_option = QCommandLineOption(['hfid'],
|
||||
f'Hugging Face model ID. Use only when --model-type is huggingface. Example: "openai/whisper-tiny"',
|
||||
'id')
|
||||
language_option = QCommandLineOption(['l', 'language'],
|
||||
f'Language code. Allowed: {", ".join(sorted([k + " (" + LANGUAGES[k].title() + ")" for k in LANGUAGES]))}. Leave empty to detect language.',
|
||||
'code', '')
|
||||
initial_prompt_option = QCommandLineOption(['p', 'prompt'], f'Initial prompt', 'prompt', '')
|
||||
open_ai_access_token_option = QCommandLineOption('openai-token',
|
||||
f'OpenAI access token. Use only when --model-type is {CommandLineModelType.OPEN_AI_WHISPER_API.value}. Defaults to your previously saved access token, if one exists.',
|
||||
'token')
|
||||
srt_option = QCommandLineOption(['srt'], 'Output result in an SRT file.')
|
||||
vtt_option = QCommandLineOption(['vtt'], 'Output result in a VTT file.')
|
||||
txt_option = QCommandLineOption('txt', 'Output result in a TXT file.')
|
||||
task_option = QCommandLineOption(
|
||||
["t", "task"],
|
||||
f"The task to perform. Allowed: {join_values(Task)}. Default: {Task.TRANSCRIBE.value}.",
|
||||
"task",
|
||||
Task.TRANSCRIBE.value,
|
||||
)
|
||||
model_type_option = QCommandLineOption(
|
||||
["m", "model-type"],
|
||||
f"Model type. Allowed: {join_values(CommandLineModelType)}. Default: {CommandLineModelType.WHISPER.value}.",
|
||||
"model-type",
|
||||
CommandLineModelType.WHISPER.value,
|
||||
)
|
||||
model_size_option = QCommandLineOption(
|
||||
["s", "model-size"],
|
||||
f"Model size. Use only when --model-type is whisper, whispercpp, or fasterwhisper. Allowed: {join_values(WhisperModelSize)}. Default: {WhisperModelSize.TINY.value}.",
|
||||
"model-size",
|
||||
WhisperModelSize.TINY.value,
|
||||
)
|
||||
hugging_face_model_id_option = QCommandLineOption(
|
||||
["hfid"],
|
||||
f'Hugging Face model ID. Use only when --model-type is huggingface. Example: "openai/whisper-tiny"',
|
||||
"id",
|
||||
)
|
||||
language_option = QCommandLineOption(
|
||||
["l", "language"],
|
||||
f'Language code. Allowed: {", ".join(sorted([k + " (" + LANGUAGES[k].title() + ")" for k in LANGUAGES]))}. Leave empty to detect language.',
|
||||
"code",
|
||||
"",
|
||||
)
|
||||
initial_prompt_option = QCommandLineOption(
|
||||
["p", "prompt"], f"Initial prompt", "prompt", ""
|
||||
)
|
||||
open_ai_access_token_option = QCommandLineOption(
|
||||
"openai-token",
|
||||
f"OpenAI access token. Use only when --model-type is {CommandLineModelType.OPEN_AI_WHISPER_API.value}. Defaults to your previously saved access token, if one exists.",
|
||||
"token",
|
||||
)
|
||||
srt_option = QCommandLineOption(["srt"], "Output result in an SRT file.")
|
||||
vtt_option = QCommandLineOption(["vtt"], "Output result in a VTT file.")
|
||||
txt_option = QCommandLineOption("txt", "Output result in a TXT file.")
|
||||
|
||||
parser.addOptions(
|
||||
[task_option, model_type_option, model_size_option, hugging_face_model_id_option, language_option,
|
||||
initial_prompt_option, open_ai_access_token_option, srt_option, vtt_option, txt_option])
|
||||
[
|
||||
task_option,
|
||||
model_type_option,
|
||||
model_size_option,
|
||||
hugging_face_model_id_option,
|
||||
language_option,
|
||||
initial_prompt_option,
|
||||
open_ai_access_token_option,
|
||||
srt_option,
|
||||
vtt_option,
|
||||
txt_option,
|
||||
]
|
||||
)
|
||||
|
||||
parser.addHelpOption()
|
||||
parser.addVersionOption()
|
||||
|
@ -89,7 +122,7 @@ def parse(app: Application, parser: QCommandLineParser):
|
|||
# slice after first argument, the command
|
||||
file_paths = parser.positionalArguments()[1:]
|
||||
if len(file_paths) == 0:
|
||||
raise CommandLineError('No input files')
|
||||
raise CommandLineError("No input files")
|
||||
|
||||
task = parse_enum_option(task_option, parser, Task)
|
||||
|
||||
|
@ -98,21 +131,29 @@ def parse(app: Application, parser: QCommandLineParser):
|
|||
|
||||
hugging_face_model_id = parser.value(hugging_face_model_id_option)
|
||||
|
||||
if hugging_face_model_id == '' and model_type == CommandLineModelType.HUGGING_FACE:
|
||||
raise CommandLineError('--hfid is required when --model-type is huggingface')
|
||||
if (
|
||||
hugging_face_model_id == ""
|
||||
and model_type == CommandLineModelType.HUGGING_FACE
|
||||
):
|
||||
raise CommandLineError(
|
||||
"--hfid is required when --model-type is huggingface"
|
||||
)
|
||||
|
||||
model = TranscriptionModel(model_type=ModelType[model_type.name], whisper_model_size=model_size,
|
||||
hugging_face_model_id=hugging_face_model_id)
|
||||
model = TranscriptionModel(
|
||||
model_type=ModelType[model_type.name],
|
||||
whisper_model_size=model_size,
|
||||
hugging_face_model_id=hugging_face_model_id,
|
||||
)
|
||||
model_path = model.get_local_model_path()
|
||||
|
||||
if model_path is None:
|
||||
raise CommandLineError('Model not found')
|
||||
raise CommandLineError("Model not found")
|
||||
|
||||
language = parser.value(language_option)
|
||||
if language == '':
|
||||
if language == "":
|
||||
language = None
|
||||
elif LANGUAGES.get(language) is None:
|
||||
raise CommandLineError('Invalid language option')
|
||||
raise CommandLineError("Invalid language option")
|
||||
|
||||
initial_prompt = parser.value(initial_prompt_option)
|
||||
|
||||
|
@ -125,33 +166,49 @@ def parse(app: Application, parser: QCommandLineParser):
|
|||
output_formats.add(OutputFormat.TXT)
|
||||
|
||||
openai_access_token = parser.value(open_ai_access_token_option)
|
||||
if model.model_type == ModelType.OPEN_AI_WHISPER_API and openai_access_token == '':
|
||||
openai_access_token = KeyringStore().get_password(key=KeyringStore.Key.OPENAI_API_KEY)
|
||||
if (
|
||||
model.model_type == ModelType.OPEN_AI_WHISPER_API
|
||||
and openai_access_token == ""
|
||||
):
|
||||
openai_access_token = KeyringStore().get_password(
|
||||
key=KeyringStore.Key.OPENAI_API_KEY
|
||||
)
|
||||
|
||||
if openai_access_token == '':
|
||||
raise CommandLineError('No OpenAI access token found')
|
||||
if openai_access_token == "":
|
||||
raise CommandLineError("No OpenAI access token found")
|
||||
|
||||
transcription_options = TranscriptionOptions(model=model, task=task, language=language,
|
||||
initial_prompt=initial_prompt,
|
||||
openai_access_token=openai_access_token)
|
||||
file_transcription_options = FileTranscriptionOptions(file_paths=file_paths, output_formats=output_formats)
|
||||
transcription_options = TranscriptionOptions(
|
||||
model=model,
|
||||
task=task,
|
||||
language=language,
|
||||
initial_prompt=initial_prompt,
|
||||
openai_access_token=openai_access_token,
|
||||
)
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
file_paths=file_paths, output_formats=output_formats
|
||||
)
|
||||
|
||||
for file_path in file_paths:
|
||||
transcription_task = FileTranscriptionTask(file_path=file_path, model_path=model_path,
|
||||
transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options)
|
||||
transcription_task = FileTranscriptionTask(
|
||||
file_path=file_path,
|
||||
model_path=model_path,
|
||||
transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
)
|
||||
app.add_task(transcription_task)
|
||||
|
||||
|
||||
T = typing.TypeVar("T", bound=enum.Enum)
|
||||
|
||||
|
||||
def parse_enum_option(option: QCommandLineOption, parser: QCommandLineParser, enum_class: typing.Type[T]) -> T:
|
||||
def parse_enum_option(
|
||||
option: QCommandLineOption, parser: QCommandLineParser, enum_class: typing.Type[T]
|
||||
) -> T:
|
||||
try:
|
||||
return enum_class(parser.value(option))
|
||||
except ValueError:
|
||||
raise CommandLineError(f'Invalid value for --{option.names()[-1]} option.')
|
||||
raise CommandLineError(f"Invalid value for --{option.names()[-1]} option.")
|
||||
|
||||
|
||||
def join_values(enum_class: typing.Type[enum.Enum]) -> str:
|
||||
return ', '.join([v.value for v in enum_class])
|
||||
return ", ".join([v.value for v in enum_class])
|
||||
|
|
|
@ -2,9 +2,10 @@ from PyQt6.QtWidgets import QWidget, QMessageBox
|
|||
|
||||
|
||||
def show_model_download_error_dialog(parent: QWidget, error: str):
|
||||
message = parent.tr(
|
||||
'An error occurred while loading the Whisper model') + \
|
||||
f": {error}{'' if error.endswith('.') else '.'}" + \
|
||||
parent.tr("Please retry or check the application logs for more information.")
|
||||
message = (
|
||||
parent.tr("An error occurred while loading the Whisper model")
|
||||
+ f": {error}{'' if error.endswith('.') else '.'}"
|
||||
+ parent.tr("Please retry or check the application logs for more information.")
|
||||
)
|
||||
|
||||
QMessageBox.critical(parent, '', message)
|
||||
QMessageBox.critical(parent, "", message)
|
||||
|
|
|
@ -7,8 +7,14 @@ from typing import Optional, Tuple, List
|
|||
from PyQt6.QtCore import QObject, QThread, pyqtSignal, pyqtSlot
|
||||
|
||||
from buzz.model_loader import ModelType
|
||||
from buzz.transcriber import FileTranscriptionTask, FileTranscriber, WhisperCppFileTranscriber, \
|
||||
OpenAIWhisperAPIFileTranscriber, WhisperFileTranscriber, Segment
|
||||
from buzz.transcriber import (
|
||||
FileTranscriptionTask,
|
||||
FileTranscriber,
|
||||
WhisperCppFileTranscriber,
|
||||
OpenAIWhisperAPIFileTranscriber,
|
||||
WhisperFileTranscriber,
|
||||
Segment,
|
||||
)
|
||||
|
||||
|
||||
class FileTranscriberQueueWorker(QObject):
|
||||
|
@ -26,7 +32,7 @@ class FileTranscriberQueueWorker(QObject):
|
|||
|
||||
@pyqtSlot()
|
||||
def run(self):
|
||||
logging.debug('Waiting for next transcription task')
|
||||
logging.debug("Waiting for next transcription task")
|
||||
|
||||
# Get next non-canceled task from queue
|
||||
while True:
|
||||
|
@ -42,38 +48,37 @@ class FileTranscriberQueueWorker(QObject):
|
|||
|
||||
break
|
||||
|
||||
logging.debug('Starting next transcription task')
|
||||
logging.debug("Starting next transcription task")
|
||||
|
||||
model_type = self.current_task.transcription_options.model.model_type
|
||||
if model_type == ModelType.WHISPER_CPP:
|
||||
self.current_transcriber = WhisperCppFileTranscriber(
|
||||
task=self.current_task)
|
||||
self.current_transcriber = WhisperCppFileTranscriber(task=self.current_task)
|
||||
elif model_type == ModelType.OPEN_AI_WHISPER_API:
|
||||
self.current_transcriber = OpenAIWhisperAPIFileTranscriber(task=self.current_task)
|
||||
elif model_type == ModelType.HUGGING_FACE or \
|
||||
model_type == ModelType.WHISPER or \
|
||||
model_type == ModelType.FASTER_WHISPER:
|
||||
self.current_transcriber = OpenAIWhisperAPIFileTranscriber(
|
||||
task=self.current_task
|
||||
)
|
||||
elif (
|
||||
model_type == ModelType.HUGGING_FACE
|
||||
or model_type == ModelType.WHISPER
|
||||
or model_type == ModelType.FASTER_WHISPER
|
||||
):
|
||||
self.current_transcriber = WhisperFileTranscriber(task=self.current_task)
|
||||
else:
|
||||
raise Exception(f'Unknown model type: {model_type}')
|
||||
raise Exception(f"Unknown model type: {model_type}")
|
||||
|
||||
self.current_transcriber_thread = QThread(self)
|
||||
|
||||
self.current_transcriber.moveToThread(self.current_transcriber_thread)
|
||||
|
||||
self.current_transcriber_thread.started.connect(
|
||||
self.current_transcriber.run)
|
||||
self.current_transcriber.completed.connect(
|
||||
self.current_transcriber_thread.quit)
|
||||
self.current_transcriber.error.connect(
|
||||
self.current_transcriber_thread.quit)
|
||||
self.current_transcriber_thread.started.connect(self.current_transcriber.run)
|
||||
self.current_transcriber.completed.connect(self.current_transcriber_thread.quit)
|
||||
self.current_transcriber.error.connect(self.current_transcriber_thread.quit)
|
||||
|
||||
self.current_transcriber.completed.connect(
|
||||
self.current_transcriber.deleteLater)
|
||||
self.current_transcriber.error.connect(
|
||||
self.current_transcriber.deleteLater)
|
||||
self.current_transcriber.completed.connect(self.current_transcriber.deleteLater)
|
||||
self.current_transcriber.error.connect(self.current_transcriber.deleteLater)
|
||||
self.current_transcriber_thread.finished.connect(
|
||||
self.current_transcriber_thread.deleteLater)
|
||||
self.current_transcriber_thread.deleteLater
|
||||
)
|
||||
|
||||
self.current_transcriber.progress.connect(self.on_task_progress)
|
||||
self.current_transcriber.error.connect(self.on_task_error)
|
||||
|
@ -104,7 +109,10 @@ class FileTranscriberQueueWorker(QObject):
|
|||
|
||||
@pyqtSlot(Exception)
|
||||
def on_task_error(self, error: Exception):
|
||||
if self.current_task is not None and self.current_task.id not in self.canceled_tasks:
|
||||
if (
|
||||
self.current_task is not None
|
||||
and self.current_task.id not in self.canceled_tasks
|
||||
):
|
||||
self.current_task.status = FileTranscriptionTask.Status.FAILED
|
||||
self.current_task.error = str(error)
|
||||
self.task_updated.emit(self.current_task)
|
||||
|
|
494
buzz/gui.py
494
buzz/gui.py
|
@ -5,13 +5,23 @@ from typing import Dict, List, Optional, Tuple
|
|||
|
||||
import sounddevice
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import (Qt, QThread,
|
||||
pyqtSignal, QModelIndex, QThreadPool)
|
||||
from PyQt6.QtGui import (QCloseEvent, QIcon,
|
||||
QKeySequence, QTextCursor, QPainter, QColor)
|
||||
from PyQt6.QtWidgets import (QApplication, QComboBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit,
|
||||
QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QFormLayout,
|
||||
QSizePolicy)
|
||||
from PyQt6.QtCore import Qt, QThread, pyqtSignal, QModelIndex, QThreadPool
|
||||
from PyQt6.QtGui import QCloseEvent, QIcon, QKeySequence, QTextCursor, QPainter, QColor
|
||||
from PyQt6.QtWidgets import (
|
||||
QApplication,
|
||||
QComboBox,
|
||||
QFileDialog,
|
||||
QLabel,
|
||||
QMainWindow,
|
||||
QMessageBox,
|
||||
QPlainTextEdit,
|
||||
QPushButton,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QWidget,
|
||||
QFormLayout,
|
||||
QSizePolicy,
|
||||
)
|
||||
|
||||
from buzz.cache import TasksCache
|
||||
from .__version__ import VERSION
|
||||
|
@ -20,25 +30,35 @@ from .assets import get_asset_path
|
|||
from .dialogs import show_model_download_error_dialog
|
||||
from .widgets.icon import Icon, BUZZ_ICON_PATH
|
||||
from .locale import _
|
||||
from .model_loader import WhisperModelSize, ModelType, TranscriptionModel, \
|
||||
ModelDownloader
|
||||
from .model_loader import (
|
||||
WhisperModelSize,
|
||||
ModelType,
|
||||
TranscriptionModel,
|
||||
ModelDownloader,
|
||||
)
|
||||
from .recording import RecordingAmplitudeListener
|
||||
from .settings.settings import Settings, APP_NAME
|
||||
from .settings.shortcut import Shortcut
|
||||
from .settings.shortcut_settings import ShortcutSettings
|
||||
from .store.keyring_store import KeyringStore
|
||||
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, Task,
|
||||
TranscriptionOptions,
|
||||
FileTranscriptionTask, LOADED_WHISPER_DLL,
|
||||
DEFAULT_WHISPER_TEMPERATURE)
|
||||
from .transcriber import (
|
||||
SUPPORTED_OUTPUT_FORMATS,
|
||||
FileTranscriptionOptions,
|
||||
Task,
|
||||
TranscriptionOptions,
|
||||
FileTranscriptionTask,
|
||||
LOADED_WHISPER_DLL,
|
||||
DEFAULT_WHISPER_TEMPERATURE,
|
||||
)
|
||||
from .recording_transcriber import RecordingTranscriber
|
||||
from .file_transcriber_queue_worker import FileTranscriberQueueWorker
|
||||
from .widgets.menu_bar import MenuBar
|
||||
from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog
|
||||
from .widgets.toolbar import ToolBar
|
||||
from .widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
|
||||
from .widgets.transcriber.transcription_options_group_box import \
|
||||
TranscriptionOptionsGroupBox
|
||||
from .widgets.transcriber.transcription_options_group_box import (
|
||||
TranscriptionOptionsGroupBox,
|
||||
)
|
||||
from .widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget
|
||||
from .widgets.transcription_viewer_widget import TranscriptionViewerWidget
|
||||
|
||||
|
@ -46,13 +66,17 @@ from .widgets.transcription_viewer_widget import TranscriptionViewerWidget
|
|||
class FormLabel(QLabel):
|
||||
def __init__(self, name: str, parent: Optional[QWidget], *args) -> None:
|
||||
super().__init__(name, parent, *args)
|
||||
self.setStyleSheet('QLabel { text-align: right; }')
|
||||
self.setAlignment(Qt.AlignmentFlag(
|
||||
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignRight))
|
||||
self.setStyleSheet("QLabel { text-align: right; }")
|
||||
self.setAlignment(
|
||||
Qt.AlignmentFlag(
|
||||
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignRight
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AudioDevicesComboBox(QComboBox):
|
||||
"""AudioDevicesComboBox displays a list of available audio input devices"""
|
||||
|
||||
device_changed = pyqtSignal(int)
|
||||
audio_devices: List[Tuple[int, str]]
|
||||
|
||||
|
@ -71,13 +95,18 @@ class AudioDevicesComboBox(QComboBox):
|
|||
def get_audio_devices(self) -> List[Tuple[int, str]]:
|
||||
try:
|
||||
devices: sounddevice.DeviceList = sounddevice.query_devices()
|
||||
return [(device.get('index'), device.get('name'))
|
||||
for device in devices if device.get('max_input_channels') > 0]
|
||||
return [
|
||||
(device.get("index"), device.get("name"))
|
||||
for device in devices
|
||||
if device.get("max_input_channels") > 0
|
||||
]
|
||||
except UnicodeDecodeError:
|
||||
QMessageBox.critical(
|
||||
self, '',
|
||||
'An error occurred while loading your audio devices. Please check the application logs for more '
|
||||
'information.')
|
||||
self,
|
||||
"",
|
||||
"An error occurred while loading your audio devices. Please check the application logs for more "
|
||||
"information.",
|
||||
)
|
||||
return []
|
||||
|
||||
def on_index_changed(self, index: int):
|
||||
|
@ -107,14 +136,16 @@ class RecordButton(QPushButton):
|
|||
def __init__(self, parent: Optional[QWidget]) -> None:
|
||||
super().__init__(_("Record"), parent)
|
||||
self.setDefault(True)
|
||||
self.setSizePolicy(QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
|
||||
self.setSizePolicy(
|
||||
QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed)
|
||||
)
|
||||
|
||||
def set_stopped(self):
|
||||
self.setText(_('Record'))
|
||||
self.setText(_("Record"))
|
||||
self.setDefault(True)
|
||||
|
||||
def set_recording(self):
|
||||
self.setText(_('Stop'))
|
||||
self.setText(_("Stop"))
|
||||
self.setDefault(False)
|
||||
|
||||
|
||||
|
@ -143,11 +174,11 @@ class AudioMeterWidget(QWidget):
|
|||
self.AMPLITUDE_SCALE_FACTOR = 15 # scale the amplitudes such that 1/AMPLITUDE_SCALE_FACTOR will show all bars
|
||||
|
||||
if self.palette().window().color().black() > 127:
|
||||
self.BAR_INACTIVE_COLOR = QColor('#555')
|
||||
self.BAR_ACTIVE_COLOR = QColor('#999')
|
||||
self.BAR_INACTIVE_COLOR = QColor("#555")
|
||||
self.BAR_ACTIVE_COLOR = QColor("#999")
|
||||
else:
|
||||
self.BAR_INACTIVE_COLOR = QColor('#BBB')
|
||||
self.BAR_ACTIVE_COLOR = QColor('#555')
|
||||
self.BAR_INACTIVE_COLOR = QColor("#BBB")
|
||||
self.BAR_ACTIVE_COLOR = QColor("#555")
|
||||
|
||||
def paintEvent(self, event: QtGui.QPaintEvent) -> None:
|
||||
painter = QPainter(self)
|
||||
|
@ -157,26 +188,38 @@ class AudioMeterWidget(QWidget):
|
|||
center_x = rect.center().x()
|
||||
num_bars_in_half = int((rect.width() / 2) / (self.BAR_MARGIN + self.BAR_WIDTH))
|
||||
for i in range(num_bars_in_half):
|
||||
is_bar_active = ((self.current_amplitude - self.MINIMUM_AMPLITUDE) * self.AMPLITUDE_SCALE_FACTOR) > (
|
||||
i / num_bars_in_half)
|
||||
painter.setBrush(self.BAR_ACTIVE_COLOR if is_bar_active else self.BAR_INACTIVE_COLOR)
|
||||
is_bar_active = (
|
||||
(self.current_amplitude - self.MINIMUM_AMPLITUDE)
|
||||
* self.AMPLITUDE_SCALE_FACTOR
|
||||
) > (i / num_bars_in_half)
|
||||
painter.setBrush(
|
||||
self.BAR_ACTIVE_COLOR if is_bar_active else self.BAR_INACTIVE_COLOR
|
||||
)
|
||||
|
||||
# draw to left
|
||||
painter.drawRect(center_x - ((i + 1) * (self.BAR_MARGIN + self.BAR_WIDTH)), rect.top() + self.PADDING_TOP,
|
||||
self.BAR_WIDTH,
|
||||
rect.height() - self.PADDING_TOP)
|
||||
painter.drawRect(
|
||||
center_x - ((i + 1) * (self.BAR_MARGIN + self.BAR_WIDTH)),
|
||||
rect.top() + self.PADDING_TOP,
|
||||
self.BAR_WIDTH,
|
||||
rect.height() - self.PADDING_TOP,
|
||||
)
|
||||
# draw to right
|
||||
painter.drawRect(center_x + (self.BAR_MARGIN + (i * (self.BAR_MARGIN + self.BAR_WIDTH))),
|
||||
rect.top() + self.PADDING_TOP,
|
||||
self.BAR_WIDTH, rect.height() - self.PADDING_TOP)
|
||||
painter.drawRect(
|
||||
center_x + (self.BAR_MARGIN + (i * (self.BAR_MARGIN + self.BAR_WIDTH))),
|
||||
rect.top() + self.PADDING_TOP,
|
||||
self.BAR_WIDTH,
|
||||
rect.height() - self.PADDING_TOP,
|
||||
)
|
||||
|
||||
def update_amplitude(self, amplitude: float):
|
||||
self.current_amplitude = max(amplitude, self.current_amplitude * self.SMOOTHING_FACTOR)
|
||||
self.current_amplitude = max(
|
||||
amplitude, self.current_amplitude * self.SMOOTHING_FACTOR
|
||||
)
|
||||
self.repaint()
|
||||
|
||||
|
||||
class RecordingTranscriberWidget(QWidget):
|
||||
current_status: 'RecordingStatus'
|
||||
current_status: "RecordingStatus"
|
||||
transcription_options: TranscriptionOptions
|
||||
selected_device_id: Optional[int]
|
||||
model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None
|
||||
|
@ -190,7 +233,9 @@ class RecordingTranscriberWidget(QWidget):
|
|||
STOPPED = auto()
|
||||
RECORDING = auto()
|
||||
|
||||
def __init__(self, parent: Optional[QWidget] = None, flags: Optional[Qt.WindowType] = None) -> None:
|
||||
def __init__(
|
||||
self, parent: Optional[QWidget] = None, flags: Optional[Qt.WindowType] = None
|
||||
) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
if flags is not None:
|
||||
|
@ -199,42 +244,63 @@ class RecordingTranscriberWidget(QWidget):
|
|||
layout = QVBoxLayout(self)
|
||||
|
||||
self.current_status = self.RecordingStatus.STOPPED
|
||||
self.setWindowTitle(_('Live Recording'))
|
||||
self.setWindowTitle(_("Live Recording"))
|
||||
|
||||
self.settings = Settings()
|
||||
default_language = self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE, default_value='')
|
||||
default_language = self.settings.value(
|
||||
key=Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE, default_value=""
|
||||
)
|
||||
self.transcription_options = TranscriptionOptions(
|
||||
model=self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_MODEL, default_value=TranscriptionModel(
|
||||
model_type=ModelType.WHISPER_CPP if LOADED_WHISPER_DLL else ModelType.WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY)),
|
||||
task=self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_TASK, default_value=Task.TRANSCRIBE),
|
||||
language=default_language if default_language != '' else None,
|
||||
initial_prompt=self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT, default_value=''),
|
||||
temperature=self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_TEMPERATURE,
|
||||
default_value=DEFAULT_WHISPER_TEMPERATURE), word_level_timings=False)
|
||||
model=self.settings.value(
|
||||
key=Settings.Key.RECORDING_TRANSCRIBER_MODEL,
|
||||
default_value=TranscriptionModel(
|
||||
model_type=ModelType.WHISPER_CPP
|
||||
if LOADED_WHISPER_DLL
|
||||
else ModelType.WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
),
|
||||
),
|
||||
task=self.settings.value(
|
||||
key=Settings.Key.RECORDING_TRANSCRIBER_TASK,
|
||||
default_value=Task.TRANSCRIBE,
|
||||
),
|
||||
language=default_language if default_language != "" else None,
|
||||
initial_prompt=self.settings.value(
|
||||
key=Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT, default_value=""
|
||||
),
|
||||
temperature=self.settings.value(
|
||||
key=Settings.Key.RECORDING_TRANSCRIBER_TEMPERATURE,
|
||||
default_value=DEFAULT_WHISPER_TEMPERATURE,
|
||||
),
|
||||
word_level_timings=False,
|
||||
)
|
||||
|
||||
self.audio_devices_combo_box = AudioDevicesComboBox(self)
|
||||
self.audio_devices_combo_box.device_changed.connect(
|
||||
self.on_device_changed)
|
||||
self.audio_devices_combo_box.device_changed.connect(self.on_device_changed)
|
||||
self.selected_device_id = self.audio_devices_combo_box.get_default_device_id()
|
||||
|
||||
self.record_button = RecordButton(self)
|
||||
self.record_button.clicked.connect(self.on_record_button_clicked)
|
||||
|
||||
self.text_box = TextDisplayBox(self)
|
||||
self.text_box.setPlaceholderText(_('Click Record to begin...'))
|
||||
self.text_box.setPlaceholderText(_("Click Record to begin..."))
|
||||
|
||||
transcription_options_group_box = TranscriptionOptionsGroupBox(
|
||||
default_transcription_options=self.transcription_options,
|
||||
# Live transcription with OpenAI Whisper API not implemented
|
||||
model_types=[model_type for model_type in ModelType if model_type is not ModelType.OPEN_AI_WHISPER_API],
|
||||
parent=self)
|
||||
model_types=[
|
||||
model_type
|
||||
for model_type in ModelType
|
||||
if model_type is not ModelType.OPEN_AI_WHISPER_API
|
||||
],
|
||||
parent=self,
|
||||
)
|
||||
transcription_options_group_box.transcription_options_changed.connect(
|
||||
self.on_transcription_options_changed)
|
||||
self.on_transcription_options_changed
|
||||
)
|
||||
|
||||
recording_options_layout = QFormLayout()
|
||||
recording_options_layout.addRow(
|
||||
_('Microphone:'), self.audio_devices_combo_box)
|
||||
recording_options_layout.addRow(_("Microphone:"), self.audio_devices_combo_box)
|
||||
|
||||
self.audio_meter_widget = AudioMeterWidget(self)
|
||||
|
||||
|
@ -252,7 +318,9 @@ class RecordingTranscriberWidget(QWidget):
|
|||
|
||||
self.reset_recording_amplitude_listener()
|
||||
|
||||
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
|
||||
def on_transcription_options_changed(
|
||||
self, transcription_options: TranscriptionOptions
|
||||
):
|
||||
self.transcription_options = transcription_options
|
||||
|
||||
def on_device_changed(self, device_id: int):
|
||||
|
@ -269,11 +337,16 @@ class RecordingTranscriberWidget(QWidget):
|
|||
|
||||
# Get the device sample rate before starting the listener as the PortAudio function
|
||||
# fails if you try to get the device's settings while recording is in progress.
|
||||
self.device_sample_rate = RecordingTranscriber.get_device_sample_rate(self.selected_device_id)
|
||||
self.device_sample_rate = RecordingTranscriber.get_device_sample_rate(
|
||||
self.selected_device_id
|
||||
)
|
||||
|
||||
self.recording_amplitude_listener = RecordingAmplitudeListener(input_device_index=self.selected_device_id,
|
||||
parent=self)
|
||||
self.recording_amplitude_listener.amplitude_changed.connect(self.on_recording_amplitude_changed)
|
||||
self.recording_amplitude_listener = RecordingAmplitudeListener(
|
||||
input_device_index=self.selected_device_id, parent=self
|
||||
)
|
||||
self.recording_amplitude_listener.amplitude_changed.connect(
|
||||
self.on_recording_amplitude_changed
|
||||
)
|
||||
self.recording_amplitude_listener.start_recording()
|
||||
|
||||
def on_record_button_clicked(self):
|
||||
|
@ -306,16 +379,19 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.transcription_thread = QThread()
|
||||
|
||||
# TODO: make runnable
|
||||
self.transcriber = RecordingTranscriber(input_device_index=self.selected_device_id,
|
||||
sample_rate=self.device_sample_rate,
|
||||
transcription_options=self.transcription_options,
|
||||
model_path=model_path)
|
||||
self.transcriber = RecordingTranscriber(
|
||||
input_device_index=self.selected_device_id,
|
||||
sample_rate=self.device_sample_rate,
|
||||
transcription_options=self.transcription_options,
|
||||
model_path=model_path,
|
||||
)
|
||||
|
||||
self.transcriber.moveToThread(self.transcription_thread)
|
||||
|
||||
self.transcription_thread.started.connect(self.transcriber.start)
|
||||
self.transcription_thread.finished.connect(
|
||||
self.transcription_thread.deleteLater)
|
||||
self.transcription_thread.deleteLater
|
||||
)
|
||||
|
||||
self.transcriber.transcription.connect(self.on_next_transcription)
|
||||
|
||||
|
@ -334,12 +410,16 @@ class RecordingTranscriberWidget(QWidget):
|
|||
|
||||
if self.model_download_progress_dialog is None:
|
||||
self.model_download_progress_dialog = ModelDownloadProgressDialog(
|
||||
model_type=self.transcription_options.model.model_type, parent=self)
|
||||
model_type=self.transcription_options.model.model_type, parent=self
|
||||
)
|
||||
self.model_download_progress_dialog.canceled.connect(
|
||||
self.on_cancel_model_progress_dialog)
|
||||
self.on_cancel_model_progress_dialog
|
||||
)
|
||||
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.set_value(fraction_completed=current_size / total_size)
|
||||
self.model_download_progress_dialog.set_value(
|
||||
fraction_completed=current_size / total_size
|
||||
)
|
||||
|
||||
def set_recording_status_stopped(self):
|
||||
self.record_button.set_stopped()
|
||||
|
@ -357,7 +437,7 @@ class RecordingTranscriberWidget(QWidget):
|
|||
if len(text) > 0:
|
||||
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
|
||||
if len(self.text_box.toPlainText()) > 0:
|
||||
self.text_box.insertPlainText('\n\n')
|
||||
self.text_box.insertPlainText("\n\n")
|
||||
self.text_box.insertPlainText(text)
|
||||
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
|
||||
|
||||
|
@ -374,9 +454,15 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.reset_record_button()
|
||||
self.set_recording_status_stopped()
|
||||
QMessageBox.critical(
|
||||
self, '',
|
||||
_('An error occurred while starting a new recording:') + error + '. ' +
|
||||
_('Please check your audio devices or check the application logs for more information.'))
|
||||
self,
|
||||
"",
|
||||
_("An error occurred while starting a new recording:")
|
||||
+ error
|
||||
+ ". "
|
||||
+ _(
|
||||
"Please check your audio devices or check the application logs for more information."
|
||||
),
|
||||
)
|
||||
|
||||
def on_cancel_model_progress_dialog(self):
|
||||
if self.model_loader is not None:
|
||||
|
@ -392,7 +478,7 @@ class RecordingTranscriberWidget(QWidget):
|
|||
|
||||
def reset_recording_controls(self):
|
||||
# Clear text box placeholder because the first chunk takes a while to process
|
||||
self.text_box.setPlaceholderText('')
|
||||
self.text_box.setPlaceholderText("")
|
||||
self.reset_record_button()
|
||||
self.reset_model_download()
|
||||
|
||||
|
@ -411,49 +497,72 @@ class RecordingTranscriberWidget(QWidget):
|
|||
self.recording_amplitude_listener.stop_recording()
|
||||
self.recording_amplitude_listener.deleteLater()
|
||||
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE, self.transcription_options.language)
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_TASK, self.transcription_options.task)
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_TEMPERATURE, self.transcription_options.temperature)
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT,
|
||||
self.transcription_options.initial_prompt)
|
||||
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_MODEL, self.transcription_options.model)
|
||||
self.settings.set_value(
|
||||
Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE,
|
||||
self.transcription_options.language,
|
||||
)
|
||||
self.settings.set_value(
|
||||
Settings.Key.RECORDING_TRANSCRIBER_TASK, self.transcription_options.task
|
||||
)
|
||||
self.settings.set_value(
|
||||
Settings.Key.RECORDING_TRANSCRIBER_TEMPERATURE,
|
||||
self.transcription_options.temperature,
|
||||
)
|
||||
self.settings.set_value(
|
||||
Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT,
|
||||
self.transcription_options.initial_prompt,
|
||||
)
|
||||
self.settings.set_value(
|
||||
Settings.Key.RECORDING_TRANSCRIBER_MODEL, self.transcription_options.model
|
||||
)
|
||||
|
||||
return super().closeEvent(event)
|
||||
|
||||
|
||||
RECORD_ICON_PATH = get_asset_path('assets/mic_FILL0_wght700_GRAD0_opsz48.svg')
|
||||
EXPAND_ICON_PATH = get_asset_path('assets/open_in_full_FILL0_wght700_GRAD0_opsz48.svg')
|
||||
ADD_ICON_PATH = get_asset_path('assets/add_FILL0_wght700_GRAD0_opsz48.svg')
|
||||
TRASH_ICON_PATH = get_asset_path('assets/delete_FILL0_wght700_GRAD0_opsz48.svg')
|
||||
CANCEL_ICON_PATH = get_asset_path('assets/cancel_FILL0_wght700_GRAD0_opsz48.svg')
|
||||
RECORD_ICON_PATH = get_asset_path("assets/mic_FILL0_wght700_GRAD0_opsz48.svg")
|
||||
EXPAND_ICON_PATH = get_asset_path("assets/open_in_full_FILL0_wght700_GRAD0_opsz48.svg")
|
||||
ADD_ICON_PATH = get_asset_path("assets/add_FILL0_wght700_GRAD0_opsz48.svg")
|
||||
TRASH_ICON_PATH = get_asset_path("assets/delete_FILL0_wght700_GRAD0_opsz48.svg")
|
||||
CANCEL_ICON_PATH = get_asset_path("assets/cancel_FILL0_wght700_GRAD0_opsz48.svg")
|
||||
|
||||
|
||||
class MainWindowToolbar(ToolBar):
|
||||
new_transcription_action_triggered: pyqtSignal
|
||||
open_transcript_action_triggered: pyqtSignal
|
||||
clear_history_action_triggered: pyqtSignal
|
||||
ICON_LIGHT_THEME_BACKGROUND = '#555'
|
||||
ICON_DARK_THEME_BACKGROUND = '#AAA'
|
||||
ICON_LIGHT_THEME_BACKGROUND = "#555"
|
||||
ICON_DARK_THEME_BACKGROUND = "#AAA"
|
||||
|
||||
def __init__(self, shortcuts: Dict[str, str], parent: Optional[QWidget]):
|
||||
super().__init__(parent)
|
||||
|
||||
self.record_action = Action(Icon(RECORD_ICON_PATH, self), _('Record'), self)
|
||||
self.record_action = Action(Icon(RECORD_ICON_PATH, self), _("Record"), self)
|
||||
self.record_action.triggered.connect(self.on_record_action_triggered)
|
||||
|
||||
self.new_transcription_action = Action(Icon(ADD_ICON_PATH, self), _('New Transcription'), self)
|
||||
self.new_transcription_action_triggered = self.new_transcription_action.triggered
|
||||
self.new_transcription_action = Action(
|
||||
Icon(ADD_ICON_PATH, self), _("New Transcription"), self
|
||||
)
|
||||
self.new_transcription_action_triggered = (
|
||||
self.new_transcription_action.triggered
|
||||
)
|
||||
|
||||
self.open_transcript_action = Action(Icon(EXPAND_ICON_PATH, self),
|
||||
_('Open Transcript'), self)
|
||||
self.open_transcript_action = Action(
|
||||
Icon(EXPAND_ICON_PATH, self), _("Open Transcript"), self
|
||||
)
|
||||
self.open_transcript_action_triggered = self.open_transcript_action.triggered
|
||||
self.open_transcript_action.setDisabled(True)
|
||||
|
||||
self.stop_transcription_action = Action(Icon(CANCEL_ICON_PATH, self), _('Cancel Transcription'), self)
|
||||
self.stop_transcription_action_triggered = self.stop_transcription_action.triggered
|
||||
self.stop_transcription_action = Action(
|
||||
Icon(CANCEL_ICON_PATH, self), _("Cancel Transcription"), self
|
||||
)
|
||||
self.stop_transcription_action_triggered = (
|
||||
self.stop_transcription_action.triggered
|
||||
)
|
||||
self.stop_transcription_action.setDisabled(True)
|
||||
|
||||
self.clear_history_action = Action(Icon(TRASH_ICON_PATH, self), _('Clear History'), self)
|
||||
self.clear_history_action = Action(
|
||||
Icon(TRASH_ICON_PATH, self), _("Clear History"), self
|
||||
)
|
||||
self.clear_history_action_triggered = self.clear_history_action.triggered
|
||||
self.clear_history_action.setDisabled(True)
|
||||
|
||||
|
@ -461,21 +570,38 @@ class MainWindowToolbar(ToolBar):
|
|||
|
||||
self.addAction(self.record_action)
|
||||
self.addSeparator()
|
||||
self.addActions([self.new_transcription_action, self.open_transcript_action, self.stop_transcription_action,
|
||||
self.clear_history_action])
|
||||
self.addActions(
|
||||
[
|
||||
self.new_transcription_action,
|
||||
self.open_transcript_action,
|
||||
self.stop_transcription_action,
|
||||
self.clear_history_action,
|
||||
]
|
||||
)
|
||||
self.setMovable(False)
|
||||
self.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly)
|
||||
|
||||
def set_shortcuts(self, shortcuts: Dict[str, str]):
|
||||
self.record_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.OPEN_RECORD_WINDOW.name]))
|
||||
self.new_transcription_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name]))
|
||||
self.record_action.setShortcut(
|
||||
QKeySequence.fromString(shortcuts[Shortcut.OPEN_RECORD_WINDOW.name])
|
||||
)
|
||||
self.new_transcription_action.setShortcut(
|
||||
QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name])
|
||||
)
|
||||
self.open_transcript_action.setShortcut(
|
||||
QKeySequence.fromString(shortcuts[Shortcut.OPEN_TRANSCRIPT_EDITOR.name]))
|
||||
self.stop_transcription_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.STOP_TRANSCRIPTION.name]))
|
||||
self.clear_history_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.CLEAR_HISTORY.name]))
|
||||
QKeySequence.fromString(shortcuts[Shortcut.OPEN_TRANSCRIPT_EDITOR.name])
|
||||
)
|
||||
self.stop_transcription_action.setShortcut(
|
||||
QKeySequence.fromString(shortcuts[Shortcut.STOP_TRANSCRIPTION.name])
|
||||
)
|
||||
self.clear_history_action.setShortcut(
|
||||
QKeySequence.fromString(shortcuts[Shortcut.CLEAR_HISTORY.name])
|
||||
)
|
||||
|
||||
def on_record_action_triggered(self):
|
||||
recording_transcriber_window = RecordingTranscriberWidget(self, flags=Qt.WindowType.Window)
|
||||
recording_transcriber_window = RecordingTranscriberWidget(
|
||||
self, flags=Qt.WindowType.Window
|
||||
)
|
||||
recording_transcriber_window.show()
|
||||
|
||||
def set_stop_transcription_action_enabled(self, enabled: bool):
|
||||
|
@ -490,7 +616,7 @@ class MainWindowToolbar(ToolBar):
|
|||
|
||||
class MainWindow(QMainWindow):
|
||||
table_widget: TranscriptionTasksTableWidget
|
||||
tasks: Dict[int, 'FileTranscriptionTask']
|
||||
tasks: Dict[int, "FileTranscriptionTask"]
|
||||
tasks_changed = pyqtSignal()
|
||||
openai_access_token: Optional[str]
|
||||
|
||||
|
@ -511,34 +637,49 @@ class MainWindow(QMainWindow):
|
|||
self.shortcuts = self.shortcut_settings.load()
|
||||
self.default_export_file_name = self.settings.value(
|
||||
Settings.Key.DEFAULT_EXPORT_FILE_NAME,
|
||||
'{{ input_file_name }} ({{ task }}d on {{ date_time }})')
|
||||
"{{ input_file_name }} ({{ task }}d on {{ date_time }})",
|
||||
)
|
||||
|
||||
self.tasks = {}
|
||||
self.tasks_changed.connect(self.on_tasks_changed)
|
||||
|
||||
self.toolbar = MainWindowToolbar(shortcuts=self.shortcuts, parent=self)
|
||||
self.toolbar.new_transcription_action_triggered.connect(self.on_new_transcription_action_triggered)
|
||||
self.toolbar.open_transcript_action_triggered.connect(self.open_transcript_viewer)
|
||||
self.toolbar.clear_history_action_triggered.connect(self.on_clear_history_action_triggered)
|
||||
self.toolbar.stop_transcription_action_triggered.connect(self.on_stop_transcription_action_triggered)
|
||||
self.toolbar.new_transcription_action_triggered.connect(
|
||||
self.on_new_transcription_action_triggered
|
||||
)
|
||||
self.toolbar.open_transcript_action_triggered.connect(
|
||||
self.open_transcript_viewer
|
||||
)
|
||||
self.toolbar.clear_history_action_triggered.connect(
|
||||
self.on_clear_history_action_triggered
|
||||
)
|
||||
self.toolbar.stop_transcription_action_triggered.connect(
|
||||
self.on_stop_transcription_action_triggered
|
||||
)
|
||||
self.addToolBar(self.toolbar)
|
||||
self.setUnifiedTitleAndToolBarOnMac(True)
|
||||
|
||||
self.menu_bar = MenuBar(shortcuts=self.shortcuts,
|
||||
default_export_file_name=self.default_export_file_name,
|
||||
parent=self)
|
||||
self.menu_bar = MenuBar(
|
||||
shortcuts=self.shortcuts,
|
||||
default_export_file_name=self.default_export_file_name,
|
||||
parent=self,
|
||||
)
|
||||
self.menu_bar.import_action_triggered.connect(
|
||||
self.on_new_transcription_action_triggered)
|
||||
self.on_new_transcription_action_triggered
|
||||
)
|
||||
self.menu_bar.shortcuts_changed.connect(self.on_shortcuts_changed)
|
||||
self.menu_bar.openai_api_key_changed.connect(self.on_openai_access_token_changed)
|
||||
self.menu_bar.default_export_file_name_changed.connect(self.default_export_file_name_changed)
|
||||
self.menu_bar.openai_api_key_changed.connect(
|
||||
self.on_openai_access_token_changed
|
||||
)
|
||||
self.menu_bar.default_export_file_name_changed.connect(
|
||||
self.default_export_file_name_changed
|
||||
)
|
||||
self.setMenuBar(self.menu_bar)
|
||||
|
||||
self.table_widget = TranscriptionTasksTableWidget(self)
|
||||
self.table_widget.doubleClicked.connect(self.on_table_double_clicked)
|
||||
self.table_widget.return_clicked.connect(self.open_transcript_viewer)
|
||||
self.table_widget.itemSelectionChanged.connect(
|
||||
self.on_table_selection_changed)
|
||||
self.table_widget.itemSelectionChanged.connect(self.on_table_selection_changed)
|
||||
|
||||
self.setCentralWidget(self.table_widget)
|
||||
|
||||
|
@ -548,8 +689,7 @@ class MainWindow(QMainWindow):
|
|||
self.transcriber_worker = FileTranscriberQueueWorker()
|
||||
self.transcriber_worker.moveToThread(self.transcriber_thread)
|
||||
|
||||
self.transcriber_worker.task_updated.connect(
|
||||
self.update_task_table_row)
|
||||
self.transcriber_worker.task_updated.connect(self.update_task_table_row)
|
||||
self.transcriber_worker.completed.connect(self.transcriber_thread.quit)
|
||||
|
||||
self.transcriber_thread.started.connect(self.transcriber_worker.run)
|
||||
|
@ -569,11 +709,14 @@ class MainWindow(QMainWindow):
|
|||
file_paths = [url.toLocalFile() for url in event.mimeData().urls()]
|
||||
self.open_file_transcriber_widget(file_paths=file_paths)
|
||||
|
||||
def on_file_transcriber_triggered(self, options: Tuple[TranscriptionOptions, FileTranscriptionOptions, str]):
|
||||
def on_file_transcriber_triggered(
|
||||
self, options: Tuple[TranscriptionOptions, FileTranscriptionOptions, str]
|
||||
):
|
||||
transcription_options, file_transcription_options, model_path = options
|
||||
for file_path in file_transcription_options.file_paths:
|
||||
task = FileTranscriptionTask(
|
||||
file_path, transcription_options, file_transcription_options, model_path)
|
||||
file_path, transcription_options, file_transcription_options, model_path
|
||||
)
|
||||
self.add_task(task)
|
||||
|
||||
def load_task(self, task: FileTranscriptionTask):
|
||||
|
@ -586,8 +729,10 @@ class MainWindow(QMainWindow):
|
|||
|
||||
@staticmethod
|
||||
def task_completed_or_errored(task: FileTranscriptionTask):
|
||||
return task.status == FileTranscriptionTask.Status.COMPLETED or \
|
||||
task.status == FileTranscriptionTask.Status.FAILED
|
||||
return (
|
||||
task.status == FileTranscriptionTask.Status.COMPLETED
|
||||
or task.status == FileTranscriptionTask.Status.FAILED
|
||||
)
|
||||
|
||||
def on_clear_history_action_triggered(self):
|
||||
selected_rows = self.table_widget.selectionModel().selectedRows()
|
||||
|
@ -595,10 +740,17 @@ class MainWindow(QMainWindow):
|
|||
return
|
||||
|
||||
reply = QMessageBox.question(
|
||||
self, _('Clear History'),
|
||||
_('Are you sure you want to delete the selected transcription(s)? This action cannot be undone.'))
|
||||
self,
|
||||
_("Clear History"),
|
||||
_(
|
||||
"Are you sure you want to delete the selected transcription(s)? This action cannot be undone."
|
||||
),
|
||||
)
|
||||
if reply == QMessageBox.StandardButton.Yes:
|
||||
task_ids = [TranscriptionTasksTableWidget.find_task_id(selected_row) for selected_row in selected_rows]
|
||||
task_ids = [
|
||||
TranscriptionTasksTableWidget.find_task_id(selected_row)
|
||||
for selected_row in selected_rows
|
||||
]
|
||||
for task_id in task_ids:
|
||||
self.table_widget.clear_task(task_id)
|
||||
self.tasks.pop(task_id)
|
||||
|
@ -617,20 +769,24 @@ class MainWindow(QMainWindow):
|
|||
|
||||
def on_new_transcription_action_triggered(self):
|
||||
(file_paths, __) = QFileDialog.getOpenFileNames(
|
||||
self, _('Select audio file'), '', SUPPORTED_OUTPUT_FORMATS)
|
||||
self, _("Select audio file"), "", SUPPORTED_OUTPUT_FORMATS
|
||||
)
|
||||
if len(file_paths) == 0:
|
||||
return
|
||||
|
||||
self.open_file_transcriber_widget(file_paths)
|
||||
|
||||
def open_file_transcriber_widget(self, file_paths: List[str]):
|
||||
file_transcriber_window = FileTranscriberWidget(file_paths=file_paths,
|
||||
default_output_file_name=self.default_export_file_name,
|
||||
parent=self,
|
||||
flags=Qt.WindowType.Window)
|
||||
file_transcriber_window.triggered.connect(
|
||||
self.on_file_transcriber_triggered)
|
||||
file_transcriber_window.openai_access_token_changed.connect(self.on_openai_access_token_changed)
|
||||
file_transcriber_window = FileTranscriberWidget(
|
||||
file_paths=file_paths,
|
||||
default_output_file_name=self.default_export_file_name,
|
||||
parent=self,
|
||||
flags=Qt.WindowType.Window,
|
||||
)
|
||||
file_transcriber_window.triggered.connect(self.on_file_transcriber_triggered)
|
||||
file_transcriber_window.openai_access_token_changed.connect(
|
||||
self.on_openai_access_token_changed
|
||||
)
|
||||
file_transcriber_window.show()
|
||||
|
||||
@staticmethod
|
||||
|
@ -639,7 +795,9 @@ class MainWindow(QMainWindow):
|
|||
|
||||
def default_export_file_name_changed(self, default_export_file_name: str):
|
||||
self.default_export_file_name = default_export_file_name
|
||||
self.settings.set_value(Settings.Key.DEFAULT_EXPORT_FILE_NAME, default_export_file_name)
|
||||
self.settings.set_value(
|
||||
Settings.Key.DEFAULT_EXPORT_FILE_NAME, default_export_file_name
|
||||
)
|
||||
|
||||
def open_transcript_viewer(self):
|
||||
selected_rows = self.table_widget.selectionModel().selectedRows()
|
||||
|
@ -648,29 +806,49 @@ class MainWindow(QMainWindow):
|
|||
self.open_transcription_viewer(task_id)
|
||||
|
||||
def on_table_selection_changed(self):
|
||||
self.toolbar.set_open_transcript_action_enabled(self.should_enable_open_transcript_action())
|
||||
self.toolbar.set_stop_transcription_action_enabled(self.should_enable_stop_transcription_action())
|
||||
self.toolbar.set_clear_history_action_enabled(self.should_enable_clear_history_action())
|
||||
self.toolbar.set_open_transcript_action_enabled(
|
||||
self.should_enable_open_transcript_action()
|
||||
)
|
||||
self.toolbar.set_stop_transcription_action_enabled(
|
||||
self.should_enable_stop_transcription_action()
|
||||
)
|
||||
self.toolbar.set_clear_history_action_enabled(
|
||||
self.should_enable_clear_history_action()
|
||||
)
|
||||
|
||||
def should_enable_open_transcript_action(self):
|
||||
return self.selected_tasks_have_status([FileTranscriptionTask.Status.COMPLETED])
|
||||
|
||||
def should_enable_stop_transcription_action(self):
|
||||
return self.selected_tasks_have_status(
|
||||
[FileTranscriptionTask.Status.IN_PROGRESS, FileTranscriptionTask.Status.QUEUED])
|
||||
[
|
||||
FileTranscriptionTask.Status.IN_PROGRESS,
|
||||
FileTranscriptionTask.Status.QUEUED,
|
||||
]
|
||||
)
|
||||
|
||||
def should_enable_clear_history_action(self):
|
||||
return self.selected_tasks_have_status(
|
||||
[FileTranscriptionTask.Status.COMPLETED, FileTranscriptionTask.Status.FAILED,
|
||||
FileTranscriptionTask.Status.CANCELED])
|
||||
[
|
||||
FileTranscriptionTask.Status.COMPLETED,
|
||||
FileTranscriptionTask.Status.FAILED,
|
||||
FileTranscriptionTask.Status.CANCELED,
|
||||
]
|
||||
)
|
||||
|
||||
def selected_tasks_have_status(self, statuses: List[FileTranscriptionTask.Status]):
|
||||
selected_rows = self.table_widget.selectionModel().selectedRows()
|
||||
if len(selected_rows) == 0:
|
||||
return False
|
||||
return all(
|
||||
[self.tasks[TranscriptionTasksTableWidget.find_task_id(selected_row)].status in statuses for selected_row in
|
||||
selected_rows])
|
||||
[
|
||||
self.tasks[
|
||||
TranscriptionTasksTableWidget.find_task_id(selected_row)
|
||||
].status
|
||||
in statuses
|
||||
for selected_row in selected_rows
|
||||
]
|
||||
)
|
||||
|
||||
def on_table_double_clicked(self, index: QModelIndex):
|
||||
task_id = TranscriptionTasksTableWidget.find_task_id(index)
|
||||
|
@ -682,7 +860,8 @@ class MainWindow(QMainWindow):
|
|||
return
|
||||
|
||||
transcription_viewer_widget = TranscriptionViewerWidget(
|
||||
transcription_task=task, parent=self, flags=Qt.WindowType.Window)
|
||||
transcription_task=task, parent=self, flags=Qt.WindowType.Window
|
||||
)
|
||||
transcription_viewer_widget.task_changed.connect(self.on_tasks_changed)
|
||||
transcription_viewer_widget.show()
|
||||
|
||||
|
@ -692,8 +871,10 @@ class MainWindow(QMainWindow):
|
|||
def load_tasks_from_cache(self):
|
||||
tasks = self.tasks_cache.load()
|
||||
for task in tasks:
|
||||
if task.status == FileTranscriptionTask.Status.QUEUED or \
|
||||
task.status == FileTranscriptionTask.Status.IN_PROGRESS:
|
||||
if (
|
||||
task.status == FileTranscriptionTask.Status.QUEUED
|
||||
or task.status == FileTranscriptionTask.Status.IN_PROGRESS
|
||||
):
|
||||
task.status = None
|
||||
self.transcriber_worker.add_task(task)
|
||||
else:
|
||||
|
@ -703,9 +884,15 @@ class MainWindow(QMainWindow):
|
|||
self.tasks_cache.save(list(self.tasks.values()))
|
||||
|
||||
def on_tasks_changed(self):
|
||||
self.toolbar.set_open_transcript_action_enabled(self.should_enable_open_transcript_action())
|
||||
self.toolbar.set_stop_transcription_action_enabled(self.should_enable_stop_transcription_action())
|
||||
self.toolbar.set_clear_history_action_enabled(self.should_enable_clear_history_action())
|
||||
self.toolbar.set_open_transcript_action_enabled(
|
||||
self.should_enable_open_transcript_action()
|
||||
)
|
||||
self.toolbar.set_stop_transcription_action_enabled(
|
||||
self.should_enable_stop_transcription_action()
|
||||
)
|
||||
self.toolbar.set_clear_history_action_enabled(
|
||||
self.should_enable_clear_history_action()
|
||||
)
|
||||
self.save_tasks_to_cache()
|
||||
|
||||
def on_shortcuts_changed(self, shortcuts: dict):
|
||||
|
@ -723,7 +910,6 @@ class MainWindow(QMainWindow):
|
|||
super().closeEvent(event)
|
||||
|
||||
|
||||
|
||||
class Application(QApplication):
|
||||
window: MainWindow
|
||||
|
||||
|
|
|
@ -6,12 +6,12 @@ from PyQt6.QtCore import QLocale
|
|||
from buzz.assets import get_asset_path
|
||||
from buzz.settings.settings import APP_NAME
|
||||
|
||||
if 'LANG' not in os.environ:
|
||||
if "LANG" not in os.environ:
|
||||
language = str(QLocale().uiLanguages()[0]).replace("-", "_")
|
||||
os.environ['LANG'] = language
|
||||
os.environ["LANG"] = language
|
||||
|
||||
locale_dir = get_asset_path('locale')
|
||||
gettext.bindtextdomain('buzz', locale_dir)
|
||||
locale_dir = get_asset_path("locale")
|
||||
gettext.bindtextdomain("buzz", locale_dir)
|
||||
|
||||
translate = gettext.translation(APP_NAME, locale_dir, fallback=True)
|
||||
|
||||
|
|
|
@ -20,11 +20,11 @@ from tqdm.auto import tqdm
|
|||
|
||||
|
||||
class WhisperModelSize(str, enum.Enum):
|
||||
TINY = 'tiny'
|
||||
BASE = 'base'
|
||||
SMALL = 'small'
|
||||
MEDIUM = 'medium'
|
||||
LARGE = 'large'
|
||||
TINY = "tiny"
|
||||
BASE = "base"
|
||||
SMALL = "small"
|
||||
MEDIUM = "medium"
|
||||
LARGE = "large"
|
||||
|
||||
def to_faster_whisper_model_size(self) -> str:
|
||||
if self == WhisperModelSize.LARGE:
|
||||
|
@ -33,11 +33,11 @@ class WhisperModelSize(str, enum.Enum):
|
|||
|
||||
|
||||
class ModelType(enum.Enum):
|
||||
WHISPER = 'Whisper'
|
||||
WHISPER_CPP = 'Whisper.cpp'
|
||||
HUGGING_FACE = 'Hugging Face'
|
||||
FASTER_WHISPER = 'Faster Whisper'
|
||||
OPEN_AI_WHISPER_API = 'OpenAI Whisper API'
|
||||
WHISPER = "Whisper"
|
||||
WHISPER_CPP = "Whisper.cpp"
|
||||
HUGGING_FACE = "Hugging Face"
|
||||
FASTER_WHISPER = "Faster Whisper"
|
||||
OPEN_AI_WHISPER_API = "OpenAI Whisper API"
|
||||
|
||||
|
||||
@dataclass()
|
||||
|
@ -47,9 +47,10 @@ class TranscriptionModel:
|
|||
hugging_face_model_id: Optional[str] = None
|
||||
|
||||
def is_deletable(self):
|
||||
return ((self.model_type == ModelType.WHISPER or
|
||||
self.model_type == ModelType.WHISPER_CPP) and
|
||||
self.get_local_model_path() is not None)
|
||||
return (
|
||||
self.model_type == ModelType.WHISPER
|
||||
or self.model_type == ModelType.WHISPER_CPP
|
||||
) and self.get_local_model_path() is not None
|
||||
|
||||
def open_file_location(self):
|
||||
model_path = self.get_local_model_path()
|
||||
|
@ -84,18 +85,20 @@ class TranscriptionModel:
|
|||
|
||||
if self.model_type == ModelType.FASTER_WHISPER:
|
||||
try:
|
||||
return download_faster_whisper_model(size=self.whisper_model_size.value,
|
||||
local_files_only=True)
|
||||
return download_faster_whisper_model(
|
||||
size=self.whisper_model_size.value, local_files_only=True
|
||||
)
|
||||
except (ValueError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
if self.model_type == ModelType.OPEN_AI_WHISPER_API:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
if self.model_type == ModelType.HUGGING_FACE:
|
||||
try:
|
||||
return huggingface_hub.snapshot_download(self.hugging_face_model_id,
|
||||
local_files_only=True)
|
||||
return huggingface_hub.snapshot_download(
|
||||
self.hugging_face_model_id, local_files_only=True
|
||||
)
|
||||
except (ValueError, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
@ -103,36 +106,38 @@ class TranscriptionModel:
|
|||
|
||||
|
||||
WHISPER_CPP_MODELS_SHA256 = {
|
||||
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
|
||||
'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe',
|
||||
'small': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b',
|
||||
'medium': '6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208',
|
||||
'large': '9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487'
|
||||
"tiny": "be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21",
|
||||
"base": "60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe",
|
||||
"small": "1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b",
|
||||
"medium": "6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208",
|
||||
"large": "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487",
|
||||
}
|
||||
|
||||
|
||||
def get_hugging_face_file_url(author: str, repository_name: str, filename: str):
|
||||
return f'https://huggingface.co/{author}/{repository_name}/resolve/main/{filename}'
|
||||
return f"https://huggingface.co/{author}/{repository_name}/resolve/main/{filename}"
|
||||
|
||||
|
||||
def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
|
||||
root_dir = user_cache_dir('Buzz')
|
||||
return os.path.join(root_dir, f'ggml-model-whisper-{size.value}.bin')
|
||||
root_dir = user_cache_dir("Buzz")
|
||||
return os.path.join(root_dir, f"ggml-model-whisper-{size.value}.bin")
|
||||
|
||||
|
||||
def get_whisper_file_path(size: WhisperModelSize) -> str:
|
||||
root_dir = os.getenv("XDG_CACHE_HOME", os.path.join(
|
||||
os.path.expanduser("~"), ".cache", "whisper"))
|
||||
root_dir = os.getenv(
|
||||
"XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper")
|
||||
)
|
||||
url = whisper._MODELS[size.value]
|
||||
return os.path.join(root_dir, os.path.basename(url))
|
||||
|
||||
|
||||
def download_faster_whisper_model(size: str, local_files_only=False,
|
||||
tqdm_class: Optional[tqdm] = None):
|
||||
def download_faster_whisper_model(
|
||||
size: str, local_files_only=False, tqdm_class: Optional[tqdm] = None
|
||||
):
|
||||
if size not in faster_whisper.utils._MODELS:
|
||||
raise ValueError(
|
||||
"Invalid model size '%s', expected one of: %s" % (
|
||||
size, ", ".join(faster_whisper.utils._MODELS))
|
||||
"Invalid model size '%s', expected one of: %s"
|
||||
% (size, ", ".join(faster_whisper.utils._MODELS))
|
||||
)
|
||||
|
||||
repo_id = "guillaumekln/faster-whisper-%s" % size
|
||||
|
@ -144,9 +149,12 @@ def download_faster_whisper_model(size: str, local_files_only=False,
|
|||
"vocabulary.txt",
|
||||
]
|
||||
|
||||
return huggingface_hub.snapshot_download(repo_id, allow_patterns=allow_patterns,
|
||||
local_files_only=local_files_only,
|
||||
tqdm_class=tqdm_class)
|
||||
return huggingface_hub.snapshot_download(
|
||||
repo_id,
|
||||
allow_patterns=allow_patterns,
|
||||
local_files_only=local_files_only,
|
||||
tqdm_class=tqdm_class,
|
||||
)
|
||||
|
||||
|
||||
class ModelDownloader(QRunnable):
|
||||
|
@ -165,22 +173,24 @@ class ModelDownloader(QRunnable):
|
|||
def run(self) -> None:
|
||||
if self.model.model_type == ModelType.WHISPER_CPP:
|
||||
model_name = self.model.whisper_model_size.value
|
||||
url = get_hugging_face_file_url(author='ggerganov',
|
||||
repository_name='whisper.cpp',
|
||||
filename=f'ggml-{model_name}.bin')
|
||||
file_path = get_whisper_cpp_file_path(
|
||||
size=self.model.whisper_model_size)
|
||||
url = get_hugging_face_file_url(
|
||||
author="ggerganov",
|
||||
repository_name="whisper.cpp",
|
||||
filename=f"ggml-{model_name}.bin",
|
||||
)
|
||||
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
|
||||
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
|
||||
return self.download_model_to_path(url=url, file_path=file_path,
|
||||
expected_sha256=expected_sha256)
|
||||
return self.download_model_to_path(
|
||||
url=url, file_path=file_path, expected_sha256=expected_sha256
|
||||
)
|
||||
|
||||
if self.model.model_type == ModelType.WHISPER:
|
||||
url = whisper._MODELS[self.model.whisper_model_size.value]
|
||||
file_path = get_whisper_file_path(
|
||||
size=self.model.whisper_model_size)
|
||||
expected_sha256 = url.split('/')[-2]
|
||||
return self.download_model_to_path(url=url, file_path=file_path,
|
||||
expected_sha256=expected_sha256)
|
||||
file_path = get_whisper_file_path(size=self.model.whisper_model_size)
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
return self.download_model_to_path(
|
||||
url=url, file_path=file_path, expected_sha256=expected_sha256
|
||||
)
|
||||
|
||||
progress = self.signals.progress
|
||||
|
||||
|
@ -197,44 +207,47 @@ class ModelDownloader(QRunnable):
|
|||
if self.model.model_type == ModelType.FASTER_WHISPER:
|
||||
model_path = download_faster_whisper_model(
|
||||
size=self.model.whisper_model_size.to_faster_whisper_model_size(),
|
||||
tqdm_class=_tqdm)
|
||||
tqdm_class=_tqdm,
|
||||
)
|
||||
self.signals.finished.emit(model_path)
|
||||
return
|
||||
|
||||
if self.model.model_type == ModelType.HUGGING_FACE:
|
||||
model_path = huggingface_hub.snapshot_download(
|
||||
self.model.hugging_face_model_id, tqdm_class=_tqdm)
|
||||
self.model.hugging_face_model_id, tqdm_class=_tqdm
|
||||
)
|
||||
self.signals.finished.emit(model_path)
|
||||
return
|
||||
|
||||
if self.model.model_type == ModelType.OPEN_AI_WHISPER_API:
|
||||
self.signals.finished.emit('')
|
||||
self.signals.finished.emit("")
|
||||
return
|
||||
|
||||
raise Exception("Invalid model type: " + self.model.model_type.value)
|
||||
|
||||
def download_model_to_path(self, url: str, file_path: str,
|
||||
expected_sha256: Optional[str]):
|
||||
def download_model_to_path(
|
||||
self, url: str, file_path: str, expected_sha256: Optional[str]
|
||||
):
|
||||
try:
|
||||
downloaded = self.download_model(url, file_path, expected_sha256)
|
||||
if downloaded:
|
||||
self.signals.finished.emit(file_path)
|
||||
except requests.RequestException:
|
||||
self.signals.error.emit('A connection error occurred')
|
||||
logging.exception('')
|
||||
self.signals.error.emit("A connection error occurred")
|
||||
logging.exception("")
|
||||
except Exception as exc:
|
||||
self.signals.error.emit(str(exc))
|
||||
logging.exception(exc)
|
||||
|
||||
def download_model(self, url: str, file_path: str,
|
||||
expected_sha256: Optional[str]) -> bool:
|
||||
logging.debug(f'Downloading model from {url} to {file_path}')
|
||||
def download_model(
|
||||
self, url: str, file_path: str, expected_sha256: Optional[str]
|
||||
) -> bool:
|
||||
logging.debug(f"Downloading model from {url} to {file_path}")
|
||||
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
if os.path.exists(file_path) and not os.path.isfile(file_path):
|
||||
raise RuntimeError(
|
||||
f"{file_path} exists and is not a regular file")
|
||||
raise RuntimeError(f"{file_path} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(file_path):
|
||||
if expected_sha256 is None:
|
||||
|
@ -246,17 +259,19 @@ class ModelDownloader(QRunnable):
|
|||
return True
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
tmp_file = tempfile.mktemp()
|
||||
logging.debug('Downloading to temporary file = %s', tmp_file)
|
||||
logging.debug("Downloading to temporary file = %s", tmp_file)
|
||||
|
||||
# Downloads the model using the requests module instead of urllib to
|
||||
# use the certs from certifi when the app is running in frozen mode
|
||||
with requests.get(url, stream=True, timeout=15) as source, open(tmp_file,
|
||||
'wb') as output:
|
||||
with requests.get(url, stream=True, timeout=15) as source, open(
|
||||
tmp_file, "wb"
|
||||
) as output:
|
||||
source.raise_for_status()
|
||||
total_size = float(source.headers.get('Content-Length', 0))
|
||||
total_size = float(source.headers.get("Content-Length", 0))
|
||||
current = 0.0
|
||||
self.signals.progress.emit((current, total_size))
|
||||
for chunk in source.iter_content(chunk_size=8192):
|
||||
|
@ -271,13 +286,14 @@ class ModelDownloader(QRunnable):
|
|||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
|
||||
"model.")
|
||||
"model."
|
||||
)
|
||||
|
||||
logging.debug('Downloaded model')
|
||||
logging.debug("Downloaded model")
|
||||
|
||||
# https://github.com/chidiwilliams/buzz/issues/454
|
||||
shutil.move(tmp_file, file_path)
|
||||
logging.debug('Moved file from %s to %s', tmp_file, file_path)
|
||||
logging.debug("Moved file from %s to %s", tmp_file, file_path)
|
||||
return True
|
||||
|
||||
def cancel(self):
|
||||
|
|
|
@ -7,4 +7,4 @@ def file_path_as_title(file_path: str):
|
|||
|
||||
|
||||
def file_paths_as_title(file_paths: List[str]):
|
||||
return ', '.join([file_path_as_title(path) for path in file_paths])
|
||||
return ", ".join([file_path_as_title(path) for path in file_paths])
|
||||
|
|
|
@ -10,19 +10,25 @@ class RecordingAmplitudeListener(QObject):
|
|||
stream: Optional[sounddevice.InputStream] = None
|
||||
amplitude_changed = pyqtSignal(float)
|
||||
|
||||
def __init__(self, input_device_index: Optional[int] = None,
|
||||
parent: Optional[QObject] = None,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
input_device_index: Optional[int] = None,
|
||||
parent: Optional[QObject] = None,
|
||||
):
|
||||
super().__init__(parent)
|
||||
self.input_device_index = input_device_index
|
||||
|
||||
def start_recording(self):
|
||||
try:
|
||||
self.stream = sounddevice.InputStream(device=self.input_device_index, dtype='float32',
|
||||
channels=1, callback=self.stream_callback)
|
||||
self.stream = sounddevice.InputStream(
|
||||
device=self.input_device_index,
|
||||
dtype="float32",
|
||||
channels=1,
|
||||
callback=self.stream_callback,
|
||||
)
|
||||
self.stream.start()
|
||||
except sounddevice.PortAudioError:
|
||||
logging.exception('')
|
||||
logging.exception("")
|
||||
|
||||
def stop_recording(self):
|
||||
if self.stream is not None:
|
||||
|
@ -31,5 +37,5 @@ class RecordingAmplitudeListener(QObject):
|
|||
|
||||
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
|
||||
chunk = in_data.ravel()
|
||||
amplitude = np.sqrt(np.mean(chunk ** 2)) # root-mean-square
|
||||
amplitude = np.sqrt(np.mean(chunk**2)) # root-mean-square
|
||||
self.amplitude_changed.emit(amplitude)
|
||||
|
|
|
@ -22,9 +22,14 @@ class RecordingTranscriber(QObject):
|
|||
is_running = False
|
||||
MAX_QUEUE_SIZE = 10
|
||||
|
||||
def __init__(self, transcription_options: TranscriptionOptions,
|
||||
input_device_index: Optional[int], sample_rate: int, model_path: str,
|
||||
parent: Optional[QObject] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
transcription_options: TranscriptionOptions,
|
||||
input_device_index: Optional[int],
|
||||
sample_rate: int,
|
||||
model_path: str,
|
||||
parent: Optional[QObject] = None,
|
||||
) -> None:
|
||||
super().__init__(parent)
|
||||
self.transcription_options = transcription_options
|
||||
self.current_stream = None
|
||||
|
@ -49,60 +54,91 @@ class RecordingTranscriber(QObject):
|
|||
|
||||
initial_prompt = self.transcription_options.initial_prompt
|
||||
|
||||
logging.debug('Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s',
|
||||
self.transcription_options, model_path, self.sample_rate, self.input_device_index)
|
||||
logging.debug(
|
||||
"Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s",
|
||||
self.transcription_options,
|
||||
model_path,
|
||||
self.sample_rate,
|
||||
self.input_device_index,
|
||||
)
|
||||
|
||||
self.is_running = True
|
||||
try:
|
||||
with sounddevice.InputStream(samplerate=self.sample_rate,
|
||||
device=self.input_device_index, dtype="float32",
|
||||
channels=1, callback=self.stream_callback):
|
||||
with sounddevice.InputStream(
|
||||
samplerate=self.sample_rate,
|
||||
device=self.input_device_index,
|
||||
dtype="float32",
|
||||
channels=1,
|
||||
callback=self.stream_callback,
|
||||
):
|
||||
while self.is_running:
|
||||
self.mutex.acquire()
|
||||
if self.queue.size >= self.n_batch_samples:
|
||||
samples = self.queue[:self.n_batch_samples]
|
||||
self.queue = self.queue[self.n_batch_samples:]
|
||||
samples = self.queue[: self.n_batch_samples]
|
||||
self.queue = self.queue[self.n_batch_samples :]
|
||||
self.mutex.release()
|
||||
|
||||
logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
|
||||
samples.size, self.queue.size, self.amplitude(samples))
|
||||
logging.debug(
|
||||
"Processing next frame, sample size = %s, queue size = %s, amplitude = %s",
|
||||
samples.size,
|
||||
self.queue.size,
|
||||
self.amplitude(samples),
|
||||
)
|
||||
time_started = datetime.datetime.now()
|
||||
|
||||
if self.transcription_options.model.model_type == ModelType.WHISPER:
|
||||
if (
|
||||
self.transcription_options.model.model_type
|
||||
== ModelType.WHISPER
|
||||
):
|
||||
assert isinstance(model, whisper.Whisper)
|
||||
result = model.transcribe(
|
||||
audio=samples, language=self.transcription_options.language,
|
||||
audio=samples,
|
||||
language=self.transcription_options.language,
|
||||
task=self.transcription_options.task.value,
|
||||
initial_prompt=initial_prompt,
|
||||
temperature=self.transcription_options.temperature)
|
||||
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
|
||||
temperature=self.transcription_options.temperature,
|
||||
)
|
||||
elif (
|
||||
self.transcription_options.model.model_type
|
||||
== ModelType.WHISPER_CPP
|
||||
):
|
||||
assert isinstance(model, WhisperCpp)
|
||||
result = model.transcribe(
|
||||
audio=samples,
|
||||
params=whisper_cpp_params(
|
||||
language=self.transcription_options.language
|
||||
if self.transcription_options.language is not None else 'en',
|
||||
task=self.transcription_options.task.value, word_level_timings=False))
|
||||
if self.transcription_options.language is not None
|
||||
else "en",
|
||||
task=self.transcription_options.task.value,
|
||||
word_level_timings=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert isinstance(model, TransformersWhisper)
|
||||
result = model.transcribe(audio=samples,
|
||||
language=self.transcription_options.language
|
||||
if self.transcription_options.language is not None else 'en',
|
||||
task=self.transcription_options.task.value)
|
||||
result = model.transcribe(
|
||||
audio=samples,
|
||||
language=self.transcription_options.language
|
||||
if self.transcription_options.language is not None
|
||||
else "en",
|
||||
task=self.transcription_options.task.value,
|
||||
)
|
||||
|
||||
next_text: str = result.get('text')
|
||||
next_text: str = result.get("text")
|
||||
|
||||
# Update initial prompt between successive recording chunks
|
||||
initial_prompt += next_text
|
||||
|
||||
logging.debug('Received next result, length = %s, time taken = %s',
|
||||
len(next_text), datetime.datetime.now() - time_started)
|
||||
logging.debug(
|
||||
"Received next result, length = %s, time taken = %s",
|
||||
len(next_text),
|
||||
datetime.datetime.now() - time_started,
|
||||
)
|
||||
self.transcription.emit(next_text)
|
||||
else:
|
||||
self.mutex.release()
|
||||
except PortAudioError as exc:
|
||||
self.error.emit(str(exc))
|
||||
logging.exception('')
|
||||
logging.exception("")
|
||||
return
|
||||
|
||||
self.finished.emit()
|
||||
|
@ -116,12 +152,13 @@ class RecordingTranscriber(QObject):
|
|||
whisper_sample_rate = whisper.audio.SAMPLE_RATE
|
||||
try:
|
||||
sounddevice.check_input_settings(
|
||||
device=device_id, samplerate=whisper_sample_rate)
|
||||
device=device_id, samplerate=whisper_sample_rate
|
||||
)
|
||||
return whisper_sample_rate
|
||||
except PortAudioError:
|
||||
device_info = sounddevice.query_devices(device=device_id)
|
||||
if isinstance(device_info, dict):
|
||||
return int(device_info.get('default_samplerate', whisper_sample_rate))
|
||||
return int(device_info.get("default_samplerate", whisper_sample_rate))
|
||||
return whisper_sample_rate
|
||||
|
||||
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
|
||||
|
|
|
@ -3,7 +3,7 @@ import typing
|
|||
|
||||
from PyQt6.QtCore import QSettings
|
||||
|
||||
APP_NAME = 'Buzz'
|
||||
APP_NAME = "Buzz"
|
||||
|
||||
|
||||
class Settings:
|
||||
|
@ -11,32 +11,38 @@ class Settings:
|
|||
self.settings = QSettings(APP_NAME)
|
||||
|
||||
class Key(enum.Enum):
|
||||
RECORDING_TRANSCRIBER_TASK = 'recording-transcriber/task'
|
||||
RECORDING_TRANSCRIBER_MODEL = 'recording-transcriber/model'
|
||||
RECORDING_TRANSCRIBER_LANGUAGE = 'recording-transcriber/language'
|
||||
RECORDING_TRANSCRIBER_TEMPERATURE = 'recording-transcriber/temperature'
|
||||
RECORDING_TRANSCRIBER_INITIAL_PROMPT = 'recording-transcriber/initial-prompt'
|
||||
RECORDING_TRANSCRIBER_TASK = "recording-transcriber/task"
|
||||
RECORDING_TRANSCRIBER_MODEL = "recording-transcriber/model"
|
||||
RECORDING_TRANSCRIBER_LANGUAGE = "recording-transcriber/language"
|
||||
RECORDING_TRANSCRIBER_TEMPERATURE = "recording-transcriber/temperature"
|
||||
RECORDING_TRANSCRIBER_INITIAL_PROMPT = "recording-transcriber/initial-prompt"
|
||||
|
||||
FILE_TRANSCRIBER_TASK = 'file-transcriber/task'
|
||||
FILE_TRANSCRIBER_MODEL = 'file-transcriber/model'
|
||||
FILE_TRANSCRIBER_LANGUAGE = 'file-transcriber/language'
|
||||
FILE_TRANSCRIBER_TEMPERATURE = 'file-transcriber/temperature'
|
||||
FILE_TRANSCRIBER_INITIAL_PROMPT = 'file-transcriber/initial-prompt'
|
||||
FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS = 'file-transcriber/word-level-timings'
|
||||
FILE_TRANSCRIBER_EXPORT_FORMATS = 'file-transcriber/export-formats'
|
||||
FILE_TRANSCRIBER_TASK = "file-transcriber/task"
|
||||
FILE_TRANSCRIBER_MODEL = "file-transcriber/model"
|
||||
FILE_TRANSCRIBER_LANGUAGE = "file-transcriber/language"
|
||||
FILE_TRANSCRIBER_TEMPERATURE = "file-transcriber/temperature"
|
||||
FILE_TRANSCRIBER_INITIAL_PROMPT = "file-transcriber/initial-prompt"
|
||||
FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS = "file-transcriber/word-level-timings"
|
||||
FILE_TRANSCRIBER_EXPORT_FORMATS = "file-transcriber/export-formats"
|
||||
|
||||
DEFAULT_EXPORT_FILE_NAME = 'transcriber/default-export-file-name'
|
||||
DEFAULT_EXPORT_FILE_NAME = "transcriber/default-export-file-name"
|
||||
|
||||
SHORTCUTS = 'shortcuts'
|
||||
SHORTCUTS = "shortcuts"
|
||||
|
||||
def set_value(self, key: Key, value: typing.Any) -> None:
|
||||
self.settings.setValue(key.value, value)
|
||||
|
||||
def value(self, key: Key, default_value: typing.Any,
|
||||
value_type: typing.Optional[type] = None) -> typing.Any:
|
||||
return self.settings.value(key.value, default_value,
|
||||
value_type if value_type is not None else type(
|
||||
default_value))
|
||||
def value(
|
||||
self,
|
||||
key: Key,
|
||||
default_value: typing.Any,
|
||||
value_type: typing.Optional[type] = None,
|
||||
) -> typing.Any:
|
||||
return self.settings.value(
|
||||
key.value,
|
||||
default_value,
|
||||
value_type if value_type is not None else type(default_value),
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
self.settings.clear()
|
||||
|
|
|
@ -13,13 +13,13 @@ class Shortcut(str, enum.Enum):
|
|||
obj.description = description
|
||||
return obj
|
||||
|
||||
OPEN_RECORD_WINDOW = ('Ctrl+R', "Open Record Window")
|
||||
OPEN_IMPORT_WINDOW = ('Ctrl+O', "Import File")
|
||||
OPEN_PREFERENCES_WINDOW = ('Ctrl+,', 'Open Preferences Window')
|
||||
OPEN_RECORD_WINDOW = ("Ctrl+R", "Open Record Window")
|
||||
OPEN_IMPORT_WINDOW = ("Ctrl+O", "Import File")
|
||||
OPEN_PREFERENCES_WINDOW = ("Ctrl+,", "Open Preferences Window")
|
||||
|
||||
OPEN_TRANSCRIPT_EDITOR = ('Ctrl+E', "Open Transcript Viewer")
|
||||
CLEAR_HISTORY = ('Ctrl+S', "Clear History")
|
||||
STOP_TRANSCRIPTION = ('Ctrl+X', "Cancel Transcription")
|
||||
OPEN_TRANSCRIPT_EDITOR = ("Ctrl+E", "Open Transcript Viewer")
|
||||
CLEAR_HISTORY = ("Ctrl+S", "Clear History")
|
||||
STOP_TRANSCRIPTION = ("Ctrl+X", "Cancel Transcription")
|
||||
|
||||
@staticmethod
|
||||
def get_default_shortcuts() -> typing.Dict[str, str]:
|
||||
|
|
|
@ -10,7 +10,9 @@ class ShortcutSettings:
|
|||
|
||||
def load(self) -> typing.Dict[str, str]:
|
||||
shortcuts = Shortcut.get_default_shortcuts()
|
||||
custom_shortcuts: typing.Dict[str, str] = self.settings.value(Settings.Key.SHORTCUTS, {})
|
||||
custom_shortcuts: typing.Dict[str, str] = self.settings.value(
|
||||
Settings.Key.SHORTCUTS, {}
|
||||
)
|
||||
for shortcut_name in custom_shortcuts:
|
||||
shortcuts[shortcut_name] = custom_shortcuts[shortcut_name]
|
||||
return shortcuts
|
||||
|
|
|
@ -9,20 +9,20 @@ from buzz.settings.settings import APP_NAME
|
|||
|
||||
class KeyringStore:
|
||||
class Key(enum.Enum):
|
||||
OPENAI_API_KEY = 'OpenAI API key'
|
||||
OPENAI_API_KEY = "OpenAI API key"
|
||||
|
||||
def get_password(self, key: Key) -> str:
|
||||
try:
|
||||
password = keyring.get_password(APP_NAME, username=key.value)
|
||||
if password is None:
|
||||
return ''
|
||||
return ""
|
||||
return password
|
||||
except (KeyringLocked, KeyringError) as exc:
|
||||
logging.error('Unable to read from keyring: %s', exc)
|
||||
return ''
|
||||
logging.error("Unable to read from keyring: %s", exc)
|
||||
return ""
|
||||
|
||||
def set_password(self, username: Key, password: str) -> None:
|
||||
try:
|
||||
keyring.set_password(APP_NAME, username.value, password)
|
||||
except (KeyringLocked, PasswordSetError) as exc:
|
||||
logging.error('Unable to write to keyring: %s', exc)
|
||||
logging.error("Unable to write to keyring: %s", exc)
|
||||
|
|
|
@ -38,7 +38,7 @@ try:
|
|||
|
||||
LOADED_WHISPER_DLL = True
|
||||
except ImportError:
|
||||
logging.exception('')
|
||||
logging.exception("")
|
||||
|
||||
DEFAULT_WHISPER_TEMPERATURE = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
|
||||
|
||||
|
@ -65,27 +65,28 @@ class TranscriptionOptions:
|
|||
model: TranscriptionModel = field(default_factory=TranscriptionModel)
|
||||
word_level_timings: bool = False
|
||||
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
|
||||
initial_prompt: str = ''
|
||||
openai_access_token: str = field(default='',
|
||||
metadata=config(exclude=Exclude.ALWAYS))
|
||||
initial_prompt: str = ""
|
||||
openai_access_token: str = field(
|
||||
default="", metadata=config(exclude=Exclude.ALWAYS)
|
||||
)
|
||||
|
||||
|
||||
@dataclass()
|
||||
class FileTranscriptionOptions:
|
||||
file_paths: List[str]
|
||||
output_formats: Set['OutputFormat'] = field(default_factory=set)
|
||||
default_output_file_name: str = ''
|
||||
output_formats: Set["OutputFormat"] = field(default_factory=set)
|
||||
default_output_file_name: str = ""
|
||||
|
||||
|
||||
@dataclass_json
|
||||
@dataclass
|
||||
class FileTranscriptionTask:
|
||||
class Status(enum.Enum):
|
||||
QUEUED = 'queued'
|
||||
IN_PROGRESS = 'in_progress'
|
||||
COMPLETED = 'completed'
|
||||
FAILED = 'failed'
|
||||
CANCELED = 'canceled'
|
||||
QUEUED = "queued"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELED = "canceled"
|
||||
|
||||
file_path: str
|
||||
transcription_options: TranscriptionOptions
|
||||
|
@ -102,9 +103,9 @@ class FileTranscriptionTask:
|
|||
|
||||
|
||||
class OutputFormat(enum.Enum):
|
||||
TXT = 'txt'
|
||||
SRT = 'srt'
|
||||
VTT = 'vtt'
|
||||
TXT = "txt"
|
||||
SRT = "srt"
|
||||
VTT = "vtt"
|
||||
|
||||
|
||||
class FileTranscriber(QObject):
|
||||
|
@ -113,8 +114,7 @@ class FileTranscriber(QObject):
|
|||
completed = pyqtSignal(list) # List[Segment]
|
||||
error = pyqtSignal(Exception)
|
||||
|
||||
def __init__(self, task: FileTranscriptionTask,
|
||||
parent: Optional['QObject'] = None):
|
||||
def __init__(self, task: FileTranscriptionTask, parent: Optional["QObject"] = None):
|
||||
super().__init__(parent)
|
||||
self.transcription_task = task
|
||||
|
||||
|
@ -128,12 +128,16 @@ class FileTranscriber(QObject):
|
|||
|
||||
self.completed.emit(segments)
|
||||
|
||||
for output_format in self.transcription_task.file_transcription_options.output_formats:
|
||||
default_path = get_default_output_file_path(task=self.transcription_task,
|
||||
output_format=output_format)
|
||||
for (
|
||||
output_format
|
||||
) in self.transcription_task.file_transcription_options.output_formats:
|
||||
default_path = get_default_output_file_path(
|
||||
task=self.transcription_task, output_format=output_format
|
||||
)
|
||||
|
||||
write_output(path=default_path, segments=segments,
|
||||
output_format=output_format)
|
||||
write_output(
|
||||
path=default_path, segments=segments, output_format=output_format
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def transcribe(self) -> List[Segment]:
|
||||
|
@ -150,13 +154,14 @@ class Stopped(Exception):
|
|||
|
||||
class WhisperCppFileTranscriber(FileTranscriber):
|
||||
duration_audio_ms = sys.maxsize # max int
|
||||
state: 'WhisperCppFileTranscriber.State'
|
||||
state: "WhisperCppFileTranscriber.State"
|
||||
|
||||
class State:
|
||||
running = True
|
||||
|
||||
def __init__(self, task: FileTranscriptionTask,
|
||||
parent: Optional['QObject'] = None) -> None:
|
||||
def __init__(
|
||||
self, task: FileTranscriptionTask, parent: Optional["QObject"] = None
|
||||
) -> None:
|
||||
super().__init__(task, parent)
|
||||
|
||||
self.file_path = task.file_path
|
||||
|
@ -171,24 +176,33 @@ class WhisperCppFileTranscriber(FileTranscriber):
|
|||
model_path = self.model_path
|
||||
|
||||
logging.debug(
|
||||
'Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s, '
|
||||
'word level timings = %s',
|
||||
self.file_path, self.language, self.task, model_path,
|
||||
self.word_level_timings)
|
||||
"Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s, "
|
||||
"word level timings = %s",
|
||||
self.file_path,
|
||||
self.language,
|
||||
self.task,
|
||||
model_path,
|
||||
self.word_level_timings,
|
||||
)
|
||||
|
||||
audio = whisper.audio.load_audio(self.file_path)
|
||||
self.duration_audio_ms = len(audio) * 1000 / whisper.audio.SAMPLE_RATE
|
||||
|
||||
whisper_params = whisper_cpp_params(
|
||||
language=self.language if self.language is not None else '', task=self.task,
|
||||
word_level_timings=self.word_level_timings)
|
||||
language=self.language if self.language is not None else "",
|
||||
task=self.task,
|
||||
word_level_timings=self.word_level_timings,
|
||||
)
|
||||
whisper_params.encoder_begin_callback_user_data = ctypes.c_void_p(
|
||||
id(self.state))
|
||||
whisper_params.encoder_begin_callback = whisper_cpp.whisper_encoder_begin_callback(
|
||||
self.encoder_begin_callback)
|
||||
id(self.state)
|
||||
)
|
||||
whisper_params.encoder_begin_callback = (
|
||||
whisper_cpp.whisper_encoder_begin_callback(self.encoder_begin_callback)
|
||||
)
|
||||
whisper_params.new_segment_callback_user_data = ctypes.c_void_p(id(self.state))
|
||||
whisper_params.new_segment_callback = whisper_cpp.whisper_new_segment_callback(
|
||||
self.new_segment_callback)
|
||||
self.new_segment_callback
|
||||
)
|
||||
|
||||
model = WhisperCpp(model=model_path)
|
||||
result = model.transcribe(audio=self.file_path, params=whisper_params)
|
||||
|
@ -197,7 +211,7 @@ class WhisperCppFileTranscriber(FileTranscriber):
|
|||
raise Stopped
|
||||
|
||||
self.state.running = False
|
||||
return result['segments']
|
||||
return result["segments"]
|
||||
|
||||
def new_segment_callback(self, ctx, _state, _n_new, user_data):
|
||||
n_segments = whisper_cpp.whisper_full_n_segments(ctx)
|
||||
|
@ -205,15 +219,17 @@ class WhisperCppFileTranscriber(FileTranscriber):
|
|||
# t1 seems to sometimes be larger than the duration when the
|
||||
# audio ends in silence. Trim to fix the displayed progress.
|
||||
progress = min(t1 * 10, self.duration_audio_ms)
|
||||
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data,
|
||||
ctypes.py_object).value
|
||||
state: WhisperCppFileTranscriber.State = ctypes.cast(
|
||||
user_data, ctypes.py_object
|
||||
).value
|
||||
if state.running:
|
||||
self.progress.emit((progress, self.duration_audio_ms))
|
||||
|
||||
@staticmethod
|
||||
def encoder_begin_callback(_ctx, _state, user_data):
|
||||
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data,
|
||||
ctypes.py_object).value
|
||||
state: WhisperCppFileTranscriber.State = ctypes.cast(
|
||||
user_data, ctypes.py_object
|
||||
).value
|
||||
return state.running == 1
|
||||
|
||||
def stop(self):
|
||||
|
@ -221,18 +237,19 @@ class WhisperCppFileTranscriber(FileTranscriber):
|
|||
|
||||
|
||||
class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
||||
def __init__(self, task: FileTranscriptionTask, parent: Optional['QObject'] = None):
|
||||
def __init__(self, task: FileTranscriptionTask, parent: Optional["QObject"] = None):
|
||||
super().__init__(task=task, parent=parent)
|
||||
self.file_path = task.file_path
|
||||
self.task = task.transcription_options.task
|
||||
|
||||
def transcribe(self) -> List[Segment]:
|
||||
logging.debug(
|
||||
'Starting OpenAI Whisper API file transcription, file path = %s, task = %s',
|
||||
"Starting OpenAI Whisper API file transcription, file path = %s, task = %s",
|
||||
self.file_path,
|
||||
self.task)
|
||||
self.task,
|
||||
)
|
||||
|
||||
wav_file = tempfile.mktemp() + '.wav'
|
||||
wav_file = tempfile.mktemp() + ".wav"
|
||||
(
|
||||
ffmpeg.input(self.file_path)
|
||||
.output(wav_file, acodec="pcm_s16le", ac=1, ar=whisper.audio.SAMPLE_RATE)
|
||||
|
@ -241,22 +258,30 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
|
||||
# TODO: Check if file size is more than 25MB (2.5 minutes), then chunk
|
||||
audio_file = open(wav_file, "rb")
|
||||
openai.api_key = self.transcription_task.transcription_options.openai_access_token
|
||||
openai.api_key = (
|
||||
self.transcription_task.transcription_options.openai_access_token
|
||||
)
|
||||
language = self.transcription_task.transcription_options.language
|
||||
response_format = "verbose_json"
|
||||
if self.transcription_task.transcription_options.task == Task.TRANSLATE:
|
||||
transcript = openai.Audio.translate("whisper-1", audio_file,
|
||||
response_format=response_format,
|
||||
language=language)
|
||||
transcript = openai.Audio.translate(
|
||||
"whisper-1",
|
||||
audio_file,
|
||||
response_format=response_format,
|
||||
language=language,
|
||||
)
|
||||
else:
|
||||
transcript = openai.Audio.transcribe("whisper-1", audio_file,
|
||||
response_format=response_format,
|
||||
language=language)
|
||||
transcript = openai.Audio.transcribe(
|
||||
"whisper-1",
|
||||
audio_file,
|
||||
response_format=response_format,
|
||||
language=language,
|
||||
)
|
||||
|
||||
segments = [
|
||||
Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for
|
||||
segment in
|
||||
transcript["segments"]]
|
||||
Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"])
|
||||
for segment in transcript["segments"]
|
||||
]
|
||||
return segments
|
||||
|
||||
def stop(self):
|
||||
|
@ -265,15 +290,16 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
|
|||
|
||||
class WhisperFileTranscriber(FileTranscriber):
|
||||
"""WhisperFileTranscriber transcribes an audio file to text, writes the text to a file, and then opens the file
|
||||
using the default program for opening txt files. """
|
||||
using the default program for opening txt files."""
|
||||
|
||||
current_process: multiprocessing.Process
|
||||
running = False
|
||||
read_line_thread: Optional[Thread] = None
|
||||
READ_LINE_THREAD_STOP_TOKEN = '--STOP--'
|
||||
READ_LINE_THREAD_STOP_TOKEN = "--STOP--"
|
||||
|
||||
def __init__(self, task: FileTranscriptionTask,
|
||||
parent: Optional['QObject'] = None) -> None:
|
||||
def __init__(
|
||||
self, task: FileTranscriptionTask, parent: Optional["QObject"] = None
|
||||
) -> None:
|
||||
super().__init__(task, parent)
|
||||
self.segments = []
|
||||
self.started_process = False
|
||||
|
@ -282,19 +308,19 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
def transcribe(self) -> List[Segment]:
|
||||
time_started = datetime.datetime.now()
|
||||
logging.debug(
|
||||
'Starting whisper file transcription, task = %s', self.transcription_task)
|
||||
"Starting whisper file transcription, task = %s", self.transcription_task
|
||||
)
|
||||
|
||||
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
|
||||
|
||||
self.current_process = multiprocessing.Process(target=self.transcribe_whisper,
|
||||
args=(send_pipe,
|
||||
self.transcription_task))
|
||||
self.current_process = multiprocessing.Process(
|
||||
target=self.transcribe_whisper, args=(send_pipe, self.transcription_task)
|
||||
)
|
||||
if not self.stopped:
|
||||
self.current_process.start()
|
||||
self.started_process = True
|
||||
|
||||
self.read_line_thread = Thread(
|
||||
target=self.read_line, args=(recv_pipe,))
|
||||
self.read_line_thread = Thread(target=self.read_line, args=(recv_pipe,))
|
||||
self.read_line_thread.start()
|
||||
|
||||
self.current_process.join()
|
||||
|
@ -305,76 +331,96 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
self.read_line_thread.join()
|
||||
|
||||
logging.debug(
|
||||
'whisper process completed with code = %s, time taken = %s, number of segments = %s',
|
||||
self.current_process.exitcode, datetime.datetime.now() - time_started,
|
||||
len(self.segments))
|
||||
"whisper process completed with code = %s, time taken = %s, number of segments = %s",
|
||||
self.current_process.exitcode,
|
||||
datetime.datetime.now() - time_started,
|
||||
len(self.segments),
|
||||
)
|
||||
|
||||
if self.current_process.exitcode != 0:
|
||||
raise Exception('Unknown error')
|
||||
raise Exception("Unknown error")
|
||||
|
||||
return self.segments
|
||||
|
||||
@classmethod
|
||||
def transcribe_whisper(cls, stderr_conn: Connection,
|
||||
task: FileTranscriptionTask) -> None:
|
||||
def transcribe_whisper(
|
||||
cls, stderr_conn: Connection, task: FileTranscriptionTask
|
||||
) -> None:
|
||||
with pipe_stderr(stderr_conn):
|
||||
if task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
|
||||
segments = cls.transcribe_hugging_face(task)
|
||||
elif task.transcription_options.model.model_type == ModelType.FASTER_WHISPER:
|
||||
elif (
|
||||
task.transcription_options.model.model_type == ModelType.FASTER_WHISPER
|
||||
):
|
||||
segments = cls.transcribe_faster_whisper(task)
|
||||
elif task.transcription_options.model.model_type == ModelType.WHISPER:
|
||||
segments = cls.transcribe_openai_whisper(task)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Invalid model type: {task.transcription_options.model.model_type}")
|
||||
f"Invalid model type: {task.transcription_options.model.model_type}"
|
||||
)
|
||||
|
||||
segments_json = json.dumps(
|
||||
segments, ensure_ascii=True, default=vars)
|
||||
sys.stderr.write(f'segments = {segments_json}\n')
|
||||
sys.stderr.write(
|
||||
WhisperFileTranscriber.READ_LINE_THREAD_STOP_TOKEN + '\n')
|
||||
segments_json = json.dumps(segments, ensure_ascii=True, default=vars)
|
||||
sys.stderr.write(f"segments = {segments_json}\n")
|
||||
sys.stderr.write(WhisperFileTranscriber.READ_LINE_THREAD_STOP_TOKEN + "\n")
|
||||
|
||||
@classmethod
|
||||
def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]:
|
||||
model = transformers_whisper.load_model(task.model_path)
|
||||
language = task.transcription_options.language if task.transcription_options.language is not None else 'en'
|
||||
result = model.transcribe(audio=task.file_path, language=language,
|
||||
task=task.transcription_options.task.value,
|
||||
verbose=False)
|
||||
language = (
|
||||
task.transcription_options.language
|
||||
if task.transcription_options.language is not None
|
||||
else "en"
|
||||
)
|
||||
result = model.transcribe(
|
||||
audio=task.file_path,
|
||||
language=language,
|
||||
task=task.transcription_options.task.value,
|
||||
verbose=False,
|
||||
)
|
||||
return [
|
||||
Segment(
|
||||
start=int(segment.get('start') * 1000),
|
||||
end=int(segment.get('end') * 1000),
|
||||
text=segment.get('text'),
|
||||
) for segment in result.get('segments')]
|
||||
start=int(segment.get("start") * 1000),
|
||||
end=int(segment.get("end") * 1000),
|
||||
text=segment.get("text"),
|
||||
)
|
||||
for segment in result.get("segments")
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def transcribe_faster_whisper(cls, task: FileTranscriptionTask) -> List[Segment]:
|
||||
model = faster_whisper.WhisperModel(
|
||||
model_size_or_path=task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size())
|
||||
whisper_segments, info = model.transcribe(audio=task.file_path,
|
||||
language=task.transcription_options.language,
|
||||
task=task.transcription_options.task.value,
|
||||
temperature=task.transcription_options.temperature,
|
||||
initial_prompt=task.transcription_options.initial_prompt,
|
||||
word_timestamps=task.transcription_options.word_level_timings)
|
||||
model_size_or_path=task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size()
|
||||
)
|
||||
whisper_segments, info = model.transcribe(
|
||||
audio=task.file_path,
|
||||
language=task.transcription_options.language,
|
||||
task=task.transcription_options.task.value,
|
||||
temperature=task.transcription_options.temperature,
|
||||
initial_prompt=task.transcription_options.initial_prompt,
|
||||
word_timestamps=task.transcription_options.word_level_timings,
|
||||
)
|
||||
segments = []
|
||||
with tqdm.tqdm(total=round(info.duration, 2), unit=' seconds') as pbar:
|
||||
with tqdm.tqdm(total=round(info.duration, 2), unit=" seconds") as pbar:
|
||||
for segment in list(whisper_segments):
|
||||
# Segment will contain words if word-level timings is True
|
||||
if segment.words:
|
||||
for word in segment.words:
|
||||
segments.append(Segment(
|
||||
start=int(word.start * 1000),
|
||||
end=int(word.end * 1000),
|
||||
text=word.word
|
||||
))
|
||||
segments.append(
|
||||
Segment(
|
||||
start=int(word.start * 1000),
|
||||
end=int(word.end * 1000),
|
||||
text=word.word,
|
||||
)
|
||||
)
|
||||
else:
|
||||
segments.append(Segment(
|
||||
start=int(segment.start * 1000),
|
||||
end=int(segment.end * 1000),
|
||||
text=segment.text
|
||||
))
|
||||
segments.append(
|
||||
Segment(
|
||||
start=int(segment.start * 1000),
|
||||
end=int(segment.end * 1000),
|
||||
text=segment.text,
|
||||
)
|
||||
)
|
||||
|
||||
pbar.update(segment.end - segment.start)
|
||||
return segments
|
||||
|
@ -386,28 +432,40 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
if task.transcription_options.word_level_timings:
|
||||
stable_whisper.modify_model(model)
|
||||
result = model.transcribe(
|
||||
audio=task.file_path, language=task.transcription_options.language,
|
||||
audio=task.file_path,
|
||||
language=task.transcription_options.language,
|
||||
task=task.transcription_options.task.value,
|
||||
temperature=task.transcription_options.temperature,
|
||||
initial_prompt=task.transcription_options.initial_prompt, pbar=True)
|
||||
initial_prompt=task.transcription_options.initial_prompt,
|
||||
pbar=True,
|
||||
)
|
||||
segments = stable_whisper.group_word_timestamps(result)
|
||||
return [Segment(
|
||||
start=int(segment.get('start') * 1000),
|
||||
end=int(segment.get('end') * 1000),
|
||||
text=segment.get('text'),
|
||||
) for segment in segments]
|
||||
return [
|
||||
Segment(
|
||||
start=int(segment.get("start") * 1000),
|
||||
end=int(segment.get("end") * 1000),
|
||||
text=segment.get("text"),
|
||||
)
|
||||
for segment in segments
|
||||
]
|
||||
|
||||
result = model.transcribe(
|
||||
audio=task.file_path, language=task.transcription_options.language,
|
||||
audio=task.file_path,
|
||||
language=task.transcription_options.language,
|
||||
task=task.transcription_options.task.value,
|
||||
temperature=task.transcription_options.temperature,
|
||||
initial_prompt=task.transcription_options.initial_prompt, verbose=False)
|
||||
segments = result.get('segments')
|
||||
return [Segment(
|
||||
start=int(segment.get('start') * 1000),
|
||||
end=int(segment.get('end') * 1000),
|
||||
text=segment.get('text'),
|
||||
) for segment in segments]
|
||||
initial_prompt=task.transcription_options.initial_prompt,
|
||||
verbose=False,
|
||||
)
|
||||
segments = result.get("segments")
|
||||
return [
|
||||
Segment(
|
||||
start=int(segment.get("start") * 1000),
|
||||
end=int(segment.get("end") * 1000),
|
||||
text=segment.get("text"),
|
||||
)
|
||||
for segment in segments
|
||||
]
|
||||
|
||||
def stop(self):
|
||||
self.stopped = True
|
||||
|
@ -424,102 +482,119 @@ class WhisperFileTranscriber(FileTranscriber):
|
|||
if line == self.READ_LINE_THREAD_STOP_TOKEN:
|
||||
return
|
||||
|
||||
if line.startswith('segments = '):
|
||||
if line.startswith("segments = "):
|
||||
segments_dict = json.loads(line[11:])
|
||||
segments = [Segment(
|
||||
start=segment.get('start'),
|
||||
end=segment.get('end'),
|
||||
text=segment.get('text'),
|
||||
) for segment in segments_dict]
|
||||
segments = [
|
||||
Segment(
|
||||
start=segment.get("start"),
|
||||
end=segment.get("end"),
|
||||
text=segment.get("text"),
|
||||
)
|
||||
for segment in segments_dict
|
||||
]
|
||||
self.segments = segments
|
||||
else:
|
||||
try:
|
||||
progress = int(line.split('|')[0].strip().strip('%'))
|
||||
progress = int(line.split("|")[0].strip().strip("%"))
|
||||
self.progress.emit((progress, 100))
|
||||
except ValueError:
|
||||
logging.debug('whisper (stderr): %s', line)
|
||||
logging.debug("whisper (stderr): %s", line)
|
||||
continue
|
||||
|
||||
|
||||
def write_output(path: str, segments: List[Segment], output_format: OutputFormat):
|
||||
logging.debug(
|
||||
'Writing transcription output, path = %s, output format = %s, number of segments = %s',
|
||||
path, output_format,
|
||||
len(segments))
|
||||
"Writing transcription output, path = %s, output format = %s, number of segments = %s",
|
||||
path,
|
||||
output_format,
|
||||
len(segments),
|
||||
)
|
||||
|
||||
with open(path, 'w', encoding='utf-8') as file:
|
||||
with open(path, "w", encoding="utf-8") as file:
|
||||
if output_format == OutputFormat.TXT:
|
||||
for (i, segment) in enumerate(segments):
|
||||
for i, segment in enumerate(segments):
|
||||
file.write(segment.text)
|
||||
file.write('\n')
|
||||
file.write("\n")
|
||||
|
||||
elif output_format == OutputFormat.VTT:
|
||||
file.write('WEBVTT\n\n')
|
||||
file.write("WEBVTT\n\n")
|
||||
for segment in segments:
|
||||
file.write(
|
||||
f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n')
|
||||
file.write(f'{segment.text}\n\n')
|
||||
f"{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n"
|
||||
)
|
||||
file.write(f"{segment.text}\n\n")
|
||||
|
||||
elif output_format == OutputFormat.SRT:
|
||||
for (i, segment) in enumerate(segments):
|
||||
file.write(f'{i + 1}\n')
|
||||
for i, segment in enumerate(segments):
|
||||
file.write(f"{i + 1}\n")
|
||||
file.write(
|
||||
f'{to_timestamp(segment.start, ms_separator=",")} --> {to_timestamp(segment.end, ms_separator=",")}\n')
|
||||
file.write(f'{segment.text}\n\n')
|
||||
f'{to_timestamp(segment.start, ms_separator=",")} --> {to_timestamp(segment.end, ms_separator=",")}\n'
|
||||
)
|
||||
file.write(f"{segment.text}\n\n")
|
||||
|
||||
logging.debug('Written transcription output')
|
||||
logging.debug("Written transcription output")
|
||||
|
||||
|
||||
def segments_to_text(segments: List[Segment]) -> str:
|
||||
result = ''
|
||||
for (i, segment) in enumerate(segments):
|
||||
result += f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n'
|
||||
result += f'{segment.text}'
|
||||
result = ""
|
||||
for i, segment in enumerate(segments):
|
||||
result += f"{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n"
|
||||
result += f"{segment.text}"
|
||||
if i < len(segments) - 1:
|
||||
result += '\n\n'
|
||||
result += "\n\n"
|
||||
return result
|
||||
|
||||
|
||||
def to_timestamp(ms: float, ms_separator='.') -> str:
|
||||
def to_timestamp(ms: float, ms_separator=".") -> str:
|
||||
hr = int(ms / (1000 * 60 * 60))
|
||||
ms = ms - hr * (1000 * 60 * 60)
|
||||
min = int(ms / (1000 * 60))
|
||||
ms = ms - min * (1000 * 60)
|
||||
sec = int(ms / 1000)
|
||||
ms = int(ms - sec * 1000)
|
||||
return f'{hr:02d}:{min:02d}:{sec:02d}{ms_separator}{ms:03d}'
|
||||
return f"{hr:02d}:{min:02d}:{sec:02d}{ms_separator}{ms:03d}"
|
||||
|
||||
|
||||
SUPPORTED_OUTPUT_FORMATS = 'Audio files (*.mp3 *.wav *.m4a *.ogg);;\
|
||||
Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)'
|
||||
SUPPORTED_OUTPUT_FORMATS = "Audio files (*.mp3 *.wav *.m4a *.ogg);;\
|
||||
Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)"
|
||||
|
||||
|
||||
def get_default_output_file_path(task: FileTranscriptionTask,
|
||||
output_format: OutputFormat):
|
||||
def get_default_output_file_path(
|
||||
task: FileTranscriptionTask, output_format: OutputFormat
|
||||
):
|
||||
input_file_name = os.path.splitext(task.file_path)[0]
|
||||
date_time_now = datetime.datetime.now().strftime('%d-%b-%Y %H-%M-%S')
|
||||
return (task.file_transcription_options.default_output_file_name
|
||||
.replace('{{ input_file_name }}', input_file_name)
|
||||
.replace('{{ task }}', task.transcription_options.task.value)
|
||||
.replace('{{ language }}', task.transcription_options.language or '')
|
||||
.replace('{{ model_type }}',
|
||||
task.transcription_options.model.model_type.value)
|
||||
.replace('{{ model_size }}',
|
||||
task.transcription_options.model.whisper_model_size.value if
|
||||
task.transcription_options.model.whisper_model_size is not None else
|
||||
'')
|
||||
.replace('{{ date_time }}', date_time_now)
|
||||
+ f".{output_format.value}")
|
||||
date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S")
|
||||
return (
|
||||
task.file_transcription_options.default_output_file_name.replace(
|
||||
"{{ input_file_name }}", input_file_name
|
||||
)
|
||||
.replace("{{ task }}", task.transcription_options.task.value)
|
||||
.replace("{{ language }}", task.transcription_options.language or "")
|
||||
.replace("{{ model_type }}", task.transcription_options.model.model_type.value)
|
||||
.replace(
|
||||
"{{ model_size }}",
|
||||
task.transcription_options.model.whisper_model_size.value
|
||||
if task.transcription_options.model.whisper_model_size is not None
|
||||
else "",
|
||||
)
|
||||
.replace("{{ date_time }}", date_time_now)
|
||||
+ f".{output_format.value}"
|
||||
)
|
||||
|
||||
|
||||
def whisper_cpp_params(
|
||||
language: str, task: Task, word_level_timings: bool,
|
||||
print_realtime=False, print_progress=False, ):
|
||||
language: str,
|
||||
task: Task,
|
||||
word_level_timings: bool,
|
||||
print_realtime=False,
|
||||
print_progress=False,
|
||||
):
|
||||
params = whisper_cpp.whisper_full_default_params(
|
||||
whisper_cpp.WHISPER_SAMPLING_GREEDY)
|
||||
whisper_cpp.WHISPER_SAMPLING_GREEDY
|
||||
)
|
||||
params.print_realtime = print_realtime
|
||||
params.print_progress = print_progress
|
||||
params.language = whisper_cpp.String(language.encode('utf-8'))
|
||||
params.language = whisper_cpp.String(language.encode("utf-8"))
|
||||
params.translate = task == Task.TRANSLATE
|
||||
params.max_len = ctypes.c_int(1)
|
||||
params.max_len = 1 if word_level_timings else 0
|
||||
|
@ -529,20 +604,20 @@ def whisper_cpp_params(
|
|||
|
||||
class WhisperCpp:
|
||||
def __init__(self, model: str) -> None:
|
||||
self.ctx = whisper_cpp.whisper_init_from_file(model.encode('utf-8'))
|
||||
self.ctx = whisper_cpp.whisper_init_from_file(model.encode("utf-8"))
|
||||
|
||||
def transcribe(self, audio: Union[np.ndarray, str], params: Any):
|
||||
if isinstance(audio, str):
|
||||
audio = whisper.audio.load_audio(audio)
|
||||
|
||||
logging.debug('Loaded audio with length = %s', len(audio))
|
||||
logging.debug("Loaded audio with length = %s", len(audio))
|
||||
|
||||
whisper_cpp_audio = audio.ctypes.data_as(
|
||||
ctypes.POINTER(ctypes.c_float))
|
||||
whisper_cpp_audio = audio.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
|
||||
result = whisper_cpp.whisper_full(
|
||||
self.ctx, params, whisper_cpp_audio, len(audio))
|
||||
self.ctx, params, whisper_cpp_audio, len(audio)
|
||||
)
|
||||
if result != 0:
|
||||
raise Exception(f'Error from whisper.cpp: {result}')
|
||||
raise Exception(f"Error from whisper.cpp: {result}")
|
||||
|
||||
segments: List[Segment] = []
|
||||
|
||||
|
@ -553,13 +628,17 @@ class WhisperCpp:
|
|||
t1 = whisper_cpp.whisper_full_get_segment_t1((self.ctx), i)
|
||||
|
||||
segments.append(
|
||||
Segment(start=t0 * 10, # centisecond to ms
|
||||
end=t1 * 10, # centisecond to ms
|
||||
text=txt.decode('utf-8')))
|
||||
Segment(
|
||||
start=t0 * 10, # centisecond to ms
|
||||
end=t1 * 10, # centisecond to ms
|
||||
text=txt.decode("utf-8"),
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
'segments': segments,
|
||||
'text': ''.join([segment.text for segment in segments])}
|
||||
"segments": segments,
|
||||
"text": "".join([segment.text for segment in segments]),
|
||||
}
|
||||
|
||||
def __del__(self):
|
||||
whisper_cpp.whisper_free(self.ctx)
|
||||
|
|
|
@ -16,43 +16,65 @@ class TransformersWhisper:
|
|||
SAMPLE_RATE = whisper.audio.SAMPLE_RATE
|
||||
N_SAMPLES_IN_CHUNK = whisper.audio.N_SAMPLES
|
||||
|
||||
def __init__(self, processor: WhisperProcessor, model: WhisperForConditionalGeneration):
|
||||
def __init__(
|
||||
self, processor: WhisperProcessor, model: WhisperForConditionalGeneration
|
||||
):
|
||||
self.processor = processor
|
||||
self.model = model
|
||||
|
||||
# Patch implementation of transcribing with transformers' WhisperProcessor until long-form transcription and
|
||||
# timestamps are available. See: https://github.com/huggingface/transformers/issues/19887,
|
||||
# https://github.com/huggingface/transformers/pull/20620.
|
||||
def transcribe(self, audio: Union[str, np.ndarray], language: str, task: str, verbose: Optional[bool] = None):
|
||||
def transcribe(
|
||||
self,
|
||||
audio: Union[str, np.ndarray],
|
||||
language: str,
|
||||
task: str,
|
||||
verbose: Optional[bool] = None,
|
||||
):
|
||||
if isinstance(audio, str):
|
||||
audio = whisper.load_audio(audio, sr=self.SAMPLE_RATE)
|
||||
|
||||
self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(task=task, language=language)
|
||||
self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(
|
||||
task=task, language=language
|
||||
)
|
||||
|
||||
segments = []
|
||||
all_predicted_ids = []
|
||||
|
||||
num_samples = audio.size
|
||||
seek = 0
|
||||
with tqdm(total=num_samples, unit='samples', disable=verbose is not False) as progress_bar:
|
||||
with tqdm(
|
||||
total=num_samples, unit="samples", disable=verbose is not False
|
||||
) as progress_bar:
|
||||
while seek < num_samples:
|
||||
chunk = audio[seek: seek + self.N_SAMPLES_IN_CHUNK]
|
||||
input_features = self.processor(chunk, return_tensors="pt",
|
||||
sampling_rate=self.SAMPLE_RATE).input_features
|
||||
chunk = audio[seek : seek + self.N_SAMPLES_IN_CHUNK]
|
||||
input_features = self.processor(
|
||||
chunk, return_tensors="pt", sampling_rate=self.SAMPLE_RATE
|
||||
).input_features
|
||||
predicted_ids = self.model.generate(input_features)
|
||||
all_predicted_ids.extend(predicted_ids)
|
||||
text: str = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
|
||||
if text.strip() != '':
|
||||
segments.append({
|
||||
'start': seek / self.SAMPLE_RATE,
|
||||
'end': min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) / self.SAMPLE_RATE,
|
||||
'text': text
|
||||
})
|
||||
text: str = self.processor.batch_decode(
|
||||
predicted_ids, skip_special_tokens=True
|
||||
)[0]
|
||||
if text.strip() != "":
|
||||
segments.append(
|
||||
{
|
||||
"start": seek / self.SAMPLE_RATE,
|
||||
"end": min(seek + self.N_SAMPLES_IN_CHUNK, num_samples)
|
||||
/ self.SAMPLE_RATE,
|
||||
"text": text,
|
||||
}
|
||||
)
|
||||
|
||||
progress_bar.update(min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek)
|
||||
progress_bar.update(
|
||||
min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek
|
||||
)
|
||||
seek += self.N_SAMPLES_IN_CHUNK
|
||||
|
||||
return {
|
||||
'text': self.processor.batch_decode(all_predicted_ids, skip_special_tokens=True)[0],
|
||||
'segments': segments
|
||||
"text": self.processor.batch_decode(
|
||||
all_predicted_ids, skip_special_tokens=True
|
||||
)[0],
|
||||
"segments": segments,
|
||||
}
|
||||
|
|
|
@ -5,8 +5,15 @@ from PyQt6 import QtGui
|
|||
from PyQt6.QtCore import Qt, QUrl
|
||||
from PyQt6.QtGui import QIcon, QPixmap, QDesktopServices
|
||||
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply
|
||||
from PyQt6.QtWidgets import QDialog, QWidget, QVBoxLayout, QLabel, QPushButton, \
|
||||
QDialogButtonBox, QMessageBox
|
||||
from PyQt6.QtWidgets import (
|
||||
QDialog,
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QLabel,
|
||||
QPushButton,
|
||||
QDialogButtonBox,
|
||||
QMessageBox,
|
||||
)
|
||||
|
||||
from buzz.__version__ import VERSION
|
||||
from buzz.widgets.icon import BUZZ_ICON_PATH, BUZZ_LARGE_ICON_PATH
|
||||
|
@ -15,11 +22,16 @@ from buzz.settings.settings import APP_NAME
|
|||
|
||||
|
||||
class AboutDialog(QDialog):
|
||||
GITHUB_API_LATEST_RELEASE_URL = 'https://api.github.com/repos/chidiwilliams/buzz/releases/latest'
|
||||
GITHUB_LATEST_RELEASE_URL = 'https://github.com/chidiwilliams/buzz/releases/latest'
|
||||
GITHUB_API_LATEST_RELEASE_URL = (
|
||||
"https://api.github.com/repos/chidiwilliams/buzz/releases/latest"
|
||||
)
|
||||
GITHUB_LATEST_RELEASE_URL = "https://github.com/chidiwilliams/buzz/releases/latest"
|
||||
|
||||
def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None,
|
||||
parent: Optional[QWidget] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
network_access_manager: Optional[QNetworkAccessManager] = None,
|
||||
parent: Optional[QWidget] = None,
|
||||
) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
self.setWindowIcon(QIcon(BUZZ_ICON_PATH))
|
||||
|
@ -35,28 +47,42 @@ class AboutDialog(QDialog):
|
|||
|
||||
image_label = QLabel()
|
||||
pixmap = QPixmap(BUZZ_LARGE_ICON_PATH).scaled(
|
||||
80, 80, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
|
||||
80,
|
||||
80,
|
||||
Qt.AspectRatioMode.KeepAspectRatio,
|
||||
Qt.TransformationMode.SmoothTransformation,
|
||||
)
|
||||
image_label.setPixmap(pixmap)
|
||||
image_label.setAlignment(Qt.AlignmentFlag(
|
||||
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter))
|
||||
image_label.setAlignment(
|
||||
Qt.AlignmentFlag(
|
||||
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter
|
||||
)
|
||||
)
|
||||
|
||||
buzz_label = QLabel(APP_NAME)
|
||||
buzz_label.setAlignment(Qt.AlignmentFlag(
|
||||
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter))
|
||||
buzz_label.setAlignment(
|
||||
Qt.AlignmentFlag(
|
||||
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter
|
||||
)
|
||||
)
|
||||
buzz_label_font = QtGui.QFont()
|
||||
buzz_label_font.setBold(True)
|
||||
buzz_label_font.setPointSize(20)
|
||||
buzz_label.setFont(buzz_label_font)
|
||||
|
||||
version_label = QLabel(f"{_('Version')} {VERSION}")
|
||||
version_label.setAlignment(Qt.AlignmentFlag(
|
||||
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter))
|
||||
version_label.setAlignment(
|
||||
Qt.AlignmentFlag(
|
||||
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter
|
||||
)
|
||||
)
|
||||
|
||||
self.check_updates_button = QPushButton(_('Check for updates'), self)
|
||||
self.check_updates_button = QPushButton(_("Check for updates"), self)
|
||||
self.check_updates_button.clicked.connect(self.on_click_check_for_updates)
|
||||
|
||||
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
|
||||
QDialogButtonBox.StandardButton.Close), self)
|
||||
button_box = QDialogButtonBox(
|
||||
QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Close), self
|
||||
)
|
||||
button_box.accepted.connect(self.accept)
|
||||
button_box.rejected.connect(self.reject)
|
||||
|
||||
|
@ -76,13 +102,13 @@ class AboutDialog(QDialog):
|
|||
def on_latest_release_reply(self, reply: QNetworkReply):
|
||||
if reply.error() == QNetworkReply.NetworkError.NoError:
|
||||
response = json.loads(reply.readAll().data())
|
||||
tag_name = response.get('name')
|
||||
tag_name = response.get("name")
|
||||
if self.is_version_lower(VERSION, tag_name[1:]):
|
||||
QDesktopServices.openUrl(QUrl(self.GITHUB_LATEST_RELEASE_URL))
|
||||
else:
|
||||
QMessageBox.information(self, '', _("You're up to date!"))
|
||||
QMessageBox.information(self, "", _("You're up to date!"))
|
||||
self.check_updates_button.setEnabled(True)
|
||||
|
||||
@staticmethod
|
||||
def is_version_lower(version_a: str, version_b: str):
|
||||
return version_a.replace('.', '') < version_b.replace('.', '')
|
||||
return version_a.replace(".", "") < version_b.replace(".", "")
|
||||
|
|
|
@ -111,7 +111,7 @@ class AudioPlayer(QWidget):
|
|||
def update_time_label(self):
|
||||
position_time = QTime(0, 0).addMSecs(self.position).toString()
|
||||
duration_time = QTime(0, 0).addMSecs(self.duration).toString()
|
||||
self.time_label.setText(f'{position_time} / {duration_time}')
|
||||
self.time_label.setText(f"{position_time} / {duration_time}")
|
||||
|
||||
def stop(self):
|
||||
self.media_player.stop()
|
||||
|
|
|
@ -6,8 +6,8 @@ from buzz.assets import get_asset_path
|
|||
|
||||
# TODO: move icons to Qt resources: https://stackoverflow.com/a/52341917/9830227
|
||||
class Icon(QIcon):
|
||||
LIGHT_THEME_BACKGROUND = '#555'
|
||||
DARK_THEME_BACKGROUND = '#EEE'
|
||||
LIGHT_THEME_BACKGROUND = "#555"
|
||||
DARK_THEME_BACKGROUND = "#EEE"
|
||||
|
||||
def __init__(self, path: str, parent: QWidget):
|
||||
# Adapted from https://stackoverflow.com/questions/15123544/change-the-color-of-an-svg-in-qt
|
||||
|
@ -23,18 +23,20 @@ class Icon(QIcon):
|
|||
super().__init__(pixmap)
|
||||
|
||||
def get_color(self, is_dark_theme):
|
||||
return self.DARK_THEME_BACKGROUND if is_dark_theme else self.LIGHT_THEME_BACKGROUND
|
||||
return (
|
||||
self.DARK_THEME_BACKGROUND if is_dark_theme else self.LIGHT_THEME_BACKGROUND
|
||||
)
|
||||
|
||||
|
||||
class PlayIcon(Icon):
|
||||
def __init__(self, parent: QWidget):
|
||||
super().__init__(get_asset_path('assets/play_arrow_black_24dp.svg'), parent)
|
||||
super().__init__(get_asset_path("assets/play_arrow_black_24dp.svg"), parent)
|
||||
|
||||
|
||||
class PauseIcon(Icon):
|
||||
def __init__(self, parent: QWidget):
|
||||
super().__init__(get_asset_path('assets/pause_black_24dp.svg'), parent)
|
||||
super().__init__(get_asset_path("assets/pause_black_24dp.svg"), parent)
|
||||
|
||||
|
||||
BUZZ_ICON_PATH = get_asset_path('assets/buzz.ico')
|
||||
BUZZ_LARGE_ICON_PATH = get_asset_path('assets/buzz-icon-1024.png')
|
||||
BUZZ_ICON_PATH = get_asset_path("assets/buzz.ico")
|
||||
BUZZ_LARGE_ICON_PATH = get_asset_path("assets/buzz-icon-1024.png")
|
||||
|
|
|
@ -5,7 +5,7 @@ from PyQt6.QtWidgets import QLineEdit, QWidget
|
|||
|
||||
|
||||
class LineEdit(QLineEdit):
|
||||
def __init__(self, default_text: str = '', parent: Optional[QWidget] = None):
|
||||
def __init__(self, default_text: str = "", parent: Optional[QWidget] = None):
|
||||
super().__init__(default_text, parent)
|
||||
if platform.system() == 'Darwin':
|
||||
self.setStyleSheet('QLineEdit { padding: 4px }')
|
||||
if platform.system() == "Darwin":
|
||||
self.setStyleSheet("QLineEdit { padding: 4px }")
|
||||
|
|
|
@ -18,16 +18,16 @@ class MenuBar(QMenuBar):
|
|||
openai_api_key_changed = pyqtSignal(str)
|
||||
default_export_file_name_changed = pyqtSignal(str)
|
||||
|
||||
def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str,
|
||||
parent: QWidget):
|
||||
def __init__(
|
||||
self, shortcuts: Dict[str, str], default_export_file_name: str, parent: QWidget
|
||||
):
|
||||
super().__init__(parent)
|
||||
|
||||
self.shortcuts = shortcuts
|
||||
self.default_export_file_name = default_export_file_name
|
||||
|
||||
self.import_action = QAction(_("Import Media File..."), self)
|
||||
self.import_action.triggered.connect(
|
||||
self.on_import_action_triggered)
|
||||
self.import_action.triggered.connect(self.on_import_action_triggered)
|
||||
|
||||
about_action = QAction(f'{_("About")} {APP_NAME}', self)
|
||||
about_action.triggered.connect(self.on_about_action_triggered)
|
||||
|
@ -56,22 +56,27 @@ class MenuBar(QMenuBar):
|
|||
about_dialog.open()
|
||||
|
||||
def on_preferences_action_triggered(self):
|
||||
preferences_dialog = PreferencesDialog(shortcuts=self.shortcuts,
|
||||
default_export_file_name=self.default_export_file_name,
|
||||
parent=self)
|
||||
preferences_dialog = PreferencesDialog(
|
||||
shortcuts=self.shortcuts,
|
||||
default_export_file_name=self.default_export_file_name,
|
||||
parent=self,
|
||||
)
|
||||
preferences_dialog.shortcuts_changed.connect(self.shortcuts_changed)
|
||||
preferences_dialog.openai_api_key_changed.connect(self.openai_api_key_changed)
|
||||
preferences_dialog.default_export_file_name_changed.connect(
|
||||
self.default_export_file_name_changed)
|
||||
self.default_export_file_name_changed
|
||||
)
|
||||
preferences_dialog.open()
|
||||
|
||||
def on_help_action_triggered(self):
|
||||
webbrowser.open('https://chidiwilliams.github.io/buzz/docs')
|
||||
webbrowser.open("https://chidiwilliams.github.io/buzz/docs")
|
||||
|
||||
def set_shortcuts(self, shortcuts: Dict[str, str]):
|
||||
self.shortcuts = shortcuts
|
||||
|
||||
self.import_action.setShortcut(
|
||||
QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name]))
|
||||
QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name])
|
||||
)
|
||||
self.preferences_action.setShortcut(
|
||||
QKeySequence.fromString(shortcuts[Shortcut.OPEN_PREFERENCES_WINDOW.name]))
|
||||
QKeySequence.fromString(shortcuts[Shortcut.OPEN_PREFERENCES_WINDOW.name])
|
||||
)
|
||||
|
|
|
@ -10,10 +10,17 @@ from buzz.model_loader import ModelType
|
|||
|
||||
|
||||
class ModelDownloadProgressDialog(QProgressDialog):
|
||||
def __init__(self, model_type: ModelType, parent: Optional[QWidget] = None, modality=Qt.WindowModality.WindowModal):
|
||||
def __init__(
|
||||
self,
|
||||
model_type: ModelType,
|
||||
parent: Optional[QWidget] = None,
|
||||
modality=Qt.WindowModality.WindowModal,
|
||||
):
|
||||
super().__init__(parent)
|
||||
|
||||
self.cancelable = model_type == ModelType.WHISPER or model_type == ModelType.WHISPER_CPP
|
||||
self.cancelable = (
|
||||
model_type == ModelType.WHISPER or model_type == ModelType.WHISPER_CPP
|
||||
)
|
||||
self.start_time = datetime.now()
|
||||
self.setRange(0, 100)
|
||||
self.setMinimumDuration(0)
|
||||
|
@ -21,7 +28,7 @@ class ModelDownloadProgressDialog(QProgressDialog):
|
|||
self.update_label_text(0)
|
||||
|
||||
if not self.cancelable:
|
||||
cancel_button = QPushButton('Cancel', self)
|
||||
cancel_button = QPushButton("Cancel", self)
|
||||
cancel_button.setEnabled(False)
|
||||
self.setCancelButton(cancel_button)
|
||||
|
||||
|
@ -30,8 +37,8 @@ class ModelDownloadProgressDialog(QProgressDialog):
|
|||
if fraction_completed > 0:
|
||||
time_spent = (datetime.now() - self.start_time).total_seconds()
|
||||
time_left = (time_spent / fraction_completed) - time_spent
|
||||
label_text += f', {humanize.naturaldelta(time_left)} remaining'
|
||||
label_text += ')'
|
||||
label_text += f", {humanize.naturaldelta(time_left)} remaining"
|
||||
label_text += ")"
|
||||
|
||||
self.setLabelText(label_text)
|
||||
|
||||
|
|
|
@ -10,8 +10,12 @@ from buzz.transcriber import LOADED_WHISPER_DLL
|
|||
class ModelTypeComboBox(QComboBox):
|
||||
changed = pyqtSignal(ModelType)
|
||||
|
||||
def __init__(self, model_types: Optional[List[ModelType]] = None, default_model: Optional[ModelType] = None,
|
||||
parent: Optional[QWidget] = None):
|
||||
def __init__(
|
||||
self,
|
||||
model_types: Optional[List[ModelType]] = None,
|
||||
default_model: Optional[ModelType] = None,
|
||||
parent: Optional[QWidget] = None,
|
||||
):
|
||||
super().__init__(parent)
|
||||
|
||||
if model_types is None:
|
||||
|
|
|
@ -16,15 +16,22 @@ class OpenAIAPIKeyLineEdit(LineEdit):
|
|||
|
||||
self.key = key
|
||||
|
||||
self.visible_on_icon = Icon(get_asset_path('assets/visibility_FILL0_wght700_GRAD0_opsz48.svg'), self)
|
||||
self.visible_off_icon = Icon(get_asset_path('assets/visibility_off_FILL0_wght700_GRAD0_opsz48.svg'), self)
|
||||
self.visible_on_icon = Icon(
|
||||
get_asset_path("assets/visibility_FILL0_wght700_GRAD0_opsz48.svg"), self
|
||||
)
|
||||
self.visible_off_icon = Icon(
|
||||
get_asset_path("assets/visibility_off_FILL0_wght700_GRAD0_opsz48.svg"), self
|
||||
)
|
||||
|
||||
self.setPlaceholderText('sk-...')
|
||||
self.setPlaceholderText("sk-...")
|
||||
self.setEchoMode(QLineEdit.EchoMode.Password)
|
||||
self.textChanged.connect(self.on_openai_api_key_changed)
|
||||
self.toggle_show_openai_api_key_action = self.addAction(self.visible_on_icon,
|
||||
QLineEdit.ActionPosition.TrailingPosition)
|
||||
self.toggle_show_openai_api_key_action.triggered.connect(self.on_toggle_show_action_triggered)
|
||||
self.toggle_show_openai_api_key_action = self.addAction(
|
||||
self.visible_on_icon, QLineEdit.ActionPosition.TrailingPosition
|
||||
)
|
||||
self.toggle_show_openai_api_key_action.triggered.connect(
|
||||
self.on_toggle_show_action_triggered
|
||||
)
|
||||
|
||||
def on_toggle_show_action_triggered(self):
|
||||
if self.echoMode() == QLineEdit.EchoMode.Password:
|
||||
|
|
|
@ -15,32 +15,40 @@ class GeneralPreferencesWidget(QWidget):
|
|||
openai_api_key_changed = pyqtSignal(str)
|
||||
default_export_file_name_changed = pyqtSignal(str)
|
||||
|
||||
def __init__(self, default_export_file_name: str, keyring_store=KeyringStore(),
|
||||
parent: Optional[QWidget] = None):
|
||||
def __init__(
|
||||
self,
|
||||
default_export_file_name: str,
|
||||
keyring_store=KeyringStore(),
|
||||
parent: Optional[QWidget] = None,
|
||||
):
|
||||
super().__init__(parent)
|
||||
|
||||
self.openai_api_key = keyring_store.get_password(
|
||||
KeyringStore.Key.OPENAI_API_KEY)
|
||||
KeyringStore.Key.OPENAI_API_KEY
|
||||
)
|
||||
|
||||
layout = QFormLayout(self)
|
||||
|
||||
self.openai_api_key_line_edit = OpenAIAPIKeyLineEdit(self.openai_api_key, self)
|
||||
self.openai_api_key_line_edit.key_changed.connect(
|
||||
self.on_openai_api_key_changed)
|
||||
self.on_openai_api_key_changed
|
||||
)
|
||||
|
||||
self.test_openai_api_key_button = QPushButton('Test')
|
||||
self.test_openai_api_key_button = QPushButton("Test")
|
||||
self.test_openai_api_key_button.clicked.connect(
|
||||
self.on_click_test_openai_api_key_button)
|
||||
self.on_click_test_openai_api_key_button
|
||||
)
|
||||
self.update_test_openai_api_key_button()
|
||||
|
||||
layout.addRow('OpenAI API Key', self.openai_api_key_line_edit)
|
||||
layout.addRow('', self.test_openai_api_key_button)
|
||||
layout.addRow("OpenAI API Key", self.openai_api_key_line_edit)
|
||||
layout.addRow("", self.test_openai_api_key_button)
|
||||
|
||||
default_export_file_name_line_edit = LineEdit(default_export_file_name, self)
|
||||
default_export_file_name_line_edit.textChanged.connect(
|
||||
self.default_export_file_name_changed)
|
||||
self.default_export_file_name_changed
|
||||
)
|
||||
default_export_file_name_line_edit.setMinimumWidth(200)
|
||||
layout.addRow('Default export file name', default_export_file_name_line_edit)
|
||||
layout.addRow("Default export file name", default_export_file_name_line_edit)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
||||
|
@ -60,12 +68,15 @@ class GeneralPreferencesWidget(QWidget):
|
|||
|
||||
def on_test_openai_api_key_success(self):
|
||||
self.test_openai_api_key_button.setEnabled(True)
|
||||
QMessageBox.information(self, 'OpenAI API Key Test',
|
||||
'Your API key is valid. Buzz will use this key to perform Whisper API transcriptions.')
|
||||
QMessageBox.information(
|
||||
self,
|
||||
"OpenAI API Key Test",
|
||||
"Your API key is valid. Buzz will use this key to perform Whisper API transcriptions.",
|
||||
)
|
||||
|
||||
def on_test_openai_api_key_failure(self, error: str):
|
||||
self.test_openai_api_key_button.setEnabled(True)
|
||||
QMessageBox.warning(self, 'OpenAI API Key Test', error)
|
||||
QMessageBox.warning(self, "OpenAI API Key Test", error)
|
||||
|
||||
def on_openai_api_key_changed(self, key: str):
|
||||
self.openai_api_key = key
|
||||
|
|
|
@ -1,34 +1,55 @@
|
|||
from typing import Optional
|
||||
|
||||
from PyQt6.QtCore import Qt, QThreadPool
|
||||
from PyQt6.QtWidgets import QWidget, QFormLayout, QTreeWidget, QTreeWidgetItem, \
|
||||
QPushButton, QMessageBox, QHBoxLayout
|
||||
from PyQt6.QtWidgets import (
|
||||
QWidget,
|
||||
QFormLayout,
|
||||
QTreeWidget,
|
||||
QTreeWidgetItem,
|
||||
QPushButton,
|
||||
QMessageBox,
|
||||
QHBoxLayout,
|
||||
)
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.model_loader import ModelType, WhisperModelSize, TranscriptionModel, ModelDownloader
|
||||
from buzz.model_loader import (
|
||||
ModelType,
|
||||
WhisperModelSize,
|
||||
TranscriptionModel,
|
||||
ModelDownloader,
|
||||
)
|
||||
from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog
|
||||
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
|
||||
|
||||
|
||||
class ModelsPreferencesWidget(QWidget):
|
||||
def __init__(self, progress_dialog_modality=Qt.WindowModality.WindowModal,
|
||||
parent: Optional[QWidget] = None):
|
||||
def __init__(
|
||||
self,
|
||||
progress_dialog_modality=Qt.WindowModality.WindowModal,
|
||||
parent: Optional[QWidget] = None,
|
||||
):
|
||||
super().__init__(parent)
|
||||
|
||||
self.model_downloader: Optional[ModelDownloader] = None
|
||||
self.model = TranscriptionModel(model_type=ModelType.WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY)
|
||||
self.model = TranscriptionModel(
|
||||
model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY
|
||||
)
|
||||
self.progress_dialog_modality = progress_dialog_modality
|
||||
|
||||
self.progress_dialog: Optional[ModelDownloadProgressDialog] = None
|
||||
|
||||
layout = QFormLayout()
|
||||
model_type_combo_box = ModelTypeComboBox(
|
||||
model_types=[ModelType.WHISPER, ModelType.WHISPER_CPP,
|
||||
ModelType.FASTER_WHISPER],
|
||||
default_model=self.model.model_type, parent=self)
|
||||
model_types=[
|
||||
ModelType.WHISPER,
|
||||
ModelType.WHISPER_CPP,
|
||||
ModelType.FASTER_WHISPER,
|
||||
],
|
||||
default_model=self.model.model_type,
|
||||
parent=self,
|
||||
)
|
||||
model_type_combo_box.changed.connect(self.on_model_type_changed)
|
||||
layout.addRow('Group', model_type_combo_box)
|
||||
layout.addRow("Group", model_type_combo_box)
|
||||
|
||||
self.model_list_widget = QTreeWidget()
|
||||
self.model_list_widget.setColumnCount(1)
|
||||
|
@ -37,20 +58,21 @@ class ModelsPreferencesWidget(QWidget):
|
|||
|
||||
buttons_layout = QHBoxLayout()
|
||||
|
||||
self.download_button = QPushButton(_('Download'))
|
||||
self.download_button.setObjectName('DownloadButton')
|
||||
self.download_button = QPushButton(_("Download"))
|
||||
self.download_button.setObjectName("DownloadButton")
|
||||
self.download_button.clicked.connect(self.on_download_button_clicked)
|
||||
buttons_layout.addWidget(self.download_button)
|
||||
|
||||
self.show_file_location_button = QPushButton(_('Show file location'))
|
||||
self.show_file_location_button.setObjectName('ShowFileLocationButton')
|
||||
self.show_file_location_button = QPushButton(_("Show file location"))
|
||||
self.show_file_location_button.setObjectName("ShowFileLocationButton")
|
||||
self.show_file_location_button.clicked.connect(
|
||||
self.on_show_file_location_button_clicked)
|
||||
self.on_show_file_location_button_clicked
|
||||
)
|
||||
buttons_layout.addWidget(self.show_file_location_button)
|
||||
buttons_layout.addStretch(1)
|
||||
|
||||
self.delete_button = QPushButton(_('Delete'))
|
||||
self.delete_button.setObjectName('DeleteButton')
|
||||
self.delete_button = QPushButton(_("Delete"))
|
||||
self.delete_button.setObjectName("DeleteButton")
|
||||
self.delete_button.clicked.connect(self.on_delete_button_clicked)
|
||||
buttons_layout.addWidget(self.delete_button)
|
||||
|
||||
|
@ -71,9 +93,10 @@ class ModelsPreferencesWidget(QWidget):
|
|||
|
||||
@staticmethod
|
||||
def can_delete_model(model: TranscriptionModel):
|
||||
return ((model.model_type == ModelType.WHISPER or
|
||||
model.model_type == ModelType.WHISPER_CPP) and
|
||||
model.get_local_model_path() is not None)
|
||||
return (
|
||||
model.model_type == ModelType.WHISPER
|
||||
or model.model_type == ModelType.WHISPER_CPP
|
||||
) and model.get_local_model_path() is not None
|
||||
|
||||
def reset(self):
|
||||
# reset buttons
|
||||
|
@ -85,20 +108,21 @@ class ModelsPreferencesWidget(QWidget):
|
|||
# reset model list
|
||||
self.model_list_widget.clear()
|
||||
downloaded_item = QTreeWidgetItem(self.model_list_widget)
|
||||
downloaded_item.setText(0, _('Downloaded'))
|
||||
downloaded_item.setText(0, _("Downloaded"))
|
||||
downloaded_item.setFlags(
|
||||
downloaded_item.flags() & ~Qt.ItemFlag.ItemIsSelectable)
|
||||
downloaded_item.flags() & ~Qt.ItemFlag.ItemIsSelectable
|
||||
)
|
||||
available_item = QTreeWidgetItem(self.model_list_widget)
|
||||
available_item.setText(0, _('Available for Download'))
|
||||
available_item.setFlags(
|
||||
available_item.flags() & ~Qt.ItemFlag.ItemIsSelectable)
|
||||
available_item.setText(0, _("Available for Download"))
|
||||
available_item.setFlags(available_item.flags() & ~Qt.ItemFlag.ItemIsSelectable)
|
||||
self.model_list_widget.addTopLevelItems([downloaded_item, available_item])
|
||||
self.model_list_widget.expandToDepth(2)
|
||||
self.model_list_widget.setHeaderHidden(True)
|
||||
self.model_list_widget.setAlternatingRowColors(True)
|
||||
for model_size in WhisperModelSize:
|
||||
model = TranscriptionModel(model_type=self.model.model_type,
|
||||
whisper_model_size=model_size)
|
||||
model = TranscriptionModel(
|
||||
model_type=self.model.model_type, whisper_model_size=model_size
|
||||
)
|
||||
model_path = model.get_local_model_path()
|
||||
parent = downloaded_item if model_path is not None else available_item
|
||||
item = QTreeWidgetItem(parent)
|
||||
|
@ -115,7 +139,9 @@ class ModelsPreferencesWidget(QWidget):
|
|||
def on_download_button_clicked(self):
|
||||
self.progress_dialog = ModelDownloadProgressDialog(
|
||||
model_type=self.model.model_type,
|
||||
modality=self.progress_dialog_modality, parent=self)
|
||||
modality=self.progress_dialog_modality,
|
||||
parent=self,
|
||||
)
|
||||
self.progress_dialog.canceled.connect(self.on_progress_dialog_canceled)
|
||||
|
||||
self.download_button.setEnabled(False)
|
||||
|
@ -128,8 +154,10 @@ class ModelsPreferencesWidget(QWidget):
|
|||
|
||||
def on_delete_button_clicked(self):
|
||||
reply = QMessageBox.question(
|
||||
self, _('Delete Model'),
|
||||
_('Are you sure you want to delete the selected model?'))
|
||||
self,
|
||||
_("Delete Model"),
|
||||
_("Are you sure you want to delete the selected model?"),
|
||||
)
|
||||
if reply == QMessageBox.StandardButton.Yes:
|
||||
self.model.delete_local_file()
|
||||
self.reset()
|
||||
|
@ -147,7 +175,7 @@ class ModelsPreferencesWidget(QWidget):
|
|||
self.progress_dialog = None
|
||||
self.download_button.setEnabled(True)
|
||||
self.reset()
|
||||
QMessageBox.warning(self, _('Error'), f'Download failed: {error}')
|
||||
QMessageBox.warning(self, _("Error"), f"Download failed: {error}")
|
||||
|
||||
def on_download_progress(self, progress: tuple):
|
||||
self.progress_dialog.set_value(float(progress[0]) / progress[1])
|
||||
|
|
|
@ -4,12 +4,15 @@ from PyQt6.QtCore import pyqtSignal
|
|||
from PyQt6.QtWidgets import QDialog, QWidget, QVBoxLayout, QTabWidget, QDialogButtonBox
|
||||
|
||||
from buzz.locale import _
|
||||
from buzz.widgets.preferences_dialog.general_preferences_widget import \
|
||||
GeneralPreferencesWidget
|
||||
from buzz.widgets.preferences_dialog.models_preferences_widget import \
|
||||
ModelsPreferencesWidget
|
||||
from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import \
|
||||
ShortcutsEditorPreferencesWidget
|
||||
from buzz.widgets.preferences_dialog.general_preferences_widget import (
|
||||
GeneralPreferencesWidget,
|
||||
)
|
||||
from buzz.widgets.preferences_dialog.models_preferences_widget import (
|
||||
ModelsPreferencesWidget,
|
||||
)
|
||||
from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import (
|
||||
ShortcutsEditorPreferencesWidget,
|
||||
)
|
||||
|
||||
|
||||
class PreferencesDialog(QDialog):
|
||||
|
@ -17,31 +20,38 @@ class PreferencesDialog(QDialog):
|
|||
openai_api_key_changed = pyqtSignal(str)
|
||||
default_export_file_name_changed = pyqtSignal(str)
|
||||
|
||||
def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str,
|
||||
parent: Optional[QWidget] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
shortcuts: Dict[str, str],
|
||||
default_export_file_name: str,
|
||||
parent: Optional[QWidget] = None,
|
||||
) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
self.setWindowTitle('Preferences')
|
||||
self.setWindowTitle("Preferences")
|
||||
|
||||
layout = QVBoxLayout(self)
|
||||
tab_widget = QTabWidget(self)
|
||||
|
||||
general_tab_widget = GeneralPreferencesWidget(
|
||||
default_export_file_name=default_export_file_name, parent=self)
|
||||
default_export_file_name=default_export_file_name, parent=self
|
||||
)
|
||||
general_tab_widget.openai_api_key_changed.connect(self.openai_api_key_changed)
|
||||
general_tab_widget.default_export_file_name_changed.connect(
|
||||
self.default_export_file_name_changed)
|
||||
tab_widget.addTab(general_tab_widget, _('General'))
|
||||
self.default_export_file_name_changed
|
||||
)
|
||||
tab_widget.addTab(general_tab_widget, _("General"))
|
||||
|
||||
models_tab_widget = ModelsPreferencesWidget(parent=self)
|
||||
tab_widget.addTab(models_tab_widget, _('Models'))
|
||||
tab_widget.addTab(models_tab_widget, _("Models"))
|
||||
|
||||
shortcuts_table_widget = ShortcutsEditorPreferencesWidget(shortcuts, self)
|
||||
shortcuts_table_widget.shortcuts_changed.connect(self.shortcuts_changed)
|
||||
tab_widget.addTab(shortcuts_table_widget, _('Shortcuts'))
|
||||
tab_widget.addTab(shortcuts_table_widget, _("Shortcuts"))
|
||||
|
||||
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
|
||||
QDialogButtonBox.StandardButton.Ok), self)
|
||||
button_box = QDialogButtonBox(
|
||||
QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Ok), self
|
||||
)
|
||||
button_box.accepted.connect(self.accept)
|
||||
button_box.rejected.connect(self.reject)
|
||||
|
||||
|
|
|
@ -18,12 +18,13 @@ class ShortcutsEditorPreferencesWidget(QWidget):
|
|||
|
||||
self.layout = QFormLayout(self)
|
||||
for shortcut in Shortcut:
|
||||
sequence_edit = SequenceEdit(shortcuts.get(shortcut.name, ''), self)
|
||||
sequence_edit = SequenceEdit(shortcuts.get(shortcut.name, ""), self)
|
||||
sequence_edit.keySequenceChanged.connect(
|
||||
self.get_key_sequence_changed(shortcut.name))
|
||||
self.get_key_sequence_changed(shortcut.name)
|
||||
)
|
||||
self.layout.addRow(shortcut.description, sequence_edit)
|
||||
|
||||
reset_to_defaults_button = QPushButton('Reset to Defaults', self)
|
||||
reset_to_defaults_button = QPushButton("Reset to Defaults", self)
|
||||
reset_to_defaults_button.setDefault(False)
|
||||
reset_to_defaults_button.setAutoDefault(False)
|
||||
reset_to_defaults_button.clicked.connect(self.reset_to_defaults)
|
||||
|
@ -41,8 +42,9 @@ class ShortcutsEditorPreferencesWidget(QWidget):
|
|||
self.shortcuts = Shortcut.get_default_shortcuts()
|
||||
|
||||
for i, shortcut in enumerate(Shortcut):
|
||||
sequence_edit = self.layout.itemAt(i,
|
||||
QFormLayout.ItemRole.FieldRole).widget()
|
||||
sequence_edit = self.layout.itemAt(
|
||||
i, QFormLayout.ItemRole.FieldRole
|
||||
).widget()
|
||||
assert isinstance(sequence_edit, SequenceEdit)
|
||||
sequence_edit.setKeySequence(QKeySequence(self.shortcuts[shortcut.name]))
|
||||
|
||||
|
|
|
@ -10,8 +10,8 @@ class SequenceEdit(QKeySequenceEdit):
|
|||
def __init__(self, sequence: str, parent: Optional[QWidget] = None):
|
||||
super().__init__(sequence, parent)
|
||||
self.setClearButtonEnabled(True)
|
||||
if platform.system() == 'Darwin':
|
||||
self.setStyleSheet('QLineEdit:focus { border: 2px solid #4d90fe; }')
|
||||
if platform.system() == "Darwin":
|
||||
self.setStyleSheet("QLineEdit:focus { border: 2px solid #4d90fe; }")
|
||||
|
||||
def keyPressEvent(self, event: QtGui.QKeyEvent) -> None:
|
||||
key = event.key()
|
||||
|
@ -23,7 +23,12 @@ class SequenceEdit(QKeySequenceEdit):
|
|||
return
|
||||
|
||||
# Ignore pressing *only* modifier keys
|
||||
if key == Qt.Key.Key_Control or key == Qt.Key.Key_Shift or key == Qt.Key.Key_Alt or key == Qt.Key.Key_Meta:
|
||||
if (
|
||||
key == Qt.Key.Key_Control
|
||||
or key == Qt.Key.Key_Shift
|
||||
or key == Qt.Key.Key_Alt
|
||||
or key == Qt.Key.Key_Meta
|
||||
):
|
||||
return
|
||||
|
||||
super().keyPressEvent(event)
|
||||
|
|
|
@ -11,7 +11,7 @@ class ToolBar(QToolBar):
|
|||
super().__init__(parent)
|
||||
|
||||
self.setIconSize(QSize(18, 18))
|
||||
self.setStyleSheet('QToolButton{margin: 6px 3px;}')
|
||||
self.setStyleSheet("QToolButton{margin: 6px 3px;}")
|
||||
self.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly)
|
||||
|
||||
def addAction(self, action: QtGui.QAction) -> None:
|
||||
|
@ -23,6 +23,7 @@ class ToolBar(QToolBar):
|
|||
self.fix_spacing_on_mac()
|
||||
|
||||
def fix_spacing_on_mac(self):
|
||||
if platform.system() == 'Darwin':
|
||||
if platform.system() == "Darwin":
|
||||
self.widgetForAction(self.actions()[0]).setStyleSheet(
|
||||
'QToolButton { margin-left: 9px; margin-right: 1px; }')
|
||||
"QToolButton { margin-left: 9px; margin-right: 1px; }"
|
||||
)
|
||||
|
|
|
@ -5,4 +5,4 @@ from PyQt6.QtWidgets import QPushButton, QWidget
|
|||
|
||||
class AdvancedSettingsButton(QPushButton):
|
||||
def __init__(self, parent: Optional[QWidget]) -> None:
|
||||
super().__init__('Advanced...', parent)
|
||||
super().__init__("Advanced...", parent)
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
from PyQt6.QtCore import pyqtSignal
|
||||
from PyQt6.QtWidgets import QDialog, QWidget, QDialogButtonBox, QFormLayout, \
|
||||
QPlainTextEdit
|
||||
from PyQt6.QtWidgets import (
|
||||
QDialog,
|
||||
QWidget,
|
||||
QDialogButtonBox,
|
||||
QFormLayout,
|
||||
QPlainTextEdit,
|
||||
)
|
||||
|
||||
from buzz.widgets.transcriber.temperature_validator import TemperatureValidator
|
||||
from buzz.locale import _
|
||||
|
@ -13,39 +18,48 @@ class AdvancedSettingsDialog(QDialog):
|
|||
transcription_options: TranscriptionOptions
|
||||
transcription_options_changed = pyqtSignal(TranscriptionOptions)
|
||||
|
||||
def __init__(self, transcription_options: TranscriptionOptions, parent: QWidget | None = None):
|
||||
def __init__(
|
||||
self, transcription_options: TranscriptionOptions, parent: QWidget | None = None
|
||||
):
|
||||
super().__init__(parent)
|
||||
|
||||
self.transcription_options = transcription_options
|
||||
|
||||
self.setWindowTitle(_('Advanced Settings'))
|
||||
self.setWindowTitle(_("Advanced Settings"))
|
||||
|
||||
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
|
||||
QDialogButtonBox.StandardButton.Ok), self)
|
||||
button_box = QDialogButtonBox(
|
||||
QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Ok), self
|
||||
)
|
||||
button_box.accepted.connect(self.accept)
|
||||
|
||||
layout = QFormLayout(self)
|
||||
|
||||
default_temperature_text = ', '.join(
|
||||
[str(temp) for temp in transcription_options.temperature])
|
||||
default_temperature_text = ", ".join(
|
||||
[str(temp) for temp in transcription_options.temperature]
|
||||
)
|
||||
self.temperature_line_edit = LineEdit(default_temperature_text, self)
|
||||
self.temperature_line_edit.setPlaceholderText(
|
||||
_('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"'))
|
||||
_('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"')
|
||||
)
|
||||
self.temperature_line_edit.setMinimumWidth(170)
|
||||
self.temperature_line_edit.textChanged.connect(
|
||||
self.on_temperature_changed)
|
||||
self.temperature_line_edit.textChanged.connect(self.on_temperature_changed)
|
||||
self.temperature_line_edit.setValidator(TemperatureValidator(self))
|
||||
self.temperature_line_edit.setEnabled(transcription_options.model.model_type == ModelType.WHISPER)
|
||||
self.temperature_line_edit.setEnabled(
|
||||
transcription_options.model.model_type == ModelType.WHISPER
|
||||
)
|
||||
|
||||
self.initial_prompt_text_edit = QPlainTextEdit(
|
||||
transcription_options.initial_prompt, self)
|
||||
transcription_options.initial_prompt, self
|
||||
)
|
||||
self.initial_prompt_text_edit.textChanged.connect(
|
||||
self.on_initial_prompt_changed)
|
||||
self.on_initial_prompt_changed
|
||||
)
|
||||
self.initial_prompt_text_edit.setEnabled(
|
||||
transcription_options.model.model_type == ModelType.WHISPER)
|
||||
transcription_options.model.model_type == ModelType.WHISPER
|
||||
)
|
||||
|
||||
layout.addRow(_('Temperature:'), self.temperature_line_edit)
|
||||
layout.addRow(_('Initial Prompt:'), self.initial_prompt_text_edit)
|
||||
layout.addRow(_("Temperature:"), self.temperature_line_edit)
|
||||
layout.addRow(_("Initial Prompt:"), self.initial_prompt_text_edit)
|
||||
layout.addWidget(button_box)
|
||||
|
||||
self.setLayout(layout)
|
||||
|
@ -53,12 +67,14 @@ class AdvancedSettingsDialog(QDialog):
|
|||
|
||||
def on_temperature_changed(self, text: str):
|
||||
try:
|
||||
temperatures = [float(temp.strip()) for temp in text.split(',')]
|
||||
temperatures = [float(temp.strip()) for temp in text.split(",")]
|
||||
self.transcription_options.temperature = tuple(temperatures)
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def on_initial_prompt_changed(self):
|
||||
self.transcription_options.initial_prompt = self.initial_prompt_text_edit.toPlainText()
|
||||
self.transcription_options.initial_prompt = (
|
||||
self.initial_prompt_text_edit.toPlainText()
|
||||
)
|
||||
self.transcription_options_changed.emit(self.transcription_options)
|
||||
|
|
|
@ -2,8 +2,14 @@ from typing import Optional, List, Tuple
|
|||
|
||||
from PyQt6 import QtGui
|
||||
from PyQt6.QtCore import pyqtSignal, Qt, QThreadPool
|
||||
from PyQt6.QtWidgets import QWidget, QVBoxLayout, QCheckBox, QFormLayout, QHBoxLayout, \
|
||||
QPushButton
|
||||
from PyQt6.QtWidgets import (
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QCheckBox,
|
||||
QFormLayout,
|
||||
QHBoxLayout,
|
||||
QPushButton,
|
||||
)
|
||||
|
||||
from buzz.dialogs import show_model_download_error_dialog
|
||||
from buzz.locale import _
|
||||
|
@ -11,11 +17,17 @@ from buzz.model_loader import ModelDownloader, TranscriptionModel, ModelType
|
|||
from buzz.paths import file_paths_as_title
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.store.keyring_store import KeyringStore
|
||||
from buzz.transcriber import FileTranscriptionOptions, TranscriptionOptions, Task, \
|
||||
DEFAULT_WHISPER_TEMPERATURE, OutputFormat
|
||||
from buzz.transcriber import (
|
||||
FileTranscriptionOptions,
|
||||
TranscriptionOptions,
|
||||
Task,
|
||||
DEFAULT_WHISPER_TEMPERATURE,
|
||||
OutputFormat,
|
||||
)
|
||||
from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog
|
||||
from buzz.widgets.transcriber.transcription_options_group_box import \
|
||||
TranscriptionOptionsGroupBox
|
||||
from buzz.widgets.transcriber.transcription_options_group_box import (
|
||||
TranscriptionOptionsGroupBox,
|
||||
)
|
||||
|
||||
|
||||
class FileTranscriberWidget(QWidget):
|
||||
|
@ -29,74 +41,100 @@ class FileTranscriberWidget(QWidget):
|
|||
openai_access_token_changed = pyqtSignal(str)
|
||||
settings = Settings()
|
||||
|
||||
def __init__(self, file_paths: List[str],
|
||||
default_output_file_name: str,
|
||||
parent: Optional[QWidget] = None,
|
||||
flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
file_paths: List[str],
|
||||
default_output_file_name: str,
|
||||
parent: Optional[QWidget] = None,
|
||||
flags: Qt.WindowType = Qt.WindowType.Widget,
|
||||
) -> None:
|
||||
super().__init__(parent, flags)
|
||||
|
||||
self.setWindowTitle(file_paths_as_title(file_paths))
|
||||
|
||||
openai_access_token = KeyringStore().get_password(
|
||||
KeyringStore.Key.OPENAI_API_KEY)
|
||||
KeyringStore.Key.OPENAI_API_KEY
|
||||
)
|
||||
|
||||
self.file_paths = file_paths
|
||||
default_language = self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value='')
|
||||
key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value=""
|
||||
)
|
||||
self.transcription_options = TranscriptionOptions(
|
||||
openai_access_token=openai_access_token,
|
||||
model=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_MODEL,
|
||||
default_value=TranscriptionModel()),
|
||||
task=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_TASK,
|
||||
default_value=Task.TRANSCRIBE),
|
||||
language=default_language if default_language != '' else None,
|
||||
model=self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_MODEL,
|
||||
default_value=TranscriptionModel(),
|
||||
),
|
||||
task=self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_TASK, default_value=Task.TRANSCRIBE
|
||||
),
|
||||
language=default_language if default_language != "" else None,
|
||||
initial_prompt=self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value=''),
|
||||
key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value=""
|
||||
),
|
||||
temperature=self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_TEMPERATURE,
|
||||
default_value=DEFAULT_WHISPER_TEMPERATURE),
|
||||
default_value=DEFAULT_WHISPER_TEMPERATURE,
|
||||
),
|
||||
word_level_timings=self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
|
||||
default_value=False))
|
||||
default_value=False,
|
||||
),
|
||||
)
|
||||
default_export_format_states: List[str] = self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
|
||||
default_value=[])
|
||||
key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS, default_value=[]
|
||||
)
|
||||
self.file_transcription_options = FileTranscriptionOptions(
|
||||
file_paths=self.file_paths,
|
||||
output_formats=set([OutputFormat(output_format) for output_format in
|
||||
default_export_format_states]),
|
||||
default_output_file_name=default_output_file_name)
|
||||
output_formats=set(
|
||||
[
|
||||
OutputFormat(output_format)
|
||||
for output_format in default_export_format_states
|
||||
]
|
||||
),
|
||||
default_output_file_name=default_output_file_name,
|
||||
)
|
||||
|
||||
layout = QVBoxLayout(self)
|
||||
|
||||
transcription_options_group_box = TranscriptionOptionsGroupBox(
|
||||
default_transcription_options=self.transcription_options, parent=self)
|
||||
default_transcription_options=self.transcription_options, parent=self
|
||||
)
|
||||
transcription_options_group_box.transcription_options_changed.connect(
|
||||
self.on_transcription_options_changed)
|
||||
self.on_transcription_options_changed
|
||||
)
|
||||
|
||||
self.word_level_timings_checkbox = QCheckBox(_('Word-level timings'))
|
||||
self.word_level_timings_checkbox = QCheckBox(_("Word-level timings"))
|
||||
self.word_level_timings_checkbox.setChecked(
|
||||
self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
|
||||
default_value=False))
|
||||
self.settings.value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
|
||||
default_value=False,
|
||||
)
|
||||
)
|
||||
self.word_level_timings_checkbox.stateChanged.connect(
|
||||
self.on_word_level_timings_changed)
|
||||
self.on_word_level_timings_changed
|
||||
)
|
||||
|
||||
file_transcription_layout = QFormLayout()
|
||||
file_transcription_layout.addRow('', self.word_level_timings_checkbox)
|
||||
file_transcription_layout.addRow("", self.word_level_timings_checkbox)
|
||||
|
||||
export_format_layout = QHBoxLayout()
|
||||
for output_format in OutputFormat:
|
||||
export_format_checkbox = QCheckBox(f'{output_format.value.upper()}',
|
||||
parent=self)
|
||||
export_format_checkbox = QCheckBox(
|
||||
f"{output_format.value.upper()}", parent=self
|
||||
)
|
||||
export_format_checkbox.setChecked(
|
||||
output_format in self.file_transcription_options.output_formats)
|
||||
output_format in self.file_transcription_options.output_formats
|
||||
)
|
||||
export_format_checkbox.stateChanged.connect(
|
||||
self.get_on_checkbox_state_changed_callback(output_format))
|
||||
self.get_on_checkbox_state_changed_callback(output_format)
|
||||
)
|
||||
export_format_layout.addWidget(export_format_checkbox)
|
||||
|
||||
file_transcription_layout.addRow('Export:', export_format_layout)
|
||||
file_transcription_layout.addRow("Export:", export_format_layout)
|
||||
|
||||
self.run_button = QPushButton(_('Run'), self)
|
||||
self.run_button = QPushButton(_("Run"), self)
|
||||
self.run_button.setDefault(True)
|
||||
self.run_button.clicked.connect(self.on_click_run)
|
||||
|
||||
|
@ -116,15 +154,19 @@ class FileTranscriberWidget(QWidget):
|
|||
|
||||
return on_checkbox_state_changed
|
||||
|
||||
def on_transcription_options_changed(self,
|
||||
transcription_options: TranscriptionOptions):
|
||||
def on_transcription_options_changed(
|
||||
self, transcription_options: TranscriptionOptions
|
||||
):
|
||||
self.transcription_options = transcription_options
|
||||
self.word_level_timings_checkbox.setDisabled(
|
||||
self.transcription_options.model.model_type == ModelType.HUGGING_FACE or
|
||||
self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API)
|
||||
if self.transcription_options.openai_access_token != '':
|
||||
self.transcription_options.model.model_type == ModelType.HUGGING_FACE
|
||||
or self.transcription_options.model.model_type
|
||||
== ModelType.OPEN_AI_WHISPER_API
|
||||
)
|
||||
if self.transcription_options.openai_access_token != "":
|
||||
self.openai_access_token_changed.emit(
|
||||
self.transcription_options.openai_access_token)
|
||||
self.transcription_options.openai_access_token
|
||||
)
|
||||
|
||||
def on_click_run(self):
|
||||
self.run_button.setDisabled(True)
|
||||
|
@ -143,8 +185,9 @@ class FileTranscriberWidget(QWidget):
|
|||
def on_model_loaded(self, model_path: str):
|
||||
self.reset_transcriber_controls()
|
||||
|
||||
self.triggered.emit((self.transcription_options,
|
||||
self.file_transcription_options, model_path))
|
||||
self.triggered.emit(
|
||||
(self.transcription_options, self.file_transcription_options, model_path)
|
||||
)
|
||||
self.close()
|
||||
|
||||
def on_download_model_progress(self, progress: Tuple[float, float]):
|
||||
|
@ -152,13 +195,16 @@ class FileTranscriberWidget(QWidget):
|
|||
|
||||
if self.model_download_progress_dialog is None:
|
||||
self.model_download_progress_dialog = ModelDownloadProgressDialog(
|
||||
model_type=self.transcription_options.model.model_type, parent=self)
|
||||
model_type=self.transcription_options.model.model_type, parent=self
|
||||
)
|
||||
self.model_download_progress_dialog.canceled.connect(
|
||||
self.on_cancel_model_progress_dialog)
|
||||
self.on_cancel_model_progress_dialog
|
||||
)
|
||||
|
||||
if self.model_download_progress_dialog is not None:
|
||||
self.model_download_progress_dialog.set_value(
|
||||
fraction_completed=current_size / total_size)
|
||||
fraction_completed=current_size / total_size
|
||||
)
|
||||
|
||||
def on_download_model_error(self, error: str):
|
||||
self.reset_model_download()
|
||||
|
@ -179,26 +225,41 @@ class FileTranscriberWidget(QWidget):
|
|||
self.model_download_progress_dialog = None
|
||||
|
||||
def on_word_level_timings_changed(self, value: int):
|
||||
self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value
|
||||
self.transcription_options.word_level_timings = (
|
||||
value == Qt.CheckState.Checked.value
|
||||
)
|
||||
|
||||
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
|
||||
if self.model_loader is not None:
|
||||
self.model_loader.cancel()
|
||||
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_LANGUAGE,
|
||||
self.transcription_options.language)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TASK,
|
||||
self.transcription_options.task)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TEMPERATURE,
|
||||
self.transcription_options.temperature)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT,
|
||||
self.transcription_options.initial_prompt)
|
||||
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_MODEL,
|
||||
self.transcription_options.model)
|
||||
self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
|
||||
value=self.transcription_options.word_level_timings)
|
||||
self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
|
||||
value=[export_format.value for export_format in
|
||||
self.file_transcription_options.output_formats])
|
||||
self.settings.set_value(
|
||||
Settings.Key.FILE_TRANSCRIBER_LANGUAGE, self.transcription_options.language
|
||||
)
|
||||
self.settings.set_value(
|
||||
Settings.Key.FILE_TRANSCRIBER_TASK, self.transcription_options.task
|
||||
)
|
||||
self.settings.set_value(
|
||||
Settings.Key.FILE_TRANSCRIBER_TEMPERATURE,
|
||||
self.transcription_options.temperature,
|
||||
)
|
||||
self.settings.set_value(
|
||||
Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT,
|
||||
self.transcription_options.initial_prompt,
|
||||
)
|
||||
self.settings.set_value(
|
||||
Settings.Key.FILE_TRANSCRIBER_MODEL, self.transcription_options.model
|
||||
)
|
||||
self.settings.set_value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
|
||||
value=self.transcription_options.word_level_timings,
|
||||
)
|
||||
self.settings.set_value(
|
||||
key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
|
||||
value=[
|
||||
export_format.value
|
||||
for export_format in self.file_transcription_options.output_formats
|
||||
],
|
||||
)
|
||||
|
||||
super().closeEvent(event)
|
||||
|
|
|
@ -2,8 +2,17 @@ import json
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from PyQt6.QtCore import pyqtSignal, QTimer, Qt, QMetaObject, QUrl, QUrlQuery, QPoint, \
|
||||
QObject, QEvent
|
||||
from PyQt6.QtCore import (
|
||||
pyqtSignal,
|
||||
QTimer,
|
||||
Qt,
|
||||
QMetaObject,
|
||||
QUrl,
|
||||
QUrlQuery,
|
||||
QPoint,
|
||||
QObject,
|
||||
QEvent,
|
||||
)
|
||||
from PyQt6.QtGui import QKeyEvent
|
||||
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply
|
||||
from PyQt6.QtWidgets import QListWidget, QWidget, QAbstractItemView, QListWidgetItem
|
||||
|
@ -16,12 +25,15 @@ class HuggingFaceSearchLineEdit(LineEdit):
|
|||
model_selected = pyqtSignal(str)
|
||||
popup: QListWidget
|
||||
|
||||
def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None,
|
||||
parent: Optional[QWidget] = None):
|
||||
super().__init__('', parent)
|
||||
def __init__(
|
||||
self,
|
||||
network_access_manager: Optional[QNetworkAccessManager] = None,
|
||||
parent: Optional[QWidget] = None,
|
||||
):
|
||||
super().__init__("", parent)
|
||||
|
||||
self.setMinimumWidth(150)
|
||||
self.setPlaceholderText('openai/whisper-tiny')
|
||||
self.setPlaceholderText("openai/whisper-tiny")
|
||||
|
||||
self.timer = QTimer(self)
|
||||
self.timer.setSingleShot(True)
|
||||
|
@ -56,7 +68,7 @@ class HuggingFaceSearchLineEdit(LineEdit):
|
|||
|
||||
item = self.popup.currentItem()
|
||||
self.setText(item.text())
|
||||
QMetaObject.invokeMethod(self, 'returnPressed')
|
||||
QMetaObject.invokeMethod(self, "returnPressed")
|
||||
self.model_selected.emit(item.data(Qt.ItemDataRole.UserRole))
|
||||
|
||||
def fetch_models(self):
|
||||
|
@ -79,7 +91,9 @@ class HuggingFaceSearchLineEdit(LineEdit):
|
|||
|
||||
def on_request_response(self, network_reply: QNetworkReply):
|
||||
if network_reply.error() != QNetworkReply.NetworkError.NoError:
|
||||
logging.debug('Error fetching Hugging Face models: %s', network_reply.error())
|
||||
logging.debug(
|
||||
"Error fetching Hugging Face models: %s", network_reply.error()
|
||||
)
|
||||
return
|
||||
|
||||
models = json.loads(network_reply.readAll().data())
|
||||
|
@ -88,7 +102,7 @@ class HuggingFaceSearchLineEdit(LineEdit):
|
|||
self.popup.clear()
|
||||
|
||||
for model in models:
|
||||
model_id = model.get('id')
|
||||
model_id = model.get("id")
|
||||
|
||||
item = QListWidgetItem(self.popup)
|
||||
item.setText(model_id)
|
||||
|
@ -96,14 +110,16 @@ class HuggingFaceSearchLineEdit(LineEdit):
|
|||
|
||||
self.popup.setCurrentItem(self.popup.item(0))
|
||||
self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20)
|
||||
self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll
|
||||
self.popup.setFixedHeight(
|
||||
self.popup.sizeHintForRow(0) * min(len(models), 8)
|
||||
) # show max 8 models, then scroll
|
||||
self.popup.setUpdatesEnabled(True)
|
||||
self.popup.move(self.mapToGlobal(QPoint(0, self.height())))
|
||||
self.popup.setFocus()
|
||||
self.popup.show()
|
||||
|
||||
def eventFilter(self, target: QObject, event: QEvent):
|
||||
if hasattr(self, 'popup') is False or target != self.popup:
|
||||
if hasattr(self, "popup") is False or target != self.popup:
|
||||
return False
|
||||
|
||||
if event.type() == QEvent.Type.MouseButtonPress:
|
||||
|
@ -123,8 +139,14 @@ class HuggingFaceSearchLineEdit(LineEdit):
|
|||
self.popup.hide()
|
||||
return True
|
||||
|
||||
if key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, Qt.Key.Key_End, Qt.Key.Key_PageUp,
|
||||
Qt.Key.Key_PageDown]:
|
||||
if key in [
|
||||
Qt.Key.Key_Up,
|
||||
Qt.Key.Key_Down,
|
||||
Qt.Key.Key_Home,
|
||||
Qt.Key.Key_End,
|
||||
Qt.Key.Key_PageUp,
|
||||
Qt.Key.Key_PageDown,
|
||||
]:
|
||||
return False
|
||||
|
||||
self.setFocus()
|
||||
|
|
|
@ -9,20 +9,25 @@ from buzz.transcriber import LANGUAGES
|
|||
|
||||
class LanguagesComboBox(QComboBox):
|
||||
"""LanguagesComboBox displays a list of languages available to use with Whisper"""
|
||||
|
||||
# language is a language key from whisper.tokenizer.LANGUAGES or '' for "detect language"
|
||||
languageChanged = pyqtSignal(str)
|
||||
|
||||
def __init__(self, default_language: Optional[str], parent: Optional[QWidget] = None) -> None:
|
||||
def __init__(
|
||||
self, default_language: Optional[str], parent: Optional[QWidget] = None
|
||||
) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
whisper_languages = sorted(
|
||||
[(lang, LANGUAGES[lang].title()) for lang in LANGUAGES], key=lambda lang: lang[1])
|
||||
self.languages = [('', _('Detect Language'))] + whisper_languages
|
||||
[(lang, LANGUAGES[lang].title()) for lang in LANGUAGES],
|
||||
key=lambda lang: lang[1],
|
||||
)
|
||||
self.languages = [("", _("Detect Language"))] + whisper_languages
|
||||
|
||||
self.addItems([lang[1] for lang in self.languages])
|
||||
self.currentIndexChanged.connect(self.on_index_changed)
|
||||
|
||||
default_language_key = default_language if default_language != '' else None
|
||||
default_language_key = default_language if default_language != "" else None
|
||||
for i, lang in enumerate(self.languages):
|
||||
if lang[0] == default_language_key:
|
||||
self.setCurrentIndex(i)
|
||||
|
|
|
@ -8,6 +8,7 @@ from buzz.transcriber import Task
|
|||
|
||||
class TasksComboBox(QComboBox):
|
||||
"""TasksComboBox displays a list of tasks available to use with Whisper"""
|
||||
|
||||
taskChanged = pyqtSignal(Task)
|
||||
|
||||
def __init__(self, default_task: Task, parent: Optional[QWidget], *args) -> None:
|
||||
|
|
|
@ -8,10 +8,12 @@ class TemperatureValidator(QValidator):
|
|||
def __init__(self, parent: Optional[QObject] = ...) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
def validate(self, text: str, cursor_position: int) -> Tuple['QValidator.State', str, int]:
|
||||
def validate(
|
||||
self, text: str, cursor_position: int
|
||||
) -> Tuple["QValidator.State", str, int]:
|
||||
try:
|
||||
temp_strings = [temp.strip() for temp in text.split(',')]
|
||||
if temp_strings[-1] == '':
|
||||
temp_strings = [temp.strip() for temp in text.split(",")]
|
||||
if temp_strings[-1] == "":
|
||||
return QValidator.State.Intermediate, text, cursor_position
|
||||
_ = [float(temp) for temp in temp_strings]
|
||||
return QValidator.State.Acceptable, text, cursor_position
|
||||
|
|
|
@ -10,8 +10,9 @@ from buzz.widgets.model_type_combo_box import ModelTypeComboBox
|
|||
from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
|
||||
from buzz.widgets.transcriber.advanced_settings_button import AdvancedSettingsButton
|
||||
from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog
|
||||
from buzz.widgets.transcriber.hugging_face_search_line_edit import \
|
||||
HuggingFaceSearchLineEdit
|
||||
from buzz.widgets.transcriber.hugging_face_search_line_edit import (
|
||||
HuggingFaceSearchLineEdit,
|
||||
)
|
||||
from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox
|
||||
from buzz.widgets.transcriber.tasks_combo_box import TasksComboBox
|
||||
|
||||
|
@ -21,64 +22,70 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
transcription_options_changed = pyqtSignal(TranscriptionOptions)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
|
||||
model_types: Optional[List[ModelType]] = None,
|
||||
parent: Optional[QWidget] = None):
|
||||
super().__init__(title='', parent=parent)
|
||||
self,
|
||||
default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
|
||||
model_types: Optional[List[ModelType]] = None,
|
||||
parent: Optional[QWidget] = None,
|
||||
):
|
||||
super().__init__(title="", parent=parent)
|
||||
self.transcription_options = default_transcription_options
|
||||
|
||||
self.form_layout = QFormLayout(self)
|
||||
|
||||
self.tasks_combo_box = TasksComboBox(
|
||||
default_task=self.transcription_options.task,
|
||||
parent=self)
|
||||
default_task=self.transcription_options.task, parent=self
|
||||
)
|
||||
self.tasks_combo_box.taskChanged.connect(self.on_task_changed)
|
||||
|
||||
self.languages_combo_box = LanguagesComboBox(
|
||||
default_language=self.transcription_options.language,
|
||||
parent=self)
|
||||
self.languages_combo_box.languageChanged.connect(
|
||||
self.on_language_changed)
|
||||
default_language=self.transcription_options.language, parent=self
|
||||
)
|
||||
self.languages_combo_box.languageChanged.connect(self.on_language_changed)
|
||||
|
||||
self.advanced_settings_button = AdvancedSettingsButton(self)
|
||||
self.advanced_settings_button.clicked.connect(
|
||||
self.open_advanced_settings)
|
||||
self.advanced_settings_button.clicked.connect(self.open_advanced_settings)
|
||||
|
||||
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
|
||||
self.hugging_face_search_line_edit.model_selected.connect(
|
||||
self.on_hugging_face_model_changed)
|
||||
self.on_hugging_face_model_changed
|
||||
)
|
||||
|
||||
self.model_type_combo_box = ModelTypeComboBox(model_types=model_types,
|
||||
default_model=default_transcription_options.model.model_type,
|
||||
parent=self)
|
||||
self.model_type_combo_box = ModelTypeComboBox(
|
||||
model_types=model_types,
|
||||
default_model=default_transcription_options.model.model_type,
|
||||
parent=self,
|
||||
)
|
||||
self.model_type_combo_box.changed.connect(self.on_model_type_changed)
|
||||
|
||||
self.whisper_model_size_combo_box = QComboBox(self)
|
||||
self.whisper_model_size_combo_box.addItems(
|
||||
[size.value.title() for size in WhisperModelSize])
|
||||
[size.value.title() for size in WhisperModelSize]
|
||||
)
|
||||
if default_transcription_options.model.whisper_model_size is not None:
|
||||
self.whisper_model_size_combo_box.setCurrentText(
|
||||
default_transcription_options.model.whisper_model_size.value.title())
|
||||
default_transcription_options.model.whisper_model_size.value.title()
|
||||
)
|
||||
self.whisper_model_size_combo_box.currentTextChanged.connect(
|
||||
self.on_whisper_model_size_changed)
|
||||
self.on_whisper_model_size_changed
|
||||
)
|
||||
|
||||
self.openai_access_token_edit = OpenAIAPIKeyLineEdit(
|
||||
key=default_transcription_options.openai_access_token,
|
||||
parent=self)
|
||||
key=default_transcription_options.openai_access_token, parent=self
|
||||
)
|
||||
self.openai_access_token_edit.key_changed.connect(
|
||||
self.on_openai_access_token_edit_changed)
|
||||
self.on_openai_access_token_edit_changed
|
||||
)
|
||||
|
||||
self.form_layout.addRow(_('Model:'), self.model_type_combo_box)
|
||||
self.form_layout.addRow('', self.whisper_model_size_combo_box)
|
||||
self.form_layout.addRow('', self.hugging_face_search_line_edit)
|
||||
self.form_layout.addRow('Access Token:', self.openai_access_token_edit)
|
||||
self.form_layout.addRow(_('Task:'), self.tasks_combo_box)
|
||||
self.form_layout.addRow(_('Language:'), self.languages_combo_box)
|
||||
self.form_layout.addRow(_("Model:"), self.model_type_combo_box)
|
||||
self.form_layout.addRow("", self.whisper_model_size_combo_box)
|
||||
self.form_layout.addRow("", self.hugging_face_search_line_edit)
|
||||
self.form_layout.addRow("Access Token:", self.openai_access_token_edit)
|
||||
self.form_layout.addRow(_("Task:"), self.tasks_combo_box)
|
||||
self.form_layout.addRow(_("Language:"), self.languages_combo_box)
|
||||
|
||||
self.reset_visible_rows()
|
||||
|
||||
self.form_layout.addRow('', self.advanced_settings_button)
|
||||
self.form_layout.addRow("", self.advanced_settings_button)
|
||||
|
||||
self.setLayout(self.form_layout)
|
||||
|
||||
|
@ -104,26 +111,33 @@ class TranscriptionOptionsGroupBox(QGroupBox):
|
|||
|
||||
def open_advanced_settings(self):
|
||||
dialog = AdvancedSettingsDialog(
|
||||
transcription_options=self.transcription_options, parent=self)
|
||||
transcription_options=self.transcription_options, parent=self
|
||||
)
|
||||
dialog.transcription_options_changed.connect(
|
||||
self.on_transcription_options_changed)
|
||||
self.on_transcription_options_changed
|
||||
)
|
||||
dialog.exec()
|
||||
|
||||
def on_transcription_options_changed(self,
|
||||
transcription_options: TranscriptionOptions):
|
||||
def on_transcription_options_changed(
|
||||
self, transcription_options: TranscriptionOptions
|
||||
):
|
||||
self.transcription_options = transcription_options
|
||||
self.transcription_options_changed.emit(transcription_options)
|
||||
|
||||
def reset_visible_rows(self):
|
||||
model_type = self.transcription_options.model.model_type
|
||||
self.form_layout.setRowVisible(self.hugging_face_search_line_edit,
|
||||
model_type == ModelType.HUGGING_FACE)
|
||||
self.form_layout.setRowVisible(self.whisper_model_size_combo_box,
|
||||
(model_type == ModelType.WHISPER) or (
|
||||
model_type == ModelType.WHISPER_CPP) or (
|
||||
model_type == ModelType.FASTER_WHISPER))
|
||||
self.form_layout.setRowVisible(self.openai_access_token_edit,
|
||||
model_type == ModelType.OPEN_AI_WHISPER_API)
|
||||
self.form_layout.setRowVisible(
|
||||
self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE
|
||||
)
|
||||
self.form_layout.setRowVisible(
|
||||
self.whisper_model_size_combo_box,
|
||||
(model_type == ModelType.WHISPER)
|
||||
or (model_type == ModelType.WHISPER_CPP)
|
||||
or (model_type == ModelType.FASTER_WHISPER),
|
||||
)
|
||||
self.form_layout.setRowVisible(
|
||||
self.openai_access_token_edit, model_type == ModelType.OPEN_AI_WHISPER_API
|
||||
)
|
||||
|
||||
def on_model_type_changed(self, model_type: ModelType):
|
||||
self.transcription_options.model.model_type = model_type
|
||||
|
|
|
@ -27,9 +27,10 @@ class TranscriptionSegmentsEditorWidget(QTableWidget):
|
|||
self.setColumnCount(3)
|
||||
|
||||
self.verticalHeader().hide()
|
||||
self.setHorizontalHeaderLabels([_('Start'), _('End'), _('Text')])
|
||||
self.horizontalHeader().setSectionResizeMode(2,
|
||||
QHeaderView.ResizeMode.ResizeToContents)
|
||||
self.setHorizontalHeaderLabels([_("Start"), _("End"), _("Text")])
|
||||
self.horizontalHeader().setSectionResizeMode(
|
||||
2, QHeaderView.ResizeMode.ResizeToContents
|
||||
)
|
||||
self.setSelectionMode(QTableWidget.SelectionMode.SingleSelection)
|
||||
|
||||
for segment in segments:
|
||||
|
@ -38,12 +39,18 @@ class TranscriptionSegmentsEditorWidget(QTableWidget):
|
|||
|
||||
start_item = QTableWidgetItem(to_timestamp(segment.start))
|
||||
start_item.setFlags(
|
||||
start_item.flags() & ~Qt.ItemFlag.ItemIsEditable & ~Qt.ItemFlag.ItemIsSelectable)
|
||||
start_item.flags()
|
||||
& ~Qt.ItemFlag.ItemIsEditable
|
||||
& ~Qt.ItemFlag.ItemIsSelectable
|
||||
)
|
||||
self.setItem(row_index, self.Column.START.value, start_item)
|
||||
|
||||
end_item = QTableWidgetItem(to_timestamp(segment.end))
|
||||
end_item.setFlags(
|
||||
end_item.flags() & ~Qt.ItemFlag.ItemIsEditable & ~Qt.ItemFlag.ItemIsSelectable)
|
||||
end_item.flags()
|
||||
& ~Qt.ItemFlag.ItemIsEditable
|
||||
& ~Qt.ItemFlag.ItemIsSelectable
|
||||
)
|
||||
self.setItem(row_index, self.Column.END.value, end_item)
|
||||
|
||||
text_item = QTableWidgetItem(segment.text)
|
||||
|
@ -61,5 +68,4 @@ class TranscriptionSegmentsEditorWidget(QTableWidget):
|
|||
|
||||
def on_item_selection_changed(self):
|
||||
ranges = self.selectedRanges()
|
||||
self.segment_index_selected.emit(
|
||||
ranges[0].topRow() if len(ranges) > 0 else -1)
|
||||
self.segment_index_selected.emit(ranges[0].topRow() if len(ranges) > 0 else -1)
|
||||
|
|
|
@ -30,13 +30,12 @@ class TranscriptionTasksTableWidget(QTableWidget):
|
|||
self.setColumnHidden(0, True)
|
||||
|
||||
self.verticalHeader().hide()
|
||||
self.setHorizontalHeaderLabels([_('ID'), _('File Name'), _('Status')])
|
||||
self.setHorizontalHeaderLabels([_("ID"), _("File Name"), _("Status")])
|
||||
self.setColumnWidth(self.Column.FILE_NAME.value, 250)
|
||||
self.setColumnWidth(self.Column.STATUS.value, 180)
|
||||
self.horizontalHeader().setMinimumSectionSize(180)
|
||||
|
||||
self.setSelectionBehavior(
|
||||
QAbstractItemView.SelectionBehavior.SelectRows)
|
||||
self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows)
|
||||
|
||||
def upsert_task(self, task: FileTranscriptionTask):
|
||||
task_row_index = self.task_row_index(task.id)
|
||||
|
@ -45,21 +44,19 @@ class TranscriptionTasksTableWidget(QTableWidget):
|
|||
|
||||
row_index = self.rowCount() - 1
|
||||
task_id_widget_item = QTableWidgetItem(str(task.id))
|
||||
self.setItem(row_index, self.Column.TASK_ID.value,
|
||||
task_id_widget_item)
|
||||
self.setItem(row_index, self.Column.TASK_ID.value, task_id_widget_item)
|
||||
|
||||
file_name_widget_item = QTableWidgetItem(
|
||||
os.path.basename(task.file_path))
|
||||
file_name_widget_item = QTableWidgetItem(os.path.basename(task.file_path))
|
||||
file_name_widget_item.setFlags(
|
||||
file_name_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||||
self.setItem(row_index, self.Column.FILE_NAME.value,
|
||||
file_name_widget_item)
|
||||
file_name_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable
|
||||
)
|
||||
self.setItem(row_index, self.Column.FILE_NAME.value, file_name_widget_item)
|
||||
|
||||
status_widget_item = QTableWidgetItem(self.get_status_text(task))
|
||||
status_widget_item.setFlags(
|
||||
status_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
|
||||
self.setItem(row_index, self.Column.STATUS.value,
|
||||
status_widget_item)
|
||||
status_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable
|
||||
)
|
||||
self.setItem(row_index, self.Column.STATUS.value, status_widget_item)
|
||||
else:
|
||||
status_widget = self.item(task_row_index, self.Column.STATUS.value)
|
||||
status_widget.setText(self.get_status_text(task))
|
||||
|
@ -67,31 +64,30 @@ class TranscriptionTasksTableWidget(QTableWidget):
|
|||
@staticmethod
|
||||
def format_timedelta(delta: datetime.timedelta):
|
||||
mm, ss = divmod(delta.seconds, 60)
|
||||
result = f'{ss}s'
|
||||
result = f"{ss}s"
|
||||
if mm == 0:
|
||||
return result
|
||||
hh, mm = divmod(mm, 60)
|
||||
result = f'{mm}m {result}'
|
||||
result = f"{mm}m {result}"
|
||||
if hh == 0:
|
||||
return result
|
||||
return f'{hh}h {result}'
|
||||
return f"{hh}h {result}"
|
||||
|
||||
@staticmethod
|
||||
def get_status_text(task: FileTranscriptionTask):
|
||||
if task.status == FileTranscriptionTask.Status.IN_PROGRESS:
|
||||
return (
|
||||
f'{_("In Progress")} ({task.fraction_completed :.0%})')
|
||||
return f'{_("In Progress")} ({task.fraction_completed :.0%})'
|
||||
elif task.status == FileTranscriptionTask.Status.COMPLETED:
|
||||
status = _('Completed')
|
||||
status = _("Completed")
|
||||
if task.started_at is not None and task.completed_at is not None:
|
||||
status += f" ({TranscriptionTasksTableWidget.format_timedelta(task.completed_at - task.started_at)})"
|
||||
return status
|
||||
elif task.status == FileTranscriptionTask.Status.FAILED:
|
||||
return f'{_("Failed")} ({task.error})'
|
||||
elif task.status == FileTranscriptionTask.Status.CANCELED:
|
||||
return _('Canceled')
|
||||
return _("Canceled")
|
||||
elif task.status == FileTranscriptionTask.Status.QUEUED:
|
||||
return _('Queued')
|
||||
return _("Queued")
|
||||
|
||||
def clear_task(self, task_id: int):
|
||||
task_row_index = self.task_row_index(task_id)
|
||||
|
@ -99,15 +95,20 @@ class TranscriptionTasksTableWidget(QTableWidget):
|
|||
self.removeRow(task_row_index)
|
||||
|
||||
def task_row_index(self, task_id: int) -> int | None:
|
||||
table_items_matching_task_id = [item for item in self.findItems(str(task_id), Qt.MatchFlag.MatchExactly) if
|
||||
item.column() == self.Column.TASK_ID.value]
|
||||
table_items_matching_task_id = [
|
||||
item
|
||||
for item in self.findItems(str(task_id), Qt.MatchFlag.MatchExactly)
|
||||
if item.column() == self.Column.TASK_ID.value
|
||||
]
|
||||
if len(table_items_matching_task_id) == 0:
|
||||
return None
|
||||
return table_items_matching_task_id[0].row()
|
||||
|
||||
@staticmethod
|
||||
def find_task_id(index: QModelIndex):
|
||||
sibling_index = index.siblingAtColumn(TranscriptionTasksTableWidget.Column.TASK_ID.value).data()
|
||||
sibling_index = index.siblingAtColumn(
|
||||
TranscriptionTasksTableWidget.Column.TASK_ID.value
|
||||
).data()
|
||||
return int(sibling_index) if sibling_index is not None else None
|
||||
|
||||
def keyPressEvent(self, event: QtGui.QKeyEvent) -> None:
|
||||
|
|
|
@ -3,20 +3,32 @@ from typing import List, Optional
|
|||
|
||||
from PyQt6.QtCore import Qt, pyqtSignal
|
||||
from PyQt6.QtGui import QUndoCommand, QUndoStack, QKeySequence, QAction
|
||||
from PyQt6.QtWidgets import QWidget, QHBoxLayout, QMenu, QPushButton, QVBoxLayout, \
|
||||
QFileDialog
|
||||
from PyQt6.QtWidgets import (
|
||||
QWidget,
|
||||
QHBoxLayout,
|
||||
QMenu,
|
||||
QPushButton,
|
||||
QVBoxLayout,
|
||||
QFileDialog,
|
||||
)
|
||||
|
||||
from buzz.action import Action
|
||||
from buzz.assets import get_asset_path
|
||||
from buzz.locale import _
|
||||
from buzz.paths import file_path_as_title
|
||||
from buzz.transcriber import FileTranscriptionTask, Segment, OutputFormat, \
|
||||
get_default_output_file_path, write_output
|
||||
from buzz.transcriber import (
|
||||
FileTranscriptionTask,
|
||||
Segment,
|
||||
OutputFormat,
|
||||
get_default_output_file_path,
|
||||
write_output,
|
||||
)
|
||||
from buzz.widgets.audio_player import AudioPlayer
|
||||
from buzz.widgets.icon import Icon
|
||||
from buzz.widgets.toolbar import ToolBar
|
||||
from buzz.widgets.transcription_segments_editor_widget import \
|
||||
TranscriptionSegmentsEditorWidget
|
||||
from buzz.widgets.transcription_segments_editor_widget import (
|
||||
TranscriptionSegmentsEditorWidget,
|
||||
)
|
||||
|
||||
|
||||
class TranscriptionViewerWidget(QWidget):
|
||||
|
@ -24,9 +36,14 @@ class TranscriptionViewerWidget(QWidget):
|
|||
task_changed = pyqtSignal()
|
||||
|
||||
class ChangeSegmentTextCommand(QUndoCommand):
|
||||
def __init__(self, table_widget: TranscriptionSegmentsEditorWidget,
|
||||
segments: List[Segment],
|
||||
segment_index: int, segment_text: str, task_changed: pyqtSignal):
|
||||
def __init__(
|
||||
self,
|
||||
table_widget: TranscriptionSegmentsEditorWidget,
|
||||
segments: List[Segment],
|
||||
segment_index: int,
|
||||
segment_text: str,
|
||||
task_changed: pyqtSignal,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.table_widget = table_widget
|
||||
|
@ -52,10 +69,11 @@ class TranscriptionViewerWidget(QWidget):
|
|||
self.task_changed.emit()
|
||||
|
||||
def __init__(
|
||||
self, transcription_task: FileTranscriptionTask,
|
||||
open_transcription_output=True,
|
||||
parent: Optional['QWidget'] = None,
|
||||
flags: Qt.WindowType = Qt.WindowType.Widget,
|
||||
self,
|
||||
transcription_task: FileTranscriptionTask,
|
||||
open_transcription_output=True,
|
||||
parent: Optional["QWidget"] = None,
|
||||
flags: Qt.WindowType = Qt.WindowType.Widget,
|
||||
) -> None:
|
||||
super().__init__(parent, flags)
|
||||
self.transcription_task = transcription_task
|
||||
|
@ -71,20 +89,23 @@ class TranscriptionViewerWidget(QWidget):
|
|||
undo_action = self.undo_stack.createUndoAction(self, _("Undo"))
|
||||
undo_action.setShortcuts(QKeySequence.StandardKey.Undo)
|
||||
undo_action.setIcon(
|
||||
Icon(get_asset_path('assets/undo_FILL0_wght700_GRAD0_opsz48.svg'), self))
|
||||
Icon(get_asset_path("assets/undo_FILL0_wght700_GRAD0_opsz48.svg"), self)
|
||||
)
|
||||
undo_action.setToolTip(Action.get_tooltip(undo_action))
|
||||
|
||||
redo_action = self.undo_stack.createRedoAction(self, _("Redo"))
|
||||
redo_action.setShortcuts(QKeySequence.StandardKey.Redo)
|
||||
redo_action.setIcon(
|
||||
Icon(get_asset_path('assets/redo_FILL0_wght700_GRAD0_opsz48.svg'), self))
|
||||
Icon(get_asset_path("assets/redo_FILL0_wght700_GRAD0_opsz48.svg"), self)
|
||||
)
|
||||
redo_action.setToolTip(Action.get_tooltip(redo_action))
|
||||
|
||||
toolbar = ToolBar()
|
||||
toolbar.addActions([undo_action, redo_action])
|
||||
|
||||
self.table_widget = TranscriptionSegmentsEditorWidget(
|
||||
segments=transcription_task.segments, parent=self)
|
||||
segments=transcription_task.segments, parent=self
|
||||
)
|
||||
self.table_widget.segment_text_changed.connect(self.on_segment_text_changed)
|
||||
self.table_widget.segment_index_selected.connect(self.on_segment_index_selected)
|
||||
|
||||
|
@ -96,14 +117,16 @@ class TranscriptionViewerWidget(QWidget):
|
|||
buttons_layout.addStretch()
|
||||
|
||||
export_button_menu = QMenu()
|
||||
actions = [QAction(text=output_format.value.upper(), parent=self)
|
||||
for output_format in OutputFormat]
|
||||
actions = [
|
||||
QAction(text=output_format.value.upper(), parent=self)
|
||||
for output_format in OutputFormat
|
||||
]
|
||||
export_button_menu.addActions(actions)
|
||||
|
||||
export_button_menu.triggered.connect(self.on_menu_triggered)
|
||||
|
||||
export_button = QPushButton(self)
|
||||
export_button.setText(_('Export'))
|
||||
export_button.setText(_("Export"))
|
||||
export_button.setMenu(export_button_menu)
|
||||
|
||||
buttons_layout.addWidget(export_button)
|
||||
|
@ -120,11 +143,14 @@ class TranscriptionViewerWidget(QWidget):
|
|||
def on_segment_text_changed(self, event: tuple):
|
||||
segment_index, segment_text = event
|
||||
self.undo_stack.push(
|
||||
self.ChangeSegmentTextCommand(table_widget=self.table_widget,
|
||||
segments=self.transcription_task.segments,
|
||||
segment_index=segment_index,
|
||||
segment_text=segment_text,
|
||||
task_changed=self.task_changed))
|
||||
self.ChangeSegmentTextCommand(
|
||||
table_widget=self.table_widget,
|
||||
segments=self.transcription_task.segments,
|
||||
segment_index=segment_index,
|
||||
segment_text=segment_text,
|
||||
task_changed=self.task_changed,
|
||||
)
|
||||
)
|
||||
|
||||
def on_segment_index_selected(self, index: int):
|
||||
selected_segment = self.transcription_task.segments[index]
|
||||
|
@ -134,15 +160,22 @@ class TranscriptionViewerWidget(QWidget):
|
|||
def on_menu_triggered(self, action: QAction):
|
||||
output_format = OutputFormat[action.text()]
|
||||
|
||||
default_path = get_default_output_file_path(task=self.transcription_task,
|
||||
output_format=output_format)
|
||||
default_path = get_default_output_file_path(
|
||||
task=self.transcription_task, output_format=output_format
|
||||
)
|
||||
|
||||
(output_file_path, nil) = QFileDialog.getSaveFileName(self, _('Save File'),
|
||||
default_path,
|
||||
_('Text files') + f' (*.{output_format.value})')
|
||||
(output_file_path, nil) = QFileDialog.getSaveFileName(
|
||||
self,
|
||||
_("Save File"),
|
||||
default_path,
|
||||
_("Text files") + f" (*.{output_format.value})",
|
||||
)
|
||||
|
||||
if output_file_path == '':
|
||||
if output_file_path == "":
|
||||
return
|
||||
|
||||
write_output(path=output_file_path, segments=self.transcription_task.segments,
|
||||
output_format=output_format)
|
||||
write_output(
|
||||
path=output_file_path,
|
||||
segments=self.transcription_task.segments,
|
||||
output_format=output_format,
|
||||
)
|
||||
|
|
77
poetry.lock
generated
77
poetry.lock
generated
|
@ -266,6 +266,54 @@ files = [
|
|||
{file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "black"
|
||||
version = "23.7.0"
|
||||
description = "The uncompromising code formatter."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"},
|
||||
{file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"},
|
||||
{file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"},
|
||||
{file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"},
|
||||
{file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"},
|
||||
{file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"},
|
||||
{file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"},
|
||||
{file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"},
|
||||
{file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"},
|
||||
{file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"},
|
||||
{file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"},
|
||||
{file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"},
|
||||
{file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"},
|
||||
{file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"},
|
||||
{file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"},
|
||||
{file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"},
|
||||
{file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"},
|
||||
{file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"},
|
||||
{file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"},
|
||||
{file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"},
|
||||
{file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"},
|
||||
{file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = {version = ">=3.7.4", optional = true, markers = "extra == \"d\""}
|
||||
click = ">=8.0.0"
|
||||
mypy-extensions = ">=0.4.3"
|
||||
packaging = ">=22.0"
|
||||
pathspec = ">=0.9.0"
|
||||
platformdirs = ">=2"
|
||||
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
|
||||
typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}
|
||||
|
||||
[package.extras]
|
||||
colorama = ["colorama (>=0.4.3)"]
|
||||
d = ["aiohttp (>=3.7.4)"]
|
||||
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
|
||||
uvloop = ["uvloop (>=0.15.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2023.5.7"
|
||||
|
@ -452,6 +500,21 @@ files = [
|
|||
{file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click"
|
||||
version = "8.1.7"
|
||||
description = "Composable command line interface toolkit"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"},
|
||||
{file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "3.26.4"
|
||||
|
@ -1534,6 +1597,18 @@ files = [
|
|||
{file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pathspec"
|
||||
version = "0.11.2"
|
||||
description = "Utility library for gitignore style pattern matching of file paths."
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"},
|
||||
{file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pefile"
|
||||
version = "2023.2.7"
|
||||
|
@ -2669,4 +2744,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9.13,<3.11"
|
||||
content-hash = "ceb6ce6c7083882f1499bd36f5e98f6aa1e0a872d8268ccbda91d67ee81fdd1e"
|
||||
content-hash = "fe7fae59602bd0ecdceafbfe274f6f36f0cb489b67bfc7d4bfae4998dbbe672a"
|
||||
|
|
|
@ -33,6 +33,7 @@ pytest-xvfb = "^2.0.0"
|
|||
pylint = "^2.15.5"
|
||||
pre-commit = "^2.20.0"
|
||||
pytest-benchmark = "^4.0.0"
|
||||
black = {extras = ["d"], version = "^23.7.0"}
|
||||
|
||||
[tool.poetry.group.build.dependencies]
|
||||
ctypesgen = "^1.1.1"
|
||||
|
|
|
@ -1,16 +1,31 @@
|
|||
from buzz.cache import TasksCache
|
||||
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
|
||||
TranscriptionOptions)
|
||||
from buzz.transcriber import (
|
||||
FileTranscriptionOptions,
|
||||
FileTranscriptionTask,
|
||||
TranscriptionOptions,
|
||||
)
|
||||
|
||||
|
||||
class TestTasksCache:
|
||||
def test_should_save_and_load(self, tmp_path):
|
||||
cache = TasksCache(cache_dir=str(tmp_path))
|
||||
tasks = [FileTranscriptionTask(file_path='1.mp3', transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=['1.mp3']),
|
||||
model_path=''),
|
||||
FileTranscriptionTask(file_path='2.mp3', transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=['2.mp3']),
|
||||
model_path='')]
|
||||
tasks = [
|
||||
FileTranscriptionTask(
|
||||
file_path="1.mp3",
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=["1.mp3"]
|
||||
),
|
||||
model_path="",
|
||||
),
|
||||
FileTranscriptionTask(
|
||||
file_path="2.mp3",
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=["2.mp3"]
|
||||
),
|
||||
model_path="",
|
||||
),
|
||||
]
|
||||
cache.save(tasks)
|
||||
assert cache.load() == tasks
|
||||
|
|
|
@ -8,91 +8,100 @@ import pytest
|
|||
import sounddevice
|
||||
from PyQt6.QtCore import QSize, Qt
|
||||
from PyQt6.QtGui import QValidator, QKeyEvent
|
||||
from PyQt6.QtWidgets import QPushButton, QToolBar, QTableWidget, QApplication, QMessageBox
|
||||
from PyQt6.QtWidgets import (
|
||||
QPushButton,
|
||||
QToolBar,
|
||||
QTableWidget,
|
||||
QApplication,
|
||||
QMessageBox,
|
||||
)
|
||||
from _pytest.fixtures import SubRequest
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.__version__ import VERSION
|
||||
from buzz.cache import TasksCache
|
||||
from buzz.gui import (AudioDevicesComboBox, MainWindow,
|
||||
RecordingTranscriberWidget)
|
||||
from buzz.gui import AudioDevicesComboBox, MainWindow, RecordingTranscriberWidget
|
||||
from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog
|
||||
from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
|
||||
from buzz.widgets.transcriber.hugging_face_search_line_edit import \
|
||||
HuggingFaceSearchLineEdit
|
||||
from buzz.widgets.transcriber.hugging_face_search_line_edit import (
|
||||
HuggingFaceSearchLineEdit,
|
||||
)
|
||||
from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox
|
||||
from buzz.widgets.transcriber.temperature_validator import TemperatureValidator
|
||||
from buzz.widgets.about_dialog import AboutDialog
|
||||
from buzz.model_loader import ModelType
|
||||
from buzz.settings.settings import Settings
|
||||
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
|
||||
TranscriptionOptions)
|
||||
from buzz.widgets.transcriber.transcription_options_group_box import \
|
||||
TranscriptionOptionsGroupBox
|
||||
from buzz.transcriber import (
|
||||
FileTranscriptionOptions,
|
||||
FileTranscriptionTask,
|
||||
TranscriptionOptions,
|
||||
)
|
||||
from buzz.widgets.transcriber.transcription_options_group_box import (
|
||||
TranscriptionOptionsGroupBox,
|
||||
)
|
||||
from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget
|
||||
from tests.mock_sounddevice import MockInputStream, mock_query_devices
|
||||
from .mock_qt import MockNetworkAccessManager, MockNetworkReply
|
||||
|
||||
if platform.system() == 'Linux':
|
||||
multiprocessing.set_start_method('spawn')
|
||||
if platform.system() == "Linux":
|
||||
multiprocessing.set_start_method("spawn")
|
||||
|
||||
|
||||
@pytest.fixture(scope='module', autouse=True)
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def audio_setup():
|
||||
with patch('sounddevice.query_devices') as query_devices_mock, \
|
||||
patch('sounddevice.InputStream', side_effect=MockInputStream), \
|
||||
patch('sounddevice.check_input_settings'):
|
||||
with patch("sounddevice.query_devices") as query_devices_mock, patch(
|
||||
"sounddevice.InputStream", side_effect=MockInputStream
|
||||
), patch("sounddevice.check_input_settings"):
|
||||
query_devices_mock.return_value = mock_query_devices
|
||||
sounddevice.default.device = 3, 4
|
||||
yield
|
||||
|
||||
|
||||
class TestLanguagesComboBox:
|
||||
|
||||
def test_should_show_sorted_whisper_languages(self, qtbot):
|
||||
languages_combox_box = LanguagesComboBox('en')
|
||||
languages_combox_box = LanguagesComboBox("en")
|
||||
qtbot.add_widget(languages_combox_box)
|
||||
assert languages_combox_box.itemText(0) == 'Detect Language'
|
||||
assert languages_combox_box.itemText(10) == 'Belarusian'
|
||||
assert languages_combox_box.itemText(20) == 'Dutch'
|
||||
assert languages_combox_box.itemText(30) == 'Gujarati'
|
||||
assert languages_combox_box.itemText(40) == 'Japanese'
|
||||
assert languages_combox_box.itemText(50) == 'Lithuanian'
|
||||
assert languages_combox_box.itemText(0) == "Detect Language"
|
||||
assert languages_combox_box.itemText(10) == "Belarusian"
|
||||
assert languages_combox_box.itemText(20) == "Dutch"
|
||||
assert languages_combox_box.itemText(30) == "Gujarati"
|
||||
assert languages_combox_box.itemText(40) == "Japanese"
|
||||
assert languages_combox_box.itemText(50) == "Lithuanian"
|
||||
|
||||
def test_should_select_en_as_default_language(self, qtbot):
|
||||
languages_combox_box = LanguagesComboBox('en')
|
||||
languages_combox_box = LanguagesComboBox("en")
|
||||
qtbot.add_widget(languages_combox_box)
|
||||
assert languages_combox_box.currentText() == 'English'
|
||||
assert languages_combox_box.currentText() == "English"
|
||||
|
||||
def test_should_select_detect_language_as_default(self, qtbot):
|
||||
languages_combo_box = LanguagesComboBox(None)
|
||||
qtbot.add_widget(languages_combo_box)
|
||||
assert languages_combo_box.currentText() == 'Detect Language'
|
||||
assert languages_combo_box.currentText() == "Detect Language"
|
||||
|
||||
|
||||
class TestAudioDevicesComboBox:
|
||||
def test_get_devices(self):
|
||||
audio_devices_combo_box = AudioDevicesComboBox()
|
||||
|
||||
assert audio_devices_combo_box.itemText(0) == 'Background Music'
|
||||
assert audio_devices_combo_box.itemText(1) == 'Background Music (UI Sounds)'
|
||||
assert audio_devices_combo_box.itemText(2) == 'BlackHole 2ch'
|
||||
assert audio_devices_combo_box.itemText(3) == 'MacBook Pro Microphone'
|
||||
assert audio_devices_combo_box.itemText(4) == 'Null Audio Device'
|
||||
assert audio_devices_combo_box.itemText(0) == "Background Music"
|
||||
assert audio_devices_combo_box.itemText(1) == "Background Music (UI Sounds)"
|
||||
assert audio_devices_combo_box.itemText(2) == "BlackHole 2ch"
|
||||
assert audio_devices_combo_box.itemText(3) == "MacBook Pro Microphone"
|
||||
assert audio_devices_combo_box.itemText(4) == "Null Audio Device"
|
||||
|
||||
assert audio_devices_combo_box.currentText() == 'MacBook Pro Microphone'
|
||||
assert audio_devices_combo_box.currentText() == "MacBook Pro Microphone"
|
||||
|
||||
def test_select_default_mic_when_no_default(self):
|
||||
sounddevice.default.device = -1, 1
|
||||
|
||||
audio_devices_combo_box = AudioDevicesComboBox()
|
||||
assert audio_devices_combo_box.currentText() == 'Background Music'
|
||||
assert audio_devices_combo_box.currentText() == "Background Music"
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tasks_cache(tmp_path, request: SubRequest):
|
||||
cache = TasksCache(cache_dir=str(tmp_path))
|
||||
if hasattr(request, 'param'):
|
||||
if hasattr(request, "param"):
|
||||
tasks: List[FileTranscriptionTask] = request.param
|
||||
cache.save(tasks)
|
||||
yield cache
|
||||
|
@ -100,28 +109,40 @@ def tasks_cache(tmp_path, request: SubRequest):
|
|||
|
||||
|
||||
def get_test_asset(filename: str):
|
||||
return os.path.join(os.path.dirname(__file__), '../testdata/', filename)
|
||||
return os.path.join(os.path.dirname(__file__), "../testdata/", filename)
|
||||
|
||||
|
||||
mock_tasks = [
|
||||
FileTranscriptionTask(file_path='', transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[]), model_path='',
|
||||
status=FileTranscriptionTask.Status.COMPLETED),
|
||||
FileTranscriptionTask(file_path='', transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[]), model_path='',
|
||||
status=FileTranscriptionTask.Status.CANCELED),
|
||||
FileTranscriptionTask(file_path='', transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[]), model_path='',
|
||||
status=FileTranscriptionTask.Status.FAILED, error='Error'),
|
||||
FileTranscriptionTask(
|
||||
file_path="",
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
|
||||
model_path="",
|
||||
status=FileTranscriptionTask.Status.COMPLETED,
|
||||
),
|
||||
FileTranscriptionTask(
|
||||
file_path="",
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
|
||||
model_path="",
|
||||
status=FileTranscriptionTask.Status.CANCELED,
|
||||
),
|
||||
FileTranscriptionTask(
|
||||
file_path="",
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
|
||||
model_path="",
|
||||
status=FileTranscriptionTask.Status.FAILED,
|
||||
error="Error",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class TestMainWindow:
|
||||
|
||||
def test_should_set_window_title_and_icon(self, qtbot):
|
||||
window = MainWindow()
|
||||
qtbot.add_widget(window)
|
||||
assert window.windowTitle() == 'Buzz'
|
||||
assert window.windowTitle() == "Buzz"
|
||||
assert window.windowIcon().pixmap(QSize(64, 64)).isNull() is False
|
||||
window.close()
|
||||
|
||||
|
@ -132,13 +153,18 @@ class TestMainWindow:
|
|||
|
||||
self._start_new_transcription(window)
|
||||
|
||||
open_transcript_action = self._get_toolbar_action(window, 'Open Transcript')
|
||||
open_transcript_action = self._get_toolbar_action(window, "Open Transcript")
|
||||
assert open_transcript_action.isEnabled() is False
|
||||
|
||||
table_widget: QTableWidget = window.findChild(QTableWidget)
|
||||
qtbot.wait_until(self._assert_task_status(table_widget, 0, 'Completed'), timeout=2 * 60 * 1000)
|
||||
qtbot.wait_until(
|
||||
self._assert_task_status(table_widget, 0, "Completed"),
|
||||
timeout=2 * 60 * 1000,
|
||||
)
|
||||
|
||||
table_widget.setCurrentIndex(table_widget.indexFromItem(table_widget.item(0, 1)))
|
||||
table_widget.setCurrentIndex(
|
||||
table_widget.indexFromItem(table_widget.item(0, 1))
|
||||
)
|
||||
assert open_transcript_action.isEnabled()
|
||||
|
||||
# @pytest.mark.skip(reason='Timing out or crashing')
|
||||
|
@ -152,8 +178,8 @@ class TestMainWindow:
|
|||
|
||||
def assert_task_in_progress():
|
||||
assert table_widget.rowCount() > 0
|
||||
assert table_widget.item(0, 1).text() == 'whisper-french.mp3'
|
||||
assert 'In Progress' in table_widget.item(0, 2).text()
|
||||
assert table_widget.item(0, 1).text() == "whisper-french.mp3"
|
||||
assert "In Progress" in table_widget.item(0, 2).text()
|
||||
|
||||
qtbot.wait_until(assert_task_in_progress, timeout=2 * 60 * 1000)
|
||||
|
||||
|
@ -161,7 +187,9 @@ class TestMainWindow:
|
|||
table_widget.selectRow(0)
|
||||
window.toolbar.stop_transcription_action.trigger()
|
||||
|
||||
qtbot.wait_until(self._assert_task_status(table_widget, 0, 'Canceled'), timeout=60 * 1000)
|
||||
qtbot.wait_until(
|
||||
self._assert_task_status(table_widget, 0, "Canceled"), timeout=60 * 1000
|
||||
)
|
||||
|
||||
table_widget.selectRow(0)
|
||||
assert window.toolbar.stop_transcription_action.isEnabled() is False
|
||||
|
@ -169,7 +197,7 @@ class TestMainWindow:
|
|||
|
||||
window.close()
|
||||
|
||||
@pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True)
|
||||
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
|
||||
def test_should_load_tasks_from_cache(self, qtbot, tasks_cache):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
qtbot.add_widget(window)
|
||||
|
@ -177,43 +205,47 @@ class TestMainWindow:
|
|||
table_widget: QTableWidget = window.findChild(QTableWidget)
|
||||
assert table_widget.rowCount() == 3
|
||||
|
||||
assert table_widget.item(0, 2).text() == 'Completed'
|
||||
assert table_widget.item(0, 2).text() == "Completed"
|
||||
table_widget.selectRow(0)
|
||||
assert window.toolbar.open_transcript_action.isEnabled()
|
||||
|
||||
assert table_widget.item(1, 2).text() == 'Canceled'
|
||||
assert table_widget.item(1, 2).text() == "Canceled"
|
||||
table_widget.selectRow(1)
|
||||
assert window.toolbar.open_transcript_action.isEnabled() is False
|
||||
|
||||
assert table_widget.item(2, 2).text() == 'Failed (Error)'
|
||||
assert table_widget.item(2, 2).text() == "Failed (Error)"
|
||||
table_widget.selectRow(2)
|
||||
assert window.toolbar.open_transcript_action.isEnabled() is False
|
||||
window.close()
|
||||
|
||||
@pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True)
|
||||
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
|
||||
def test_should_clear_history_with_rows_selected(self, qtbot, tasks_cache):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
|
||||
table_widget: QTableWidget = window.findChild(QTableWidget)
|
||||
table_widget.selectAll()
|
||||
|
||||
with patch('PyQt6.QtWidgets.QMessageBox.question') as question_message_box_mock:
|
||||
with patch("PyQt6.QtWidgets.QMessageBox.question") as question_message_box_mock:
|
||||
question_message_box_mock.return_value = QMessageBox.StandardButton.Yes
|
||||
window.toolbar.clear_history_action.trigger()
|
||||
|
||||
assert table_widget.rowCount() == 0
|
||||
window.close()
|
||||
|
||||
@pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True)
|
||||
def test_should_have_clear_history_action_disabled_with_no_rows_selected(self, qtbot, tasks_cache):
|
||||
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
|
||||
def test_should_have_clear_history_action_disabled_with_no_rows_selected(
|
||||
self, qtbot, tasks_cache
|
||||
):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
qtbot.add_widget(window)
|
||||
|
||||
assert window.toolbar.clear_history_action.isEnabled() is False
|
||||
window.close()
|
||||
|
||||
@pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True)
|
||||
def test_should_open_transcription_viewer_when_menu_action_is_clicked(self, qtbot, tasks_cache):
|
||||
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
|
||||
def test_should_open_transcription_viewer_when_menu_action_is_clicked(
|
||||
self, qtbot, tasks_cache
|
||||
):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
qtbot.add_widget(window)
|
||||
|
||||
|
@ -228,23 +260,33 @@ class TestMainWindow:
|
|||
|
||||
window.close()
|
||||
|
||||
@pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True)
|
||||
def test_should_open_transcription_viewer_when_return_clicked(self, qtbot, tasks_cache):
|
||||
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
|
||||
def test_should_open_transcription_viewer_when_return_clicked(
|
||||
self, qtbot, tasks_cache
|
||||
):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
qtbot.add_widget(window)
|
||||
|
||||
table_widget: QTableWidget = window.findChild(QTableWidget)
|
||||
table_widget.selectRow(0)
|
||||
table_widget.keyPressEvent(
|
||||
QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Return, Qt.KeyboardModifier.NoModifier, '\r'))
|
||||
QKeyEvent(
|
||||
QKeyEvent.Type.KeyPress,
|
||||
Qt.Key.Key_Return,
|
||||
Qt.KeyboardModifier.NoModifier,
|
||||
"\r",
|
||||
)
|
||||
)
|
||||
|
||||
transcription_viewer = window.findChild(TranscriptionViewerWidget)
|
||||
assert transcription_viewer is not None
|
||||
|
||||
window.close()
|
||||
|
||||
@pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True)
|
||||
def test_should_have_open_transcript_action_disabled_with_no_rows_selected(self, qtbot, tasks_cache):
|
||||
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
|
||||
def test_should_have_open_transcript_action_disabled_with_no_rows_selected(
|
||||
self, qtbot, tasks_cache
|
||||
):
|
||||
window = MainWindow(tasks_cache=tasks_cache)
|
||||
qtbot.add_widget(window)
|
||||
|
||||
|
@ -253,20 +295,31 @@ class TestMainWindow:
|
|||
|
||||
@staticmethod
|
||||
def _start_new_transcription(window: MainWindow):
|
||||
with patch('PyQt6.QtWidgets.QFileDialog.getOpenFileNames') as open_file_names_mock:
|
||||
open_file_names_mock.return_value = ([get_test_asset('whisper-french.mp3')], '')
|
||||
new_transcription_action = TestMainWindow._get_toolbar_action(window, 'New Transcription')
|
||||
with patch(
|
||||
"PyQt6.QtWidgets.QFileDialog.getOpenFileNames"
|
||||
) as open_file_names_mock:
|
||||
open_file_names_mock.return_value = (
|
||||
[get_test_asset("whisper-french.mp3")],
|
||||
"",
|
||||
)
|
||||
new_transcription_action = TestMainWindow._get_toolbar_action(
|
||||
window, "New Transcription"
|
||||
)
|
||||
new_transcription_action.trigger()
|
||||
|
||||
file_transcriber_widget: FileTranscriberWidget = window.findChild(FileTranscriberWidget)
|
||||
file_transcriber_widget: FileTranscriberWidget = window.findChild(
|
||||
FileTranscriberWidget
|
||||
)
|
||||
run_button: QPushButton = file_transcriber_widget.findChild(QPushButton)
|
||||
run_button.click()
|
||||
|
||||
@staticmethod
|
||||
def _assert_task_status(table_widget: QTableWidget, row_index: int, expected_status: str):
|
||||
def _assert_task_status(
|
||||
table_widget: QTableWidget, row_index: int, expected_status: str
|
||||
):
|
||||
def assert_task_canceled():
|
||||
assert table_widget.rowCount() > 0
|
||||
assert table_widget.item(row_index, 1).text() == 'whisper-french.mp3'
|
||||
assert table_widget.item(row_index, 1).text() == "whisper-french.mp3"
|
||||
assert expected_status in table_widget.item(row_index, 2).text()
|
||||
|
||||
return assert_task_canceled
|
||||
|
@ -277,7 +330,7 @@ class TestMainWindow:
|
|||
return [action for action in toolbar.actions() if action.text() == text][0]
|
||||
|
||||
|
||||
@pytest.fixture(scope='module', autouse=True)
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def clear_settings():
|
||||
settings = Settings()
|
||||
settings.clear()
|
||||
|
@ -285,7 +338,7 @@ def clear_settings():
|
|||
|
||||
class TestAboutDialog:
|
||||
def test_should_check_for_updates(self, qtbot: QtBot):
|
||||
reply = MockNetworkReply(data={'name': 'v' + VERSION})
|
||||
reply = MockNetworkReply(data={"name": "v" + VERSION})
|
||||
manager = MockNetworkAccessManager(reply=reply)
|
||||
dialog = AboutDialog(network_access_manager=manager)
|
||||
qtbot.add_widget(dialog)
|
||||
|
@ -296,41 +349,45 @@ class TestAboutDialog:
|
|||
with qtbot.wait_signal(dialog.network_access_manager.finished):
|
||||
dialog.check_updates_button.click()
|
||||
|
||||
mock_message_box_information.assert_called_with(dialog, '', "You're up to date!")
|
||||
mock_message_box_information.assert_called_with(
|
||||
dialog, "", "You're up to date!"
|
||||
)
|
||||
|
||||
|
||||
class TestAdvancedSettingsDialog:
|
||||
def test_should_update_advanced_settings(self, qtbot: QtBot):
|
||||
dialog = AdvancedSettingsDialog(
|
||||
transcription_options=TranscriptionOptions(temperature=(0.0, 0.8), initial_prompt='prompt'))
|
||||
transcription_options=TranscriptionOptions(
|
||||
temperature=(0.0, 0.8), initial_prompt="prompt"
|
||||
)
|
||||
)
|
||||
qtbot.add_widget(dialog)
|
||||
|
||||
transcription_options_mock = Mock()
|
||||
dialog.transcription_options_changed.connect(
|
||||
transcription_options_mock)
|
||||
dialog.transcription_options_changed.connect(transcription_options_mock)
|
||||
|
||||
assert dialog.windowTitle() == 'Advanced Settings'
|
||||
assert dialog.temperature_line_edit.text() == '0.0, 0.8'
|
||||
assert dialog.initial_prompt_text_edit.toPlainText() == 'prompt'
|
||||
assert dialog.windowTitle() == "Advanced Settings"
|
||||
assert dialog.temperature_line_edit.text() == "0.0, 0.8"
|
||||
assert dialog.initial_prompt_text_edit.toPlainText() == "prompt"
|
||||
|
||||
dialog.temperature_line_edit.setText('0.0, 0.8, 1.0')
|
||||
dialog.initial_prompt_text_edit.setPlainText('new prompt')
|
||||
dialog.temperature_line_edit.setText("0.0, 0.8, 1.0")
|
||||
dialog.initial_prompt_text_edit.setPlainText("new prompt")
|
||||
|
||||
assert transcription_options_mock.call_args[0][0].temperature == (
|
||||
0.0, 0.8, 1.0)
|
||||
assert transcription_options_mock.call_args[0][0].initial_prompt == 'new prompt'
|
||||
assert transcription_options_mock.call_args[0][0].temperature == (0.0, 0.8, 1.0)
|
||||
assert transcription_options_mock.call_args[0][0].initial_prompt == "new prompt"
|
||||
|
||||
|
||||
class TestTemperatureValidator:
|
||||
validator = TemperatureValidator(None)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'text,state',
|
||||
"text,state",
|
||||
[
|
||||
('0.0,0.5,1.0', QValidator.State.Acceptable),
|
||||
('0.0,0.5,', QValidator.State.Intermediate),
|
||||
('0.0,0.5,p', QValidator.State.Invalid),
|
||||
])
|
||||
("0.0,0.5,1.0", QValidator.State.Acceptable),
|
||||
("0.0,0.5,", QValidator.State.Intermediate),
|
||||
("0.0,0.5,p", QValidator.State.Invalid),
|
||||
],
|
||||
)
|
||||
def test_should_validate_temperature(self, text: str, state: QValidator.State):
|
||||
assert self.validator.validate(text, 0)[0] == state
|
||||
|
||||
|
@ -339,9 +396,9 @@ class TestRecordingTranscriberWidget:
|
|||
def test_should_set_window_title(self, qtbot: QtBot):
|
||||
widget = RecordingTranscriberWidget()
|
||||
qtbot.add_widget(widget)
|
||||
assert widget.windowTitle() == 'Live Recording'
|
||||
assert widget.windowTitle() == "Live Recording"
|
||||
|
||||
@pytest.mark.skip(reason='Seg faults on CI')
|
||||
@pytest.mark.skip(reason="Seg faults on CI")
|
||||
def test_should_transcribe(self, qtbot):
|
||||
widget = RecordingTranscriberWidget()
|
||||
qtbot.add_widget(widget)
|
||||
|
@ -355,31 +412,37 @@ class TestRecordingTranscriberWidget:
|
|||
with qtbot.wait_signal(widget.transcription_thread.finished, timeout=60 * 1000):
|
||||
widget.stop_recording()
|
||||
|
||||
assert 'Welcome to Passe' in widget.text_box.toPlainText()
|
||||
assert "Welcome to Passe" in widget.text_box.toPlainText()
|
||||
|
||||
|
||||
class TestHuggingFaceSearchLineEdit:
|
||||
def test_should_update_selected_model_on_type(self, qtbot: QtBot):
|
||||
widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager())
|
||||
widget = HuggingFaceSearchLineEdit(
|
||||
network_access_manager=self.network_access_manager()
|
||||
)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
mock_model_selected = Mock()
|
||||
widget.model_selected.connect(mock_model_selected)
|
||||
|
||||
self._set_text_and_wait_response(qtbot, widget)
|
||||
mock_model_selected.assert_called_with('openai/whisper-tiny')
|
||||
mock_model_selected.assert_called_with("openai/whisper-tiny")
|
||||
|
||||
def test_should_show_list_of_models(self, qtbot: QtBot):
|
||||
widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager())
|
||||
widget = HuggingFaceSearchLineEdit(
|
||||
network_access_manager=self.network_access_manager()
|
||||
)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
self._set_text_and_wait_response(qtbot, widget)
|
||||
|
||||
assert widget.popup.count() > 0
|
||||
assert 'openai/whisper-tiny' in widget.popup.item(0).text()
|
||||
assert "openai/whisper-tiny" in widget.popup.item(0).text()
|
||||
|
||||
def test_should_select_model_from_list(self, qtbot: QtBot):
|
||||
widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager())
|
||||
widget = HuggingFaceSearchLineEdit(
|
||||
network_access_manager=self.network_access_manager()
|
||||
)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
mock_model_selected = Mock()
|
||||
|
@ -388,23 +451,35 @@ class TestHuggingFaceSearchLineEdit:
|
|||
self._set_text_and_wait_response(qtbot, widget)
|
||||
|
||||
# press down arrow and enter to select next item
|
||||
QApplication.sendEvent(widget.popup,
|
||||
QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Down, Qt.KeyboardModifier.NoModifier))
|
||||
QApplication.sendEvent(widget.popup,
|
||||
QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Enter, Qt.KeyboardModifier.NoModifier))
|
||||
QApplication.sendEvent(
|
||||
widget.popup,
|
||||
QKeyEvent(
|
||||
QKeyEvent.Type.KeyPress, Qt.Key.Key_Down, Qt.KeyboardModifier.NoModifier
|
||||
),
|
||||
)
|
||||
QApplication.sendEvent(
|
||||
widget.popup,
|
||||
QKeyEvent(
|
||||
QKeyEvent.Type.KeyPress,
|
||||
Qt.Key.Key_Enter,
|
||||
Qt.KeyboardModifier.NoModifier,
|
||||
),
|
||||
)
|
||||
|
||||
mock_model_selected.assert_called_with('openai/whisper-tiny.en')
|
||||
mock_model_selected.assert_called_with("openai/whisper-tiny.en")
|
||||
|
||||
@staticmethod
|
||||
def network_access_manager():
|
||||
reply = MockNetworkReply(data=[{'id': 'openai/whisper-tiny'}, {'id': 'openai/whisper-tiny.en'}])
|
||||
reply = MockNetworkReply(
|
||||
data=[{"id": "openai/whisper-tiny"}, {"id": "openai/whisper-tiny.en"}]
|
||||
)
|
||||
return MockNetworkAccessManager(reply=reply)
|
||||
|
||||
@staticmethod
|
||||
def _set_text_and_wait_response(qtbot: QtBot, widget: HuggingFaceSearchLineEdit):
|
||||
with qtbot.wait_signal(widget.network_manager.finished):
|
||||
widget.setText('openai/whisper-tiny')
|
||||
widget.textEdited.emit('openai/whisper-tiny')
|
||||
widget.setText("openai/whisper-tiny")
|
||||
widget.textEdited.emit("openai/whisper-tiny")
|
||||
|
||||
|
||||
class TestTranscriptionOptionsGroupBox:
|
||||
|
@ -417,5 +492,7 @@ class TestTranscriptionOptionsGroupBox:
|
|||
|
||||
widget.model_type_combo_box.setCurrentIndex(1)
|
||||
|
||||
transcription_options: TranscriptionOptions = mock_transcription_options_changed.call_args[0][0]
|
||||
transcription_options: TranscriptionOptions = (
|
||||
mock_transcription_options_changed.call_args[0][0]
|
||||
)
|
||||
assert transcription_options.model.model_type == ModelType.WHISPER_CPP
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
|
@ -10,10 +9,10 @@ class MockNetworkReply(QNetworkReply):
|
|||
def __init__(self, data: object, _: Optional[QObject] = None) -> None:
|
||||
self.data = data
|
||||
|
||||
def readAll(self) -> 'QByteArray':
|
||||
return QByteArray(json.dumps(self.data).encode('utf-8'))
|
||||
def readAll(self) -> "QByteArray":
|
||||
return QByteArray(json.dumps(self.data).encode("utf-8"))
|
||||
|
||||
def error(self) -> 'QNetworkReply.NetworkError':
|
||||
def error(self) -> "QNetworkReply.NetworkError":
|
||||
return QNetworkReply.NetworkError.NoError
|
||||
|
||||
|
||||
|
@ -21,10 +20,12 @@ class MockNetworkAccessManager(QNetworkAccessManager):
|
|||
finished = pyqtSignal(object)
|
||||
reply: MockNetworkReply
|
||||
|
||||
def __init__(self, reply: MockNetworkReply, parent: Optional[QObject] = None) -> None:
|
||||
def __init__(
|
||||
self, reply: MockNetworkReply, parent: Optional[QObject] = None
|
||||
) -> None:
|
||||
super().__init__(parent)
|
||||
self.reply = reply
|
||||
|
||||
def get(self, _: 'QNetworkRequest') -> 'QNetworkReply':
|
||||
def get(self, _: "QNetworkRequest") -> "QNetworkReply":
|
||||
self.finished.emit(self.reply)
|
||||
return self.reply
|
||||
|
|
|
@ -9,34 +9,90 @@ import sounddevice
|
|||
import whisper
|
||||
|
||||
mock_query_devices = [
|
||||
{'name': 'Background Music', 'index': 0, 'hostapi': 0, 'max_input_channels': 2, 'max_output_channels': 2,
|
||||
'default_low_input_latency': 0.01,
|
||||
'default_low_output_latency': 0.008, 'default_high_input_latency': 0.1, 'default_high_output_latency': 0.064,
|
||||
'default_samplerate': 8000.0},
|
||||
{'name': 'Background Music (UI Sounds)', 'index': 1, 'hostapi': 0, 'max_input_channels': 2,
|
||||
'max_output_channels': 2, 'default_low_input_latency': 0.01,
|
||||
'default_low_output_latency': 0.008, 'default_high_input_latency': 0.1, 'default_high_output_latency': 0.064,
|
||||
'default_samplerate': 8000.0},
|
||||
{'name': 'BlackHole 2ch', 'index': 2, 'hostapi': 0, 'max_input_channels': 2, 'max_output_channels': 2,
|
||||
'default_low_input_latency': 0.01,
|
||||
'default_low_output_latency': 0.0013333333333333333, 'default_high_input_latency': 0.1,
|
||||
'default_high_output_latency': 0.010666666666666666, 'default_samplerate': 48000.0},
|
||||
{'name': 'MacBook Pro Microphone', 'index': 3, 'hostapi': 0, 'max_input_channels': 1, 'max_output_channels': 0,
|
||||
'default_low_input_latency': 0.034520833333333334,
|
||||
'default_low_output_latency': 0.01, 'default_high_input_latency': 0.043854166666666666,
|
||||
'default_high_output_latency': 0.1, 'default_samplerate': 48000.0},
|
||||
{'name': 'MacBook Pro Speakers', 'index': 4, 'hostapi': 0, 'max_input_channels': 0, 'max_output_channels': 2,
|
||||
'default_low_input_latency': 0.01,
|
||||
'default_low_output_latency': 0.0070416666666666666, 'default_high_input_latency': 0.1,
|
||||
'default_high_output_latency': 0.016375, 'default_samplerate': 48000.0},
|
||||
{'name': 'Null Audio Device', 'index': 5, 'hostapi': 0, 'max_input_channels': 2, 'max_output_channels': 2,
|
||||
'default_low_input_latency': 0.01,
|
||||
'default_low_output_latency': 0.0014512471655328798, 'default_high_input_latency': 0.1,
|
||||
'default_high_output_latency': 0.011609977324263039, 'default_samplerate': 44100.0},
|
||||
{'name': 'Multi-Output Device', 'index': 6, 'hostapi': 0, 'max_input_channels': 0, 'max_output_channels': 2,
|
||||
'default_low_input_latency': 0.01,
|
||||
'default_low_output_latency': 0.0033333333333333335, 'default_high_input_latency': 0.1,
|
||||
'default_high_output_latency': 0.012666666666666666, 'default_samplerate': 48000.0},
|
||||
{
|
||||
"name": "Background Music",
|
||||
"index": 0,
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 2,
|
||||
"max_output_channels": 2,
|
||||
"default_low_input_latency": 0.01,
|
||||
"default_low_output_latency": 0.008,
|
||||
"default_high_input_latency": 0.1,
|
||||
"default_high_output_latency": 0.064,
|
||||
"default_samplerate": 8000.0,
|
||||
},
|
||||
{
|
||||
"name": "Background Music (UI Sounds)",
|
||||
"index": 1,
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 2,
|
||||
"max_output_channels": 2,
|
||||
"default_low_input_latency": 0.01,
|
||||
"default_low_output_latency": 0.008,
|
||||
"default_high_input_latency": 0.1,
|
||||
"default_high_output_latency": 0.064,
|
||||
"default_samplerate": 8000.0,
|
||||
},
|
||||
{
|
||||
"name": "BlackHole 2ch",
|
||||
"index": 2,
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 2,
|
||||
"max_output_channels": 2,
|
||||
"default_low_input_latency": 0.01,
|
||||
"default_low_output_latency": 0.0013333333333333333,
|
||||
"default_high_input_latency": 0.1,
|
||||
"default_high_output_latency": 0.010666666666666666,
|
||||
"default_samplerate": 48000.0,
|
||||
},
|
||||
{
|
||||
"name": "MacBook Pro Microphone",
|
||||
"index": 3,
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 1,
|
||||
"max_output_channels": 0,
|
||||
"default_low_input_latency": 0.034520833333333334,
|
||||
"default_low_output_latency": 0.01,
|
||||
"default_high_input_latency": 0.043854166666666666,
|
||||
"default_high_output_latency": 0.1,
|
||||
"default_samplerate": 48000.0,
|
||||
},
|
||||
{
|
||||
"name": "MacBook Pro Speakers",
|
||||
"index": 4,
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 0,
|
||||
"max_output_channels": 2,
|
||||
"default_low_input_latency": 0.01,
|
||||
"default_low_output_latency": 0.0070416666666666666,
|
||||
"default_high_input_latency": 0.1,
|
||||
"default_high_output_latency": 0.016375,
|
||||
"default_samplerate": 48000.0,
|
||||
},
|
||||
{
|
||||
"name": "Null Audio Device",
|
||||
"index": 5,
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 2,
|
||||
"max_output_channels": 2,
|
||||
"default_low_input_latency": 0.01,
|
||||
"default_low_output_latency": 0.0014512471655328798,
|
||||
"default_high_input_latency": 0.1,
|
||||
"default_high_output_latency": 0.011609977324263039,
|
||||
"default_samplerate": 44100.0,
|
||||
},
|
||||
{
|
||||
"name": "Multi-Output Device",
|
||||
"index": 6,
|
||||
"hostapi": 0,
|
||||
"max_input_channels": 0,
|
||||
"max_output_channels": 2,
|
||||
"default_low_input_latency": 0.01,
|
||||
"default_low_output_latency": 0.0033333333333333335,
|
||||
"default_high_input_latency": 0.1,
|
||||
"default_high_output_latency": 0.012666666666666666,
|
||||
"default_samplerate": 48000.0,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
|
@ -44,7 +100,12 @@ class MockInputStream(MagicMock):
|
|||
running = False
|
||||
thread: Thread
|
||||
|
||||
def __init__(self, callback: Callable[[np.ndarray, int, Any, sounddevice.CallbackFlags], None], *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
callback: Callable[[np.ndarray, int, Any, sounddevice.CallbackFlags], None],
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(spec=sounddevice.InputStream)
|
||||
self.thread = Thread(target=self.target)
|
||||
self.callback = callback
|
||||
|
@ -54,7 +115,9 @@ class MockInputStream(MagicMock):
|
|||
|
||||
def target(self):
|
||||
sample_rate = whisper.audio.SAMPLE_RATE
|
||||
file_path = os.path.join(os.path.dirname(__file__), '../testdata/whisper-french.mp3')
|
||||
file_path = os.path.join(
|
||||
os.path.dirname(__file__), "../testdata/whisper-french.mp3"
|
||||
)
|
||||
audio = whisper.load_audio(file_path, sr=sample_rate)
|
||||
|
||||
chunk_duration_secs = 1
|
||||
|
@ -65,7 +128,7 @@ class MockInputStream(MagicMock):
|
|||
|
||||
while self.running:
|
||||
time.sleep(chunk_duration_secs)
|
||||
chunk = audio[seek:seek + num_samples_in_chunk]
|
||||
chunk = audio[seek : seek + num_samples_in_chunk]
|
||||
self.callback(chunk, 0, None, sounddevice.CallbackFlags())
|
||||
seek += num_samples_in_chunk
|
||||
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
from buzz.model_loader import TranscriptionModel, ModelDownloader
|
||||
|
||||
|
||||
|
||||
def get_model_path(transcription_model: TranscriptionModel) -> str:
|
||||
path = transcription_model.get_local_model_path()
|
||||
if path is not None:
|
||||
return path
|
||||
|
||||
model_loader = ModelDownloader(model=transcription_model)
|
||||
model_path = ''
|
||||
model_path = ""
|
||||
|
||||
def on_load_model(path: str):
|
||||
nonlocal model_path
|
||||
|
|
|
@ -4,20 +4,32 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
|
||||
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel
|
||||
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, Task, WhisperCppFileTranscriber,
|
||||
TranscriptionOptions, WhisperFileTranscriber, FileTranscriber)
|
||||
from buzz.transcriber import (
|
||||
FileTranscriptionOptions,
|
||||
FileTranscriptionTask,
|
||||
Task,
|
||||
WhisperCppFileTranscriber,
|
||||
TranscriptionOptions,
|
||||
WhisperFileTranscriber,
|
||||
FileTranscriber,
|
||||
)
|
||||
from tests.model_loader import get_model_path
|
||||
|
||||
|
||||
def get_task(model: TranscriptionModel):
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
file_paths=['testdata/whisper-french.mp3'])
|
||||
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
|
||||
word_level_timings=False,
|
||||
model=model)
|
||||
file_paths=["testdata/whisper-french.mp3"]
|
||||
)
|
||||
transcription_options = TranscriptionOptions(
|
||||
language="fr", task=Task.TRANSCRIBE, word_level_timings=False, model=model
|
||||
)
|
||||
model_path = get_model_path(transcription_options.model)
|
||||
return FileTranscriptionTask(file_path='testdata/audio-long.mp3', transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options, model_path=model_path)
|
||||
return FileTranscriptionTask(
|
||||
file_path="testdata/audio-long.mp3",
|
||||
transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
model_path=model_path,
|
||||
)
|
||||
|
||||
|
||||
def transcribe(qtbot, transcriber: FileTranscriber):
|
||||
|
@ -31,24 +43,53 @@ def transcribe(qtbot, transcriber: FileTranscriber):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'transcriber',
|
||||
"transcriber",
|
||||
[
|
||||
pytest.param(
|
||||
WhisperCppFileTranscriber(task=(get_task(
|
||||
TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY)))),
|
||||
id="Whisper.cpp - Tiny"),
|
||||
pytest.param(
|
||||
WhisperFileTranscriber(task=(get_task(
|
||||
TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)))),
|
||||
id="Whisper - Tiny"),
|
||||
pytest.param(
|
||||
WhisperFileTranscriber(task=(get_task(
|
||||
TranscriptionModel(model_type=ModelType.FASTER_WHISPER, whisper_model_size=WhisperModelSize.TINY)))),
|
||||
id="Faster Whisper - Tiny",
|
||||
marks=pytest.mark.skipif(platform.system() == 'Darwin',
|
||||
reason='Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087')
|
||||
WhisperCppFileTranscriber(
|
||||
task=(
|
||||
get_task(
|
||||
TranscriptionModel(
|
||||
model_type=ModelType.WHISPER_CPP,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
id="Whisper.cpp - Tiny",
|
||||
),
|
||||
])
|
||||
pytest.param(
|
||||
WhisperFileTranscriber(
|
||||
task=(
|
||||
get_task(
|
||||
TranscriptionModel(
|
||||
model_type=ModelType.WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
id="Whisper - Tiny",
|
||||
),
|
||||
pytest.param(
|
||||
WhisperFileTranscriber(
|
||||
task=(
|
||||
get_task(
|
||||
TranscriptionModel(
|
||||
model_type=ModelType.FASTER_WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
id="Faster Whisper - Tiny",
|
||||
marks=pytest.mark.skipif(
|
||||
platform.system() == "Darwin",
|
||||
reason="Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_should_transcribe_and_benchmark(qtbot, benchmark, transcriber):
|
||||
segments = benchmark(transcribe, qtbot, transcriber)
|
||||
assert len(segments) > 0
|
||||
|
|
|
@ -11,26 +11,42 @@ from PyQt6.QtCore import QThread
|
|||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel
|
||||
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, Segment, Task, WhisperCpp, WhisperCppFileTranscriber,
|
||||
WhisperFileTranscriber,
|
||||
get_default_output_file_path, to_timestamp,
|
||||
whisper_cpp_params, write_output, TranscriptionOptions)
|
||||
from buzz.transcriber import (
|
||||
FileTranscriptionOptions,
|
||||
FileTranscriptionTask,
|
||||
OutputFormat,
|
||||
Segment,
|
||||
Task,
|
||||
WhisperCpp,
|
||||
WhisperCppFileTranscriber,
|
||||
WhisperFileTranscriber,
|
||||
get_default_output_file_path,
|
||||
to_timestamp,
|
||||
whisper_cpp_params,
|
||||
write_output,
|
||||
TranscriptionOptions,
|
||||
)
|
||||
from buzz.recording_transcriber import RecordingTranscriber
|
||||
from tests.mock_sounddevice import MockInputStream
|
||||
from tests.model_loader import get_model_path
|
||||
|
||||
|
||||
class TestRecordingTranscriber:
|
||||
@pytest.mark.skip(reason='Hanging')
|
||||
@pytest.mark.skip(reason="Hanging")
|
||||
def test_should_transcribe(self, qtbot):
|
||||
thread = QThread()
|
||||
|
||||
transcription_model = TranscriptionModel(model_type=ModelType.WHISPER_CPP,
|
||||
whisper_model_size=WhisperModelSize.TINY)
|
||||
transcription_model = TranscriptionModel(
|
||||
model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY
|
||||
)
|
||||
|
||||
transcriber = RecordingTranscriber(transcription_options=TranscriptionOptions(
|
||||
model=transcription_model, language='fr', task=Task.TRANSCRIBE),
|
||||
input_device_index=0, sample_rate=16_000)
|
||||
transcriber = RecordingTranscriber(
|
||||
transcription_options=TranscriptionOptions(
|
||||
model=transcription_model, language="fr", task=Task.TRANSCRIBE
|
||||
),
|
||||
input_device_index=0,
|
||||
sample_rate=16_000,
|
||||
)
|
||||
transcriber.moveToThread(thread)
|
||||
|
||||
thread.finished.connect(thread.deleteLater)
|
||||
|
@ -41,39 +57,55 @@ class TestRecordingTranscriber:
|
|||
transcriber.finished.connect(thread.quit)
|
||||
transcriber.finished.connect(transcriber.deleteLater)
|
||||
|
||||
with patch('sounddevice.InputStream', side_effect=MockInputStream), patch(
|
||||
'sounddevice.check_input_settings'), qtbot.wait_signal(transcriber.transcription, timeout=60 * 1000):
|
||||
with patch("sounddevice.InputStream", side_effect=MockInputStream), patch(
|
||||
"sounddevice.check_input_settings"
|
||||
), qtbot.wait_signal(transcriber.transcription, timeout=60 * 1000):
|
||||
thread.start()
|
||||
|
||||
with qtbot.wait_signal(thread.finished, timeout=60 * 1000):
|
||||
transcriber.stop_recording()
|
||||
|
||||
text = mock_transcription.call_args[0][0]
|
||||
assert 'Bienvenue dans Passe' in text
|
||||
assert "Bienvenue dans Passe" in text
|
||||
|
||||
|
||||
class TestWhisperCppFileTranscriber:
|
||||
@pytest.mark.parametrize(
|
||||
'word_level_timings,expected_segments',
|
||||
"word_level_timings,expected_segments",
|
||||
[
|
||||
(False, [Segment(0, 6560,
|
||||
'Bienvenue dans Passe-Relle. Un podcast pensé pour')]),
|
||||
(True, [Segment(30, 330, 'Bien'), Segment(330, 740, 'venue')])
|
||||
])
|
||||
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
|
||||
(
|
||||
False,
|
||||
[Segment(0, 6560, "Bienvenue dans Passe-Relle. Un podcast pensé pour")],
|
||||
),
|
||||
(True, [Segment(30, 330, "Bien"), Segment(330, 740, "venue")]),
|
||||
],
|
||||
)
|
||||
def test_transcribe(
|
||||
self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]
|
||||
):
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
file_paths=['testdata/whisper-french.mp3'])
|
||||
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
|
||||
word_level_timings=word_level_timings,
|
||||
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP,
|
||||
whisper_model_size=WhisperModelSize.TINY))
|
||||
file_paths=["testdata/whisper-french.mp3"]
|
||||
)
|
||||
transcription_options = TranscriptionOptions(
|
||||
language="fr",
|
||||
task=Task.TRANSCRIBE,
|
||||
word_level_timings=word_level_timings,
|
||||
model=TranscriptionModel(
|
||||
model_type=ModelType.WHISPER_CPP,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
),
|
||||
)
|
||||
|
||||
model_path = get_model_path(transcription_options.model)
|
||||
transcriber = WhisperCppFileTranscriber(
|
||||
task=FileTranscriptionTask(file_path='testdata/whisper-french.mp3',
|
||||
transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options, model_path=model_path))
|
||||
mock_progress = Mock(side_effect=lambda value: print('progress: ', value))
|
||||
task=FileTranscriptionTask(
|
||||
file_path="testdata/whisper-french.mp3",
|
||||
transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
model_path=model_path,
|
||||
)
|
||||
)
|
||||
mock_progress = Mock(side_effect=lambda value: print("progress: ", value))
|
||||
mock_completed = Mock()
|
||||
transcriber.progress.connect(mock_progress)
|
||||
transcriber.completed.connect(mock_completed)
|
||||
|
@ -81,7 +113,11 @@ class TestWhisperCppFileTranscriber:
|
|||
transcriber.run()
|
||||
|
||||
mock_progress.assert_called()
|
||||
segments = [segment for segment in mock_completed.call_args[0][0] if len(segment.text) > 0]
|
||||
segments = [
|
||||
segment
|
||||
for segment in mock_completed.call_args[0][0]
|
||||
if len(segment.text) > 0
|
||||
]
|
||||
for i, expected_segment in enumerate(expected_segments):
|
||||
assert expected_segment.start == segments[i].start
|
||||
assert expected_segment.end == segments[i].end
|
||||
|
@ -90,82 +126,164 @@ class TestWhisperCppFileTranscriber:
|
|||
|
||||
class TestWhisperFileTranscriber:
|
||||
@pytest.mark.parametrize(
|
||||
'output_format,expected_file_path,default_output_file_name',
|
||||
"output_format,expected_file_path,default_output_file_name",
|
||||
[
|
||||
(OutputFormat.SRT, '/a/b/c-translate--Whisper-tiny.srt', '{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}'),
|
||||
])
|
||||
def test_default_output_file2(self, output_format: OutputFormat, expected_file_path: str, default_output_file_name: str):
|
||||
(
|
||||
OutputFormat.SRT,
|
||||
"/a/b/c-translate--Whisper-tiny.srt",
|
||||
"{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_default_output_file2(
|
||||
self,
|
||||
output_format: OutputFormat,
|
||||
expected_file_path: str,
|
||||
default_output_file_name: str,
|
||||
):
|
||||
file_path = get_default_output_file_path(
|
||||
task=FileTranscriptionTask(
|
||||
file_path='/a/b/c.mp4',
|
||||
file_path="/a/b/c.mp4",
|
||||
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name=default_output_file_name),
|
||||
model_path=''),
|
||||
output_format=output_format)
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=[], default_output_file_name=default_output_file_name
|
||||
),
|
||||
model_path="",
|
||||
),
|
||||
output_format=output_format,
|
||||
)
|
||||
assert file_path == expected_file_path
|
||||
|
||||
def test_default_output_file(self):
|
||||
srt = get_default_output_file_path(
|
||||
task=FileTranscriptionTask(
|
||||
file_path='/a/b/c.mp4',
|
||||
file_path="/a/b/c.mp4",
|
||||
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name='{{ input_file_name }} (Translated on {{ date_time }})'),
|
||||
model_path=''),
|
||||
output_format=OutputFormat.TXT)
|
||||
assert srt.startswith('/a/b/c (Translated on ')
|
||||
assert srt.endswith('.txt')
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=[],
|
||||
default_output_file_name="{{ input_file_name }} (Translated on {{ date_time }})",
|
||||
),
|
||||
model_path="",
|
||||
),
|
||||
output_format=OutputFormat.TXT,
|
||||
)
|
||||
assert srt.startswith("/a/b/c (Translated on ")
|
||||
assert srt.endswith(".txt")
|
||||
|
||||
srt = get_default_output_file_path(
|
||||
task=FileTranscriptionTask(
|
||||
file_path='/a/b/c.mp4',
|
||||
file_path="/a/b/c.mp4",
|
||||
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
|
||||
file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name='{{ input_file_name }} (Translated on {{ date_time }})'),
|
||||
model_path=''),
|
||||
output_format=OutputFormat.SRT)
|
||||
assert srt.startswith('/a/b/c (Translated on ')
|
||||
assert srt.endswith('.srt')
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=[],
|
||||
default_output_file_name="{{ input_file_name }} (Translated on {{ date_time }})",
|
||||
),
|
||||
model_path="",
|
||||
),
|
||||
output_format=OutputFormat.SRT,
|
||||
)
|
||||
assert srt.startswith("/a/b/c (Translated on ")
|
||||
assert srt.endswith(".srt")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'word_level_timings,expected_segments,model,check_progress',
|
||||
"word_level_timings,expected_segments,model,check_progress",
|
||||
[
|
||||
(False, [Segment(0, 6560,
|
||||
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances')],
|
||||
TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True),
|
||||
(True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')],
|
||||
TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True),
|
||||
(False, [Segment(0, 8517,
|
||||
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêyer la curiosité des apprenances '
|
||||
'et des apprenances de français.')],
|
||||
TranscriptionModel(model_type=ModelType.HUGGING_FACE,
|
||||
hugging_face_model_id='openai/whisper-tiny'), False),
|
||||
(
|
||||
False,
|
||||
[
|
||||
Segment(
|
||||
0,
|
||||
6560,
|
||||
" Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances",
|
||||
)
|
||||
],
|
||||
TranscriptionModel(
|
||||
model_type=ModelType.WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
),
|
||||
True,
|
||||
),
|
||||
(
|
||||
True,
|
||||
[Segment(40, 299, " Bien"), Segment(299, 329, "venue dans")],
|
||||
TranscriptionModel(
|
||||
model_type=ModelType.WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
),
|
||||
True,
|
||||
),
|
||||
(
|
||||
False,
|
||||
[
|
||||
Segment(
|
||||
0,
|
||||
8517,
|
||||
" Bienvenue dans Passe-Relle. Un podcast pensé pour évêyer la curiosité des apprenances "
|
||||
"et des apprenances de français.",
|
||||
)
|
||||
],
|
||||
TranscriptionModel(
|
||||
model_type=ModelType.HUGGING_FACE,
|
||||
hugging_face_model_id="openai/whisper-tiny",
|
||||
),
|
||||
False,
|
||||
),
|
||||
pytest.param(
|
||||
False, [Segment(start=0, end=8400,
|
||||
text=' Bienvenue dans Passrel, un podcast pensé pour éveiller la curiosité des apprenances et des apprenances de français.')],
|
||||
TranscriptionModel(model_type=ModelType.FASTER_WHISPER, whisper_model_size=WhisperModelSize.TINY), True,
|
||||
marks=pytest.mark.skipif(platform.system() == 'Darwin',
|
||||
reason='Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087')
|
||||
)
|
||||
])
|
||||
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment],
|
||||
model: TranscriptionModel, check_progress):
|
||||
False,
|
||||
[
|
||||
Segment(
|
||||
start=0,
|
||||
end=8400,
|
||||
text=" Bienvenue dans Passrel, un podcast pensé pour éveiller la curiosité des apprenances et des apprenances de français.",
|
||||
)
|
||||
],
|
||||
TranscriptionModel(
|
||||
model_type=ModelType.FASTER_WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
),
|
||||
True,
|
||||
marks=pytest.mark.skipif(
|
||||
platform.system() == "Darwin",
|
||||
reason="Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087",
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_transcribe(
|
||||
self,
|
||||
qtbot: QtBot,
|
||||
word_level_timings: bool,
|
||||
expected_segments: List[Segment],
|
||||
model: TranscriptionModel,
|
||||
check_progress,
|
||||
):
|
||||
mock_progress = Mock()
|
||||
mock_completed = Mock()
|
||||
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
|
||||
word_level_timings=word_level_timings,
|
||||
model=model)
|
||||
transcription_options = TranscriptionOptions(
|
||||
language="fr",
|
||||
task=Task.TRANSCRIBE,
|
||||
word_level_timings=word_level_timings,
|
||||
model=model,
|
||||
)
|
||||
model_path = get_model_path(transcription_options.model)
|
||||
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'testdata/whisper-french.mp3'))
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
file_paths=[file_path])
|
||||
file_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "testdata/whisper-french.mp3")
|
||||
)
|
||||
file_transcription_options = FileTranscriptionOptions(file_paths=[file_path])
|
||||
|
||||
transcriber = WhisperFileTranscriber(
|
||||
task=FileTranscriptionTask(transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
file_path=file_path, model_path=model_path))
|
||||
task=FileTranscriptionTask(
|
||||
transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
file_path=file_path,
|
||||
model_path=model_path,
|
||||
)
|
||||
)
|
||||
transcriber.progress.connect(mock_progress)
|
||||
transcriber.completed.connect(mock_completed)
|
||||
with qtbot.wait_signal(transcriber.progress, timeout=10 * 6000), qtbot.wait_signal(transcriber.completed,
|
||||
timeout=10 * 6000):
|
||||
with qtbot.wait_signal(
|
||||
transcriber.progress, timeout=10 * 6000
|
||||
), qtbot.wait_signal(transcriber.completed, timeout=10 * 6000):
|
||||
transcriber.run()
|
||||
|
||||
# Skip checking progress...
|
||||
|
@ -182,26 +300,37 @@ class TestWhisperFileTranscriber:
|
|||
mock_completed.assert_called()
|
||||
segments = mock_completed.call_args[0][0]
|
||||
assert len(segments) >= len(expected_segments)
|
||||
for (i, expected_segment) in enumerate(expected_segments):
|
||||
for i, expected_segment in enumerate(expected_segments):
|
||||
assert segments[i] == expected_segment
|
||||
|
||||
@pytest.mark.skip()
|
||||
def test_transcribe_stop(self):
|
||||
output_file_path = os.path.join(tempfile.gettempdir(), 'whisper.txt')
|
||||
output_file_path = os.path.join(tempfile.gettempdir(), "whisper.txt")
|
||||
if os.path.exists(output_file_path):
|
||||
os.remove(output_file_path)
|
||||
|
||||
file_transcription_options = FileTranscriptionOptions(
|
||||
file_paths=['testdata/whisper-french.mp3'])
|
||||
file_paths=["testdata/whisper-french.mp3"]
|
||||
)
|
||||
transcription_options = TranscriptionOptions(
|
||||
language='fr', task=Task.TRANSCRIBE, word_level_timings=False,
|
||||
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY))
|
||||
language="fr",
|
||||
task=Task.TRANSCRIBE,
|
||||
word_level_timings=False,
|
||||
model=TranscriptionModel(
|
||||
model_type=ModelType.WHISPER_CPP,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
),
|
||||
)
|
||||
model_path = get_model_path(transcription_options.model)
|
||||
|
||||
transcriber = WhisperFileTranscriber(
|
||||
task=FileTranscriptionTask(model_path=model_path, transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
file_path='testdata/whisper-french.mp3'))
|
||||
task=FileTranscriptionTask(
|
||||
model_path=model_path,
|
||||
transcription_options=transcription_options,
|
||||
file_transcription_options=file_transcription_options,
|
||||
file_path="testdata/whisper-french.mp3",
|
||||
)
|
||||
)
|
||||
transcriber.run()
|
||||
time.sleep(1)
|
||||
transcriber.stop()
|
||||
|
@ -212,40 +341,54 @@ class TestWhisperFileTranscriber:
|
|||
|
||||
class TestToTimestamp:
|
||||
def test_to_timestamp(self):
|
||||
assert to_timestamp(0) == '00:00:00.000'
|
||||
assert to_timestamp(123456789) == '34:17:36.789'
|
||||
assert to_timestamp(0) == "00:00:00.000"
|
||||
assert to_timestamp(123456789) == "34:17:36.789"
|
||||
|
||||
|
||||
class TestWhisperCpp:
|
||||
def test_transcribe(self):
|
||||
transcription_options = TranscriptionOptions(
|
||||
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY))
|
||||
model=TranscriptionModel(
|
||||
model_type=ModelType.WHISPER_CPP,
|
||||
whisper_model_size=WhisperModelSize.TINY,
|
||||
)
|
||||
)
|
||||
model_path = get_model_path(transcription_options.model)
|
||||
|
||||
whisper_cpp = WhisperCpp(model=model_path)
|
||||
params = whisper_cpp_params(
|
||||
language='fr', task=Task.TRANSCRIBE, word_level_timings=False)
|
||||
language="fr", task=Task.TRANSCRIBE, word_level_timings=False
|
||||
)
|
||||
result = whisper_cpp.transcribe(
|
||||
audio='testdata/whisper-french.mp3', params=params)
|
||||
audio="testdata/whisper-french.mp3", params=params
|
||||
)
|
||||
|
||||
assert 'Bienvenue dans Passe' in result['text']
|
||||
assert "Bienvenue dans Passe" in result["text"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'output_format,output_text',
|
||||
"output_format,output_text",
|
||||
[
|
||||
(OutputFormat.TXT, 'Bien\nvenue dans\n'),
|
||||
(OutputFormat.TXT, "Bien\nvenue dans\n"),
|
||||
(
|
||||
OutputFormat.SRT,
|
||||
'1\n00:00:00,040 --> 00:00:00,299\nBien\n\n2\n00:00:00,299 --> 00:00:00,329\nvenue dans\n\n'),
|
||||
(OutputFormat.VTT,
|
||||
'WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
|
||||
])
|
||||
def test_write_output(tmp_path: pathlib.Path, output_format: OutputFormat, output_text: str):
|
||||
output_file_path = tmp_path / 'whisper.txt'
|
||||
segments = [Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')]
|
||||
OutputFormat.SRT,
|
||||
"1\n00:00:00,040 --> 00:00:00,299\nBien\n\n2\n00:00:00,299 --> 00:00:00,329\nvenue dans\n\n",
|
||||
),
|
||||
(
|
||||
OutputFormat.VTT,
|
||||
"WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_write_output(
|
||||
tmp_path: pathlib.Path, output_format: OutputFormat, output_text: str
|
||||
):
|
||||
output_file_path = tmp_path / "whisper.txt"
|
||||
segments = [Segment(40, 299, "Bien"), Segment(299, 329, "venue dans")]
|
||||
|
||||
write_output(path=str(output_file_path), segments=segments, output_format=output_format)
|
||||
write_output(
|
||||
path=str(output_file_path), segments=segments, output_format=output_format
|
||||
)
|
||||
|
||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
||||
output_file = open(output_file_path, "r", encoding="utf-8")
|
||||
assert output_text == output_file.read()
|
||||
|
|
|
@ -3,8 +3,9 @@ from buzz.transformers_whisper import load_model
|
|||
|
||||
class TestTransformersWhisper:
|
||||
def test_should_transcribe(self):
|
||||
model = load_model('openai/whisper-tiny')
|
||||
model = load_model("openai/whisper-tiny")
|
||||
result = model.transcribe(
|
||||
audio='testdata/whisper-french.mp3', language='fr', task='transcribe')
|
||||
audio="testdata/whisper-french.mp3", language="fr", task="transcribe"
|
||||
)
|
||||
|
||||
assert 'Bienvenue dans Passe' in result['text']
|
||||
assert "Bienvenue dans Passe" in result["text"]
|
||||
|
|
|
@ -9,13 +9,19 @@ from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidg
|
|||
class TestFileTranscriberWidget:
|
||||
def test_should_set_window_title(self, qtbot: QtBot):
|
||||
widget = FileTranscriberWidget(
|
||||
file_paths=['testdata/whisper-french.mp3'], default_output_file_name='', parent=None)
|
||||
file_paths=["testdata/whisper-french.mp3"],
|
||||
default_output_file_name="",
|
||||
parent=None,
|
||||
)
|
||||
qtbot.add_widget(widget)
|
||||
assert widget.windowTitle() == 'whisper-french.mp3'
|
||||
assert widget.windowTitle() == "whisper-french.mp3"
|
||||
|
||||
def test_should_emit_triggered_event(self, qtbot: QtBot):
|
||||
widget = FileTranscriberWidget(
|
||||
file_paths=['testdata/whisper-french.mp3'], default_output_file_name='', parent=None)
|
||||
file_paths=["testdata/whisper-french.mp3"],
|
||||
default_output_file_name="",
|
||||
parent=None,
|
||||
)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
mock_triggered = Mock()
|
||||
|
@ -24,9 +30,11 @@ class TestFileTranscriberWidget:
|
|||
with qtbot.wait_signal(widget.triggered, timeout=30 * 1000):
|
||||
qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton)
|
||||
|
||||
transcription_options, file_transcription_options, model_path = mock_triggered.call_args[
|
||||
0][0]
|
||||
(
|
||||
transcription_options,
|
||||
file_transcription_options,
|
||||
model_path,
|
||||
) = mock_triggered.call_args[0][0]
|
||||
assert transcription_options.language is None
|
||||
assert file_transcription_options.file_paths == [
|
||||
'testdata/whisper-french.mp3']
|
||||
assert file_transcription_options.file_paths == ["testdata/whisper-french.mp3"]
|
||||
assert len(model_path) > 0
|
||||
|
|
|
@ -8,7 +8,7 @@ class TestModelDownloadProgressDialog:
|
|||
def test_should_show_dialog(self, qtbot):
|
||||
dialog = ModelDownloadProgressDialog(model_type=ModelType.WHISPER, parent=None)
|
||||
qtbot.add_widget(dialog)
|
||||
assert dialog.labelText() == 'Downloading model (0%)'
|
||||
assert dialog.labelText() == "Downloading model (0%)"
|
||||
|
||||
def test_should_update_label_on_progress(self, qtbot):
|
||||
dialog = ModelDownloadProgressDialog(model_type=ModelType.WHISPER, parent=None)
|
||||
|
@ -16,12 +16,10 @@ class TestModelDownloadProgressDialog:
|
|||
dialog.set_value(0.0)
|
||||
|
||||
dialog.set_value(0.01)
|
||||
assert dialog.labelText().startswith(
|
||||
'Downloading model (1%')
|
||||
assert dialog.labelText().startswith("Downloading model (1%")
|
||||
|
||||
dialog.set_value(0.1)
|
||||
assert dialog.labelText().startswith(
|
||||
'Downloading model (10%')
|
||||
assert dialog.labelText().startswith("Downloading model (10%")
|
||||
|
||||
# Other windows should not be processing while models are being downloaded
|
||||
def test_should_be_an_application_modal(self, qtbot):
|
||||
|
|
|
@ -7,8 +7,8 @@ class TestModelTypeComboBox:
|
|||
qtbot.add_widget(widget)
|
||||
|
||||
assert widget.count() == 5
|
||||
assert widget.itemText(0) == 'Whisper'
|
||||
assert widget.itemText(1) == 'Whisper.cpp'
|
||||
assert widget.itemText(2) == 'Hugging Face'
|
||||
assert widget.itemText(3) == 'Faster Whisper'
|
||||
assert widget.itemText(4) == 'OpenAI Whisper API'
|
||||
assert widget.itemText(0) == "Whisper"
|
||||
assert widget.itemText(1) == "Whisper.cpp"
|
||||
assert widget.itemText(2) == "Hugging Face"
|
||||
assert widget.itemText(3) == "Faster Whisper"
|
||||
assert widget.itemText(4) == "OpenAI Whisper API"
|
||||
|
|
|
@ -3,14 +3,14 @@ from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
|
|||
|
||||
class TestOpenAIAPIKeyLineEdit:
|
||||
def test_should_emit_key_changed(self, qtbot):
|
||||
line_edit = OpenAIAPIKeyLineEdit(key='')
|
||||
line_edit = OpenAIAPIKeyLineEdit(key="")
|
||||
qtbot.add_widget(line_edit)
|
||||
|
||||
with qtbot.wait_signal(line_edit.key_changed):
|
||||
line_edit.setText('abcdefg')
|
||||
line_edit.setText("abcdefg")
|
||||
|
||||
def test_should_toggle_visibility(self, qtbot):
|
||||
line_edit = OpenAIAPIKeyLineEdit(key='')
|
||||
line_edit = OpenAIAPIKeyLineEdit(key="")
|
||||
qtbot.add_widget(line_edit)
|
||||
|
||||
assert line_edit.echoMode() == OpenAIAPIKeyLineEdit.EchoMode.Password
|
||||
|
|
|
@ -4,32 +4,35 @@ import pytest
|
|||
from PyQt6.QtWidgets import QPushButton, QMessageBox, QLineEdit
|
||||
|
||||
from buzz.store.keyring_store import KeyringStore
|
||||
from buzz.widgets.preferences_dialog.general_preferences_widget import \
|
||||
GeneralPreferencesWidget
|
||||
from buzz.widgets.preferences_dialog.general_preferences_widget import (
|
||||
GeneralPreferencesWidget,
|
||||
)
|
||||
|
||||
|
||||
class TestGeneralPreferencesWidget:
|
||||
def test_should_disable_test_button_if_no_api_key(self, qtbot):
|
||||
widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store(''),
|
||||
default_export_file_name='')
|
||||
widget = GeneralPreferencesWidget(
|
||||
keyring_store=self.get_keyring_store(""), default_export_file_name=""
|
||||
)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
test_button = widget.findChild(QPushButton)
|
||||
assert isinstance(test_button, QPushButton)
|
||||
|
||||
assert test_button.text() == 'Test'
|
||||
assert test_button.text() == "Test"
|
||||
assert not test_button.isEnabled()
|
||||
|
||||
line_edit = widget.findChild(QLineEdit)
|
||||
assert isinstance(line_edit, QLineEdit)
|
||||
line_edit.setText('123')
|
||||
line_edit.setText("123")
|
||||
|
||||
assert test_button.isEnabled()
|
||||
|
||||
def test_should_test_openai_api_key(self, qtbot):
|
||||
widget = GeneralPreferencesWidget(
|
||||
keyring_store=self.get_keyring_store('wrong-api-key'),
|
||||
default_export_file_name='')
|
||||
keyring_store=self.get_keyring_store("wrong-api-key"),
|
||||
default_export_file_name="",
|
||||
)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
test_button = widget.findChild(QPushButton)
|
||||
|
@ -42,9 +45,11 @@ class TestGeneralPreferencesWidget:
|
|||
|
||||
def mock_called():
|
||||
mock.assert_called()
|
||||
assert mock.call_args[0][1] == 'OpenAI API Key Test'
|
||||
assert mock.call_args[0][
|
||||
2] == 'Incorrect API key provided: wrong-ap*-key. You can find your API key at https://platform.openai.com/account/api-keys.'
|
||||
assert mock.call_args[0][1] == "OpenAI API Key Test"
|
||||
assert (
|
||||
mock.call_args[0][2]
|
||||
== "Incorrect API key provided: wrong-ap*-key. You can find your API key at https://platform.openai.com/account/api-keys."
|
||||
)
|
||||
|
||||
qtbot.waitUntil(mock_called)
|
||||
|
||||
|
|
|
@ -5,16 +5,20 @@ from PyQt6.QtCore import Qt
|
|||
from PyQt6.QtWidgets import QComboBox, QPushButton
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.model_loader import get_whisper_file_path, WhisperModelSize, \
|
||||
TranscriptionModel, \
|
||||
ModelType
|
||||
from buzz.widgets.preferences_dialog.models_preferences_widget import \
|
||||
ModelsPreferencesWidget
|
||||
from buzz.model_loader import (
|
||||
get_whisper_file_path,
|
||||
WhisperModelSize,
|
||||
TranscriptionModel,
|
||||
ModelType,
|
||||
)
|
||||
from buzz.widgets.preferences_dialog.models_preferences_widget import (
|
||||
ModelsPreferencesWidget,
|
||||
)
|
||||
from tests.model_loader import get_model_path
|
||||
|
||||
|
||||
class TestModelsPreferencesWidget:
|
||||
@pytest.fixture(scope='class')
|
||||
@pytest.fixture(scope="class")
|
||||
def clear_model_cache(self):
|
||||
file_path = get_whisper_file_path(size=WhisperModelSize.TINY)
|
||||
if os.path.isfile(file_path):
|
||||
|
@ -25,10 +29,10 @@ class TestModelsPreferencesWidget:
|
|||
qtbot.add_widget(widget)
|
||||
|
||||
first_item = widget.model_list_widget.topLevelItem(0)
|
||||
assert first_item.text(0) == 'Downloaded'
|
||||
assert first_item.text(0) == "Downloaded"
|
||||
|
||||
second_item = widget.model_list_widget.topLevelItem(1)
|
||||
assert second_item.text(0) == 'Available for Download'
|
||||
assert second_item.text(0) == "Available for Download"
|
||||
|
||||
def test_should_change_model_type(self, qtbot):
|
||||
widget = ModelsPreferencesWidget()
|
||||
|
@ -36,36 +40,38 @@ class TestModelsPreferencesWidget:
|
|||
|
||||
combo_box = widget.findChild(QComboBox)
|
||||
assert isinstance(combo_box, QComboBox)
|
||||
combo_box.setCurrentText('Faster Whisper')
|
||||
combo_box.setCurrentText("Faster Whisper")
|
||||
|
||||
first_item = widget.model_list_widget.topLevelItem(0)
|
||||
assert first_item.text(0) == 'Downloaded'
|
||||
assert first_item.text(0) == "Downloaded"
|
||||
|
||||
second_item = widget.model_list_widget.topLevelItem(1)
|
||||
assert second_item.text(0) == 'Available for Download'
|
||||
assert second_item.text(0) == "Available for Download"
|
||||
|
||||
def test_should_download_model(self, qtbot: QtBot, clear_model_cache):
|
||||
# make progress dialog non-modal to unblock qtbot.wait_until
|
||||
widget = ModelsPreferencesWidget(
|
||||
progress_dialog_modality=Qt.WindowModality.NonModal)
|
||||
progress_dialog_modality=Qt.WindowModality.NonModal
|
||||
)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
model = TranscriptionModel(model_type=ModelType.WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY)
|
||||
model = TranscriptionModel(
|
||||
model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY
|
||||
)
|
||||
|
||||
assert model.get_local_model_path() is None
|
||||
|
||||
available_item = widget.model_list_widget.topLevelItem(1)
|
||||
assert available_item.text(0) == 'Available for Download'
|
||||
assert available_item.text(0) == "Available for Download"
|
||||
|
||||
tiny_item = available_item.child(0)
|
||||
assert tiny_item.text(0) == 'Tiny'
|
||||
assert tiny_item.text(0) == "Tiny"
|
||||
tiny_item.setSelected(True)
|
||||
|
||||
download_button = widget.findChild(QPushButton, 'DownloadButton')
|
||||
download_button = widget.findChild(QPushButton, "DownloadButton")
|
||||
assert isinstance(download_button, QPushButton)
|
||||
|
||||
assert download_button.text() == 'Download'
|
||||
assert download_button.text() == "Download"
|
||||
download_button.click()
|
||||
|
||||
def downloaded_model():
|
||||
|
@ -73,22 +79,26 @@ class TestModelsPreferencesWidget:
|
|||
|
||||
_downloaded_item = widget.model_list_widget.topLevelItem(0)
|
||||
assert _downloaded_item.childCount() > 0
|
||||
assert _downloaded_item.child(0).text(0) == 'Tiny'
|
||||
assert _downloaded_item.child(0).text(0) == "Tiny"
|
||||
|
||||
_available_item = widget.model_list_widget.topLevelItem(1)
|
||||
assert _available_item.childCount() == 0 or _available_item.child(0).text(
|
||||
0) != 'Tiny'
|
||||
assert (
|
||||
_available_item.childCount() == 0
|
||||
or _available_item.child(0).text(0) != "Tiny"
|
||||
)
|
||||
|
||||
# model file exists
|
||||
assert os.path.isfile(get_whisper_file_path(size=model.whisper_model_size))
|
||||
|
||||
qtbot.wait_until(callback=downloaded_model, timeout=60_000)
|
||||
|
||||
@pytest.fixture(scope='class')
|
||||
@pytest.fixture(scope="class")
|
||||
def whisper_tiny_model_path(self) -> str:
|
||||
return get_model_path(transcription_model=TranscriptionModel(
|
||||
model_type=ModelType.WHISPER,
|
||||
whisper_model_size=WhisperModelSize.TINY))
|
||||
return get_model_path(
|
||||
transcription_model=TranscriptionModel(
|
||||
model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY
|
||||
)
|
||||
)
|
||||
|
||||
def test_should_show_downloaded_model(self, qtbot, whisper_tiny_model_path):
|
||||
widget = ModelsPreferencesWidget()
|
||||
|
@ -96,15 +106,16 @@ class TestModelsPreferencesWidget:
|
|||
qtbot.add_widget(widget)
|
||||
|
||||
available_item = widget.model_list_widget.topLevelItem(0)
|
||||
assert available_item.text(0) == 'Downloaded'
|
||||
assert available_item.text(0) == "Downloaded"
|
||||
|
||||
tiny_item = available_item.child(0)
|
||||
assert tiny_item.text(0) == 'Tiny'
|
||||
assert tiny_item.text(0) == "Tiny"
|
||||
tiny_item.setSelected(True)
|
||||
|
||||
delete_button = widget.findChild(QPushButton, 'DeleteButton')
|
||||
delete_button = widget.findChild(QPushButton, "DeleteButton")
|
||||
assert delete_button.isVisible()
|
||||
|
||||
show_file_location_button = widget.findChild(QPushButton,
|
||||
'ShowFileLocationButton')
|
||||
show_file_location_button = widget.findChild(
|
||||
QPushButton, "ShowFileLocationButton"
|
||||
)
|
||||
assert show_file_location_button.isVisible()
|
||||
|
|
|
@ -6,14 +6,14 @@ from buzz.widgets.preferences_dialog.preferences_dialog import PreferencesDialog
|
|||
|
||||
class TestPreferencesDialog:
|
||||
def test_create(self, qtbot: QtBot):
|
||||
dialog = PreferencesDialog(shortcuts={}, default_export_file_name='')
|
||||
dialog = PreferencesDialog(shortcuts={}, default_export_file_name="")
|
||||
qtbot.add_widget(dialog)
|
||||
|
||||
assert dialog.windowTitle() == 'Preferences'
|
||||
assert dialog.windowTitle() == "Preferences"
|
||||
|
||||
tab_widget = dialog.findChild(QTabWidget)
|
||||
assert isinstance(tab_widget, QTabWidget)
|
||||
assert tab_widget.count() == 3
|
||||
assert tab_widget.tabText(0) == 'General'
|
||||
assert tab_widget.tabText(1) == 'Models'
|
||||
assert tab_widget.tabText(2) == 'Shortcuts'
|
||||
assert tab_widget.tabText(0) == "General"
|
||||
assert tab_widget.tabText(1) == "Models"
|
||||
assert tab_widget.tabText(2) == "Shortcuts"
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
from PyQt6.QtWidgets import QPushButton, QLabel
|
||||
|
||||
from buzz.settings.shortcut import Shortcut
|
||||
from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import \
|
||||
ShortcutsEditorPreferencesWidget
|
||||
from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import (
|
||||
ShortcutsEditorPreferencesWidget,
|
||||
)
|
||||
from buzz.widgets.sequence_edit import SequenceEdit
|
||||
|
||||
|
||||
class TestShortcutsEditorWidget:
|
||||
def test_should_reset_to_defaults(self, qtbot):
|
||||
widget = ShortcutsEditorPreferencesWidget(shortcuts=Shortcut.get_default_shortcuts())
|
||||
widget = ShortcutsEditorPreferencesWidget(
|
||||
shortcuts=Shortcut.get_default_shortcuts()
|
||||
)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
reset_button = widget.findChild(QPushButton)
|
||||
|
@ -19,12 +22,13 @@ class TestShortcutsEditorWidget:
|
|||
sequence_edits = widget.findChildren(SequenceEdit)
|
||||
|
||||
expected = (
|
||||
('Open Record Window', 'Ctrl+R'),
|
||||
('Import File', 'Ctrl+O'),
|
||||
('Open Preferences Window', 'Ctrl+,'),
|
||||
('Open Transcript Viewer', 'Ctrl+E'),
|
||||
('Clear History', 'Ctrl+S'),
|
||||
('Cancel Transcription', 'Ctrl+X'))
|
||||
("Open Record Window", "Ctrl+R"),
|
||||
("Import File", "Ctrl+O"),
|
||||
("Open Preferences Window", "Ctrl+,"),
|
||||
("Open Transcript Viewer", "Ctrl+E"),
|
||||
("Clear History", "Ctrl+S"),
|
||||
("Cancel Transcription", "Ctrl+X"),
|
||||
)
|
||||
|
||||
for i, (label, sequence_edit) in enumerate(zip(labels, sequence_edits)):
|
||||
assert isinstance(label, QLabel)
|
||||
|
|
|
@ -2,57 +2,70 @@ import datetime
|
|||
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.transcriber import FileTranscriptionTask, TranscriptionOptions, FileTranscriptionOptions
|
||||
from buzz.transcriber import (
|
||||
FileTranscriptionTask,
|
||||
TranscriptionOptions,
|
||||
FileTranscriptionOptions,
|
||||
)
|
||||
from buzz.widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget
|
||||
|
||||
|
||||
class TestTranscriptionTasksTableWidget:
|
||||
|
||||
def test_upsert_task(self, qtbot: QtBot):
|
||||
widget = TranscriptionTasksTableWidget()
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
task = FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3',
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=['testdata/whisper-french.mp3']), model_path='',
|
||||
status=FileTranscriptionTask.Status.QUEUED)
|
||||
task = FileTranscriptionTask(
|
||||
id=0,
|
||||
file_path="testdata/whisper-french.mp3",
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=["testdata/whisper-french.mp3"]
|
||||
),
|
||||
model_path="",
|
||||
status=FileTranscriptionTask.Status.QUEUED,
|
||||
)
|
||||
task.queued_at = datetime.datetime(2023, 4, 12, 0, 0, 0)
|
||||
task.started_at = datetime.datetime(2023, 4, 12, 0, 0, 5)
|
||||
|
||||
widget.upsert_task(task)
|
||||
|
||||
assert widget.rowCount() == 1
|
||||
assert widget.item(0, 1).text() == 'whisper-french.mp3'
|
||||
assert widget.item(0, 2).text() == 'Queued'
|
||||
assert widget.item(0, 1).text() == "whisper-french.mp3"
|
||||
assert widget.item(0, 2).text() == "Queued"
|
||||
|
||||
task.status = FileTranscriptionTask.Status.IN_PROGRESS
|
||||
task.fraction_completed = 0.3524
|
||||
widget.upsert_task(task)
|
||||
|
||||
assert widget.rowCount() == 1
|
||||
assert widget.item(0, 1).text() == 'whisper-french.mp3'
|
||||
assert widget.item(0, 2).text() == 'In Progress (35%)'
|
||||
assert widget.item(0, 1).text() == "whisper-french.mp3"
|
||||
assert widget.item(0, 2).text() == "In Progress (35%)"
|
||||
|
||||
task.status = FileTranscriptionTask.Status.COMPLETED
|
||||
task.completed_at = datetime.datetime(2023, 4, 12, 0, 0, 10)
|
||||
widget.upsert_task(task)
|
||||
|
||||
assert widget.rowCount() == 1
|
||||
assert widget.item(0, 1).text() == 'whisper-french.mp3'
|
||||
assert widget.item(0, 2).text() == 'Completed (5s)'
|
||||
assert widget.item(0, 1).text() == "whisper-french.mp3"
|
||||
assert widget.item(0, 2).text() == "Completed (5s)"
|
||||
|
||||
def test_upsert_task_no_timings(self, qtbot: QtBot):
|
||||
widget = TranscriptionTasksTableWidget()
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
task = FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3',
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=['testdata/whisper-french.mp3']), model_path='',
|
||||
status=FileTranscriptionTask.Status.COMPLETED)
|
||||
task = FileTranscriptionTask(
|
||||
id=0,
|
||||
file_path="testdata/whisper-french.mp3",
|
||||
transcription_options=TranscriptionOptions(),
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=["testdata/whisper-french.mp3"]
|
||||
),
|
||||
model_path="",
|
||||
status=FileTranscriptionTask.Status.COMPLETED,
|
||||
)
|
||||
widget.upsert_task(task)
|
||||
|
||||
assert widget.rowCount() == 1
|
||||
assert widget.item(0, 1).text() == 'whisper-french.mp3'
|
||||
assert widget.item(0, 2).text() == 'Completed'
|
||||
assert widget.item(0, 1).text() == "whisper-french.mp3"
|
||||
assert widget.item(0, 2).text() == "Completed"
|
||||
|
|
|
@ -6,69 +6,85 @@ from PyQt6.QtGui import QKeyEvent
|
|||
from PyQt6.QtWidgets import QPushButton, QToolBar, QToolButton
|
||||
from pytestqt.qtbot import QtBot
|
||||
|
||||
from buzz.transcriber import FileTranscriptionTask, FileTranscriptionOptions, TranscriptionOptions, Segment
|
||||
from buzz.widgets.transcription_segments_editor_widget import TranscriptionSegmentsEditorWidget
|
||||
from buzz.transcriber import (
|
||||
FileTranscriptionTask,
|
||||
FileTranscriptionOptions,
|
||||
TranscriptionOptions,
|
||||
Segment,
|
||||
)
|
||||
from buzz.widgets.transcription_segments_editor_widget import (
|
||||
TranscriptionSegmentsEditorWidget,
|
||||
)
|
||||
from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget
|
||||
|
||||
|
||||
class TestTranscriptionViewerWidget:
|
||||
@pytest.fixture()
|
||||
def task(self) -> FileTranscriptionTask:
|
||||
return FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3',
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=['testdata/whisper-french.mp3']),
|
||||
transcription_options=TranscriptionOptions(),
|
||||
segments=[Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')],
|
||||
model_path='')
|
||||
return FileTranscriptionTask(
|
||||
id=0,
|
||||
file_path="testdata/whisper-french.mp3",
|
||||
file_transcription_options=FileTranscriptionOptions(
|
||||
file_paths=["testdata/whisper-french.mp3"]
|
||||
),
|
||||
transcription_options=TranscriptionOptions(),
|
||||
segments=[Segment(40, 299, "Bien"), Segment(299, 329, "venue dans")],
|
||||
model_path="",
|
||||
)
|
||||
|
||||
def test_should_display_segments(self, qtbot: QtBot, task):
|
||||
widget = TranscriptionViewerWidget(
|
||||
transcription_task=task, open_transcription_output=False)
|
||||
transcription_task=task, open_transcription_output=False
|
||||
)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
assert widget.windowTitle() == 'whisper-french.mp3'
|
||||
assert widget.windowTitle() == "whisper-french.mp3"
|
||||
|
||||
editor = widget.findChild(TranscriptionSegmentsEditorWidget)
|
||||
assert isinstance(editor, TranscriptionSegmentsEditorWidget)
|
||||
|
||||
assert editor.item(0, 0).text() == '00:00:00.040'
|
||||
assert editor.item(0, 1).text() == '00:00:00.299'
|
||||
assert editor.item(0, 2).text() == 'Bien'
|
||||
assert editor.item(0, 0).text() == "00:00:00.040"
|
||||
assert editor.item(0, 1).text() == "00:00:00.299"
|
||||
assert editor.item(0, 2).text() == "Bien"
|
||||
|
||||
def test_should_update_segment_text(self, qtbot, task):
|
||||
widget = TranscriptionViewerWidget(
|
||||
transcription_task=task, open_transcription_output=False)
|
||||
transcription_task=task, open_transcription_output=False
|
||||
)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
editor = widget.findChild(TranscriptionSegmentsEditorWidget)
|
||||
assert isinstance(editor, TranscriptionSegmentsEditorWidget)
|
||||
|
||||
# Change text
|
||||
editor.item(0, 2).setText('Biens')
|
||||
assert task.segments[0].text == 'Biens'
|
||||
editor.item(0, 2).setText("Biens")
|
||||
assert task.segments[0].text == "Biens"
|
||||
|
||||
# Undo
|
||||
toolbar = widget.findChild(QToolBar)
|
||||
undo_action, redo_action = toolbar.actions()
|
||||
|
||||
undo_action.trigger()
|
||||
assert task.segments[0].text == 'Bien'
|
||||
assert task.segments[0].text == "Bien"
|
||||
|
||||
redo_action.trigger()
|
||||
assert task.segments[0].text == 'Biens'
|
||||
assert task.segments[0].text == "Biens"
|
||||
|
||||
def test_should_export_segments(self, tmp_path: pathlib.Path, qtbot: QtBot, task):
|
||||
widget = TranscriptionViewerWidget(
|
||||
transcription_task=task, open_transcription_output=False)
|
||||
transcription_task=task, open_transcription_output=False
|
||||
)
|
||||
qtbot.add_widget(widget)
|
||||
|
||||
export_button = widget.findChild(QPushButton)
|
||||
assert isinstance(export_button, QPushButton)
|
||||
|
||||
output_file_path = tmp_path / 'whisper.txt'
|
||||
with patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock:
|
||||
save_file_name_mock.return_value = (str(output_file_path), '')
|
||||
output_file_path = tmp_path / "whisper.txt"
|
||||
with patch(
|
||||
"PyQt6.QtWidgets.QFileDialog.getSaveFileName"
|
||||
) as save_file_name_mock:
|
||||
save_file_name_mock.return_value = (str(output_file_path), "")
|
||||
export_button.menu().actions()[0].trigger()
|
||||
|
||||
output_file = open(output_file_path, 'r', encoding='utf-8')
|
||||
assert 'Bien\nvenue dans' in output_file.read()
|
||||
output_file = open(output_file_path, "r", encoding="utf-8")
|
||||
assert "Bien\nvenue dans" in output_file.read()
|
||||
|
|
Loading…
Reference in a new issue