Add black formatting (#571)

This commit is contained in:
Chidi Williams 2023-08-18 23:32:18 +01:00 committed by GitHub
parent f5f77b3908
commit c498e60949
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
66 changed files with 2544 additions and 1372 deletions

View file

@ -2,7 +2,7 @@ import subprocess
def build(setup_kwargs):
subprocess.call(['make', 'buzz/whisper_cpp.py'])
subprocess.call(["make", "buzz/whisper_cpp.py"])
if __name__ == "__main__":

View file

@ -4,7 +4,10 @@ from PyQt6.QtGui import QAction, QKeySequence
class Action(QAction):
def setShortcut(self, shortcut: typing.Union['QKeySequence', 'QKeySequence.StandardKey', str, int]) -> None:
def setShortcut(
self,
shortcut: typing.Union["QKeySequence", "QKeySequence.StandardKey", str, int],
) -> None:
super().setShortcut(shortcut)
self.setToolTip(Action.get_tooltip(self))

View file

@ -3,6 +3,6 @@ import sys
def get_asset_path(path: str):
if getattr(sys, 'frozen', False):
if getattr(sys, "frozen", False):
return os.path.join(os.path.dirname(sys.executable), path)
return os.path.join(os.path.dirname(__file__), '..', path)
return os.path.join(os.path.dirname(__file__), "..", path)

View file

@ -9,7 +9,7 @@ from typing import TextIO
from appdirs import user_log_dir
# Check for segfaults if not running in frozen mode
if getattr(sys, 'frozen', False) is False:
if getattr(sys, "frozen", False) is False:
faulthandler.enable()
# Sets stderr to no-op TextIO when None (run as Windows GUI).
@ -19,30 +19,35 @@ if sys.stderr is None:
# Adds the current directory to the PATH, so the ffmpeg binary get picked up:
# https://stackoverflow.com/a/44352931/9830227
app_dir = getattr(sys, '_MEIPASS', os.path.dirname(
os.path.abspath(__file__)))
app_dir = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
os.environ["PATH"] += os.pathsep + app_dir
# Add the app directory to the DLL list: https://stackoverflow.com/a/64303856
if platform.system() == 'Windows':
if platform.system() == "Windows":
os.add_dll_directory(app_dir)
def main():
if platform.system() == 'Linux':
multiprocessing.set_start_method('spawn')
if platform.system() == "Linux":
multiprocessing.set_start_method("spawn")
# Fixes opening new window when app has been frozen on Windows:
# https://stackoverflow.com/a/33979091
multiprocessing.freeze_support()
log_dir = user_log_dir(appname='Buzz')
log_dir = user_log_dir(appname="Buzz")
os.makedirs(log_dir, exist_ok=True)
log_format = "[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s"
logging.basicConfig(filename=os.path.join(log_dir, 'logs.txt'), level=logging.DEBUG, format=log_format)
log_format = (
"[%(asctime)s] %(module)s.%(funcName)s:%(lineno)d %(levelname)s -> %(message)s"
)
logging.basicConfig(
filename=os.path.join(log_dir, "logs.txt"),
level=logging.DEBUG,
format=log_format,
)
if getattr(sys, 'frozen', False) is False:
if getattr(sys, "frozen", False) is False:
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setLevel(logging.DEBUG)
stdout_handler.setFormatter(logging.Formatter(log_format))

View file

@ -9,11 +9,11 @@ from .transcriber import FileTranscriptionTask
class TasksCache:
def __init__(self, cache_dir=user_cache_dir('Buzz')):
def __init__(self, cache_dir=user_cache_dir("Buzz")):
os.makedirs(cache_dir, exist_ok=True)
self.cache_dir = cache_dir
self.pickle_cache_file_path = os.path.join(cache_dir, 'tasks')
self.tasks_list_file_path = os.path.join(cache_dir, 'tasks.json')
self.pickle_cache_file_path = os.path.join(cache_dir, "tasks")
self.tasks_list_file_path = os.path.join(cache_dir, "tasks.json")
def save(self, tasks: List[FileTranscriptionTask]):
self.save_json_tasks(tasks=tasks)
@ -23,16 +23,20 @@ class TasksCache:
return self.load_json_tasks()
try:
with open(self.pickle_cache_file_path, 'rb') as file:
with open(self.pickle_cache_file_path, "rb") as file:
return pickle.load(file)
except FileNotFoundError:
return []
except (pickle.UnpicklingError, AttributeError, ValueError): # delete corrupted cache
except (
pickle.UnpicklingError,
AttributeError,
ValueError,
): # delete corrupted cache
os.remove(self.pickle_cache_file_path)
return []
def load_json_tasks(self) -> List[FileTranscriptionTask]:
with open(self.tasks_list_file_path, 'r') as file:
with open(self.tasks_list_file_path, "r") as file:
task_ids = json.load(file)
tasks = []
@ -57,7 +61,7 @@ class TasksCache:
file.write(json_str)
def get_task_path(self, task_id: int):
path = os.path.join(self.cache_dir, 'transcriptions', f'{task_id}.json')
path = os.path.join(self.cache_dir, "transcriptions", f"{task_id}.json")
os.makedirs(os.path.dirname(path), exist_ok=True)
return path

View file

@ -7,8 +7,14 @@ from PyQt6.QtCore import QCommandLineParser, QCommandLineOption
from buzz.gui import Application
from buzz.model_loader import ModelType, WhisperModelSize, TranscriptionModel
from buzz.store.keyring_store import KeyringStore
from buzz.transcriber import Task, FileTranscriptionTask, FileTranscriptionOptions, TranscriptionOptions, LANGUAGES, \
OutputFormat
from buzz.transcriber import (
Task,
FileTranscriptionTask,
FileTranscriptionOptions,
TranscriptionOptions,
LANGUAGES,
OutputFormat,
)
class CommandLineError(Exception):
@ -17,11 +23,11 @@ class CommandLineError(Exception):
class CommandLineModelType(enum.Enum):
WHISPER = 'whisper'
WHISPER_CPP = 'whispercpp'
HUGGING_FACE = 'huggingface'
FASTER_WHISPER = 'fasterwhisper'
OPEN_AI_WHISPER_API = 'openaiapi'
WHISPER = "whisper"
WHISPER_CPP = "whispercpp"
HUGGING_FACE = "huggingface"
FASTER_WHISPER = "fasterwhisper"
OPEN_AI_WHISPER_API = "openaiapi"
def parse_command_line(app: Application):
@ -29,13 +35,13 @@ def parse_command_line(app: Application):
try:
parse(app, parser)
except CommandLineError as exc:
print(f'Error: {str(exc)}\n', file=sys.stderr)
print(f"Error: {str(exc)}\n", file=sys.stderr)
print(parser.helpText())
sys.exit(1)
def parse(app: Application, parser: QCommandLineParser):
parser.addPositionalArgument('<command>', 'One of the following commands:\n- add')
parser.addPositionalArgument("<command>", "One of the following commands:\n- add")
parser.parse(app.arguments())
args = parser.positionalArguments()
@ -50,36 +56,63 @@ def parse(app: Application, parser: QCommandLineParser):
if command == "add":
parser.clearPositionalArguments()
parser.addPositionalArgument('files', 'Input file paths', '[file file file...]')
parser.addPositionalArgument("files", "Input file paths", "[file file file...]")
task_option = QCommandLineOption(['t', 'task'],
f'The task to perform. Allowed: {join_values(Task)}. Default: {Task.TRANSCRIBE.value}.',
'task',
Task.TRANSCRIBE.value)
model_type_option = QCommandLineOption(['m', 'model-type'],
f'Model type. Allowed: {join_values(CommandLineModelType)}. Default: {CommandLineModelType.WHISPER.value}.',
'model-type',
CommandLineModelType.WHISPER.value)
model_size_option = QCommandLineOption(['s', 'model-size'],
f'Model size. Use only when --model-type is whisper, whispercpp, or fasterwhisper. Allowed: {join_values(WhisperModelSize)}. Default: {WhisperModelSize.TINY.value}.',
'model-size', WhisperModelSize.TINY.value)
hugging_face_model_id_option = QCommandLineOption(['hfid'],
f'Hugging Face model ID. Use only when --model-type is huggingface. Example: "openai/whisper-tiny"',
'id')
language_option = QCommandLineOption(['l', 'language'],
f'Language code. Allowed: {", ".join(sorted([k + " (" + LANGUAGES[k].title() + ")" for k in LANGUAGES]))}. Leave empty to detect language.',
'code', '')
initial_prompt_option = QCommandLineOption(['p', 'prompt'], f'Initial prompt', 'prompt', '')
open_ai_access_token_option = QCommandLineOption('openai-token',
f'OpenAI access token. Use only when --model-type is {CommandLineModelType.OPEN_AI_WHISPER_API.value}. Defaults to your previously saved access token, if one exists.',
'token')
srt_option = QCommandLineOption(['srt'], 'Output result in an SRT file.')
vtt_option = QCommandLineOption(['vtt'], 'Output result in a VTT file.')
txt_option = QCommandLineOption('txt', 'Output result in a TXT file.')
task_option = QCommandLineOption(
["t", "task"],
f"The task to perform. Allowed: {join_values(Task)}. Default: {Task.TRANSCRIBE.value}.",
"task",
Task.TRANSCRIBE.value,
)
model_type_option = QCommandLineOption(
["m", "model-type"],
f"Model type. Allowed: {join_values(CommandLineModelType)}. Default: {CommandLineModelType.WHISPER.value}.",
"model-type",
CommandLineModelType.WHISPER.value,
)
model_size_option = QCommandLineOption(
["s", "model-size"],
f"Model size. Use only when --model-type is whisper, whispercpp, or fasterwhisper. Allowed: {join_values(WhisperModelSize)}. Default: {WhisperModelSize.TINY.value}.",
"model-size",
WhisperModelSize.TINY.value,
)
hugging_face_model_id_option = QCommandLineOption(
["hfid"],
f'Hugging Face model ID. Use only when --model-type is huggingface. Example: "openai/whisper-tiny"',
"id",
)
language_option = QCommandLineOption(
["l", "language"],
f'Language code. Allowed: {", ".join(sorted([k + " (" + LANGUAGES[k].title() + ")" for k in LANGUAGES]))}. Leave empty to detect language.',
"code",
"",
)
initial_prompt_option = QCommandLineOption(
["p", "prompt"], f"Initial prompt", "prompt", ""
)
open_ai_access_token_option = QCommandLineOption(
"openai-token",
f"OpenAI access token. Use only when --model-type is {CommandLineModelType.OPEN_AI_WHISPER_API.value}. Defaults to your previously saved access token, if one exists.",
"token",
)
srt_option = QCommandLineOption(["srt"], "Output result in an SRT file.")
vtt_option = QCommandLineOption(["vtt"], "Output result in a VTT file.")
txt_option = QCommandLineOption("txt", "Output result in a TXT file.")
parser.addOptions(
[task_option, model_type_option, model_size_option, hugging_face_model_id_option, language_option,
initial_prompt_option, open_ai_access_token_option, srt_option, vtt_option, txt_option])
[
task_option,
model_type_option,
model_size_option,
hugging_face_model_id_option,
language_option,
initial_prompt_option,
open_ai_access_token_option,
srt_option,
vtt_option,
txt_option,
]
)
parser.addHelpOption()
parser.addVersionOption()
@ -89,7 +122,7 @@ def parse(app: Application, parser: QCommandLineParser):
# slice after first argument, the command
file_paths = parser.positionalArguments()[1:]
if len(file_paths) == 0:
raise CommandLineError('No input files')
raise CommandLineError("No input files")
task = parse_enum_option(task_option, parser, Task)
@ -98,21 +131,29 @@ def parse(app: Application, parser: QCommandLineParser):
hugging_face_model_id = parser.value(hugging_face_model_id_option)
if hugging_face_model_id == '' and model_type == CommandLineModelType.HUGGING_FACE:
raise CommandLineError('--hfid is required when --model-type is huggingface')
if (
hugging_face_model_id == ""
and model_type == CommandLineModelType.HUGGING_FACE
):
raise CommandLineError(
"--hfid is required when --model-type is huggingface"
)
model = TranscriptionModel(model_type=ModelType[model_type.name], whisper_model_size=model_size,
hugging_face_model_id=hugging_face_model_id)
model = TranscriptionModel(
model_type=ModelType[model_type.name],
whisper_model_size=model_size,
hugging_face_model_id=hugging_face_model_id,
)
model_path = model.get_local_model_path()
if model_path is None:
raise CommandLineError('Model not found')
raise CommandLineError("Model not found")
language = parser.value(language_option)
if language == '':
if language == "":
language = None
elif LANGUAGES.get(language) is None:
raise CommandLineError('Invalid language option')
raise CommandLineError("Invalid language option")
initial_prompt = parser.value(initial_prompt_option)
@ -125,33 +166,49 @@ def parse(app: Application, parser: QCommandLineParser):
output_formats.add(OutputFormat.TXT)
openai_access_token = parser.value(open_ai_access_token_option)
if model.model_type == ModelType.OPEN_AI_WHISPER_API and openai_access_token == '':
openai_access_token = KeyringStore().get_password(key=KeyringStore.Key.OPENAI_API_KEY)
if (
model.model_type == ModelType.OPEN_AI_WHISPER_API
and openai_access_token == ""
):
openai_access_token = KeyringStore().get_password(
key=KeyringStore.Key.OPENAI_API_KEY
)
if openai_access_token == '':
raise CommandLineError('No OpenAI access token found')
if openai_access_token == "":
raise CommandLineError("No OpenAI access token found")
transcription_options = TranscriptionOptions(model=model, task=task, language=language,
initial_prompt=initial_prompt,
openai_access_token=openai_access_token)
file_transcription_options = FileTranscriptionOptions(file_paths=file_paths, output_formats=output_formats)
transcription_options = TranscriptionOptions(
model=model,
task=task,
language=language,
initial_prompt=initial_prompt,
openai_access_token=openai_access_token,
)
file_transcription_options = FileTranscriptionOptions(
file_paths=file_paths, output_formats=output_formats
)
for file_path in file_paths:
transcription_task = FileTranscriptionTask(file_path=file_path, model_path=model_path,
transcription_options=transcription_options,
file_transcription_options=file_transcription_options)
transcription_task = FileTranscriptionTask(
file_path=file_path,
model_path=model_path,
transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
)
app.add_task(transcription_task)
T = typing.TypeVar("T", bound=enum.Enum)
def parse_enum_option(option: QCommandLineOption, parser: QCommandLineParser, enum_class: typing.Type[T]) -> T:
def parse_enum_option(
option: QCommandLineOption, parser: QCommandLineParser, enum_class: typing.Type[T]
) -> T:
try:
return enum_class(parser.value(option))
except ValueError:
raise CommandLineError(f'Invalid value for --{option.names()[-1]} option.')
raise CommandLineError(f"Invalid value for --{option.names()[-1]} option.")
def join_values(enum_class: typing.Type[enum.Enum]) -> str:
return ', '.join([v.value for v in enum_class])
return ", ".join([v.value for v in enum_class])

View file

@ -2,9 +2,10 @@ from PyQt6.QtWidgets import QWidget, QMessageBox
def show_model_download_error_dialog(parent: QWidget, error: str):
message = parent.tr(
'An error occurred while loading the Whisper model') + \
f": {error}{'' if error.endswith('.') else '.'}" + \
parent.tr("Please retry or check the application logs for more information.")
message = (
parent.tr("An error occurred while loading the Whisper model")
+ f": {error}{'' if error.endswith('.') else '.'}"
+ parent.tr("Please retry or check the application logs for more information.")
)
QMessageBox.critical(parent, '', message)
QMessageBox.critical(parent, "", message)

View file

@ -7,8 +7,14 @@ from typing import Optional, Tuple, List
from PyQt6.QtCore import QObject, QThread, pyqtSignal, pyqtSlot
from buzz.model_loader import ModelType
from buzz.transcriber import FileTranscriptionTask, FileTranscriber, WhisperCppFileTranscriber, \
OpenAIWhisperAPIFileTranscriber, WhisperFileTranscriber, Segment
from buzz.transcriber import (
FileTranscriptionTask,
FileTranscriber,
WhisperCppFileTranscriber,
OpenAIWhisperAPIFileTranscriber,
WhisperFileTranscriber,
Segment,
)
class FileTranscriberQueueWorker(QObject):
@ -26,7 +32,7 @@ class FileTranscriberQueueWorker(QObject):
@pyqtSlot()
def run(self):
logging.debug('Waiting for next transcription task')
logging.debug("Waiting for next transcription task")
# Get next non-canceled task from queue
while True:
@ -42,38 +48,37 @@ class FileTranscriberQueueWorker(QObject):
break
logging.debug('Starting next transcription task')
logging.debug("Starting next transcription task")
model_type = self.current_task.transcription_options.model.model_type
if model_type == ModelType.WHISPER_CPP:
self.current_transcriber = WhisperCppFileTranscriber(
task=self.current_task)
self.current_transcriber = WhisperCppFileTranscriber(task=self.current_task)
elif model_type == ModelType.OPEN_AI_WHISPER_API:
self.current_transcriber = OpenAIWhisperAPIFileTranscriber(task=self.current_task)
elif model_type == ModelType.HUGGING_FACE or \
model_type == ModelType.WHISPER or \
model_type == ModelType.FASTER_WHISPER:
self.current_transcriber = OpenAIWhisperAPIFileTranscriber(
task=self.current_task
)
elif (
model_type == ModelType.HUGGING_FACE
or model_type == ModelType.WHISPER
or model_type == ModelType.FASTER_WHISPER
):
self.current_transcriber = WhisperFileTranscriber(task=self.current_task)
else:
raise Exception(f'Unknown model type: {model_type}')
raise Exception(f"Unknown model type: {model_type}")
self.current_transcriber_thread = QThread(self)
self.current_transcriber.moveToThread(self.current_transcriber_thread)
self.current_transcriber_thread.started.connect(
self.current_transcriber.run)
self.current_transcriber.completed.connect(
self.current_transcriber_thread.quit)
self.current_transcriber.error.connect(
self.current_transcriber_thread.quit)
self.current_transcriber_thread.started.connect(self.current_transcriber.run)
self.current_transcriber.completed.connect(self.current_transcriber_thread.quit)
self.current_transcriber.error.connect(self.current_transcriber_thread.quit)
self.current_transcriber.completed.connect(
self.current_transcriber.deleteLater)
self.current_transcriber.error.connect(
self.current_transcriber.deleteLater)
self.current_transcriber.completed.connect(self.current_transcriber.deleteLater)
self.current_transcriber.error.connect(self.current_transcriber.deleteLater)
self.current_transcriber_thread.finished.connect(
self.current_transcriber_thread.deleteLater)
self.current_transcriber_thread.deleteLater
)
self.current_transcriber.progress.connect(self.on_task_progress)
self.current_transcriber.error.connect(self.on_task_error)
@ -104,7 +109,10 @@ class FileTranscriberQueueWorker(QObject):
@pyqtSlot(Exception)
def on_task_error(self, error: Exception):
if self.current_task is not None and self.current_task.id not in self.canceled_tasks:
if (
self.current_task is not None
and self.current_task.id not in self.canceled_tasks
):
self.current_task.status = FileTranscriptionTask.Status.FAILED
self.current_task.error = str(error)
self.task_updated.emit(self.current_task)

View file

@ -5,13 +5,23 @@ from typing import Dict, List, Optional, Tuple
import sounddevice
from PyQt6 import QtGui
from PyQt6.QtCore import (Qt, QThread,
pyqtSignal, QModelIndex, QThreadPool)
from PyQt6.QtGui import (QCloseEvent, QIcon,
QKeySequence, QTextCursor, QPainter, QColor)
from PyQt6.QtWidgets import (QApplication, QComboBox, QFileDialog, QLabel, QMainWindow, QMessageBox, QPlainTextEdit,
QPushButton, QVBoxLayout, QHBoxLayout, QWidget, QFormLayout,
QSizePolicy)
from PyQt6.QtCore import Qt, QThread, pyqtSignal, QModelIndex, QThreadPool
from PyQt6.QtGui import QCloseEvent, QIcon, QKeySequence, QTextCursor, QPainter, QColor
from PyQt6.QtWidgets import (
QApplication,
QComboBox,
QFileDialog,
QLabel,
QMainWindow,
QMessageBox,
QPlainTextEdit,
QPushButton,
QVBoxLayout,
QHBoxLayout,
QWidget,
QFormLayout,
QSizePolicy,
)
from buzz.cache import TasksCache
from .__version__ import VERSION
@ -20,25 +30,35 @@ from .assets import get_asset_path
from .dialogs import show_model_download_error_dialog
from .widgets.icon import Icon, BUZZ_ICON_PATH
from .locale import _
from .model_loader import WhisperModelSize, ModelType, TranscriptionModel, \
ModelDownloader
from .model_loader import (
WhisperModelSize,
ModelType,
TranscriptionModel,
ModelDownloader,
)
from .recording import RecordingAmplitudeListener
from .settings.settings import Settings, APP_NAME
from .settings.shortcut import Shortcut
from .settings.shortcut_settings import ShortcutSettings
from .store.keyring_store import KeyringStore
from .transcriber import (SUPPORTED_OUTPUT_FORMATS, FileTranscriptionOptions, Task,
TranscriptionOptions,
FileTranscriptionTask, LOADED_WHISPER_DLL,
DEFAULT_WHISPER_TEMPERATURE)
from .transcriber import (
SUPPORTED_OUTPUT_FORMATS,
FileTranscriptionOptions,
Task,
TranscriptionOptions,
FileTranscriptionTask,
LOADED_WHISPER_DLL,
DEFAULT_WHISPER_TEMPERATURE,
)
from .recording_transcriber import RecordingTranscriber
from .file_transcriber_queue_worker import FileTranscriberQueueWorker
from .widgets.menu_bar import MenuBar
from .widgets.model_download_progress_dialog import ModelDownloadProgressDialog
from .widgets.toolbar import ToolBar
from .widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
from .widgets.transcriber.transcription_options_group_box import \
TranscriptionOptionsGroupBox
from .widgets.transcriber.transcription_options_group_box import (
TranscriptionOptionsGroupBox,
)
from .widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget
from .widgets.transcription_viewer_widget import TranscriptionViewerWidget
@ -46,13 +66,17 @@ from .widgets.transcription_viewer_widget import TranscriptionViewerWidget
class FormLabel(QLabel):
def __init__(self, name: str, parent: Optional[QWidget], *args) -> None:
super().__init__(name, parent, *args)
self.setStyleSheet('QLabel { text-align: right; }')
self.setAlignment(Qt.AlignmentFlag(
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignRight))
self.setStyleSheet("QLabel { text-align: right; }")
self.setAlignment(
Qt.AlignmentFlag(
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignRight
)
)
class AudioDevicesComboBox(QComboBox):
"""AudioDevicesComboBox displays a list of available audio input devices"""
device_changed = pyqtSignal(int)
audio_devices: List[Tuple[int, str]]
@ -71,13 +95,18 @@ class AudioDevicesComboBox(QComboBox):
def get_audio_devices(self) -> List[Tuple[int, str]]:
try:
devices: sounddevice.DeviceList = sounddevice.query_devices()
return [(device.get('index'), device.get('name'))
for device in devices if device.get('max_input_channels') > 0]
return [
(device.get("index"), device.get("name"))
for device in devices
if device.get("max_input_channels") > 0
]
except UnicodeDecodeError:
QMessageBox.critical(
self, '',
'An error occurred while loading your audio devices. Please check the application logs for more '
'information.')
self,
"",
"An error occurred while loading your audio devices. Please check the application logs for more "
"information.",
)
return []
def on_index_changed(self, index: int):
@ -107,14 +136,16 @@ class RecordButton(QPushButton):
def __init__(self, parent: Optional[QWidget]) -> None:
super().__init__(_("Record"), parent)
self.setDefault(True)
self.setSizePolicy(QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed))
self.setSizePolicy(
QSizePolicy(QSizePolicy.Policy.Fixed, QSizePolicy.Policy.Fixed)
)
def set_stopped(self):
self.setText(_('Record'))
self.setText(_("Record"))
self.setDefault(True)
def set_recording(self):
self.setText(_('Stop'))
self.setText(_("Stop"))
self.setDefault(False)
@ -143,11 +174,11 @@ class AudioMeterWidget(QWidget):
self.AMPLITUDE_SCALE_FACTOR = 15 # scale the amplitudes such that 1/AMPLITUDE_SCALE_FACTOR will show all bars
if self.palette().window().color().black() > 127:
self.BAR_INACTIVE_COLOR = QColor('#555')
self.BAR_ACTIVE_COLOR = QColor('#999')
self.BAR_INACTIVE_COLOR = QColor("#555")
self.BAR_ACTIVE_COLOR = QColor("#999")
else:
self.BAR_INACTIVE_COLOR = QColor('#BBB')
self.BAR_ACTIVE_COLOR = QColor('#555')
self.BAR_INACTIVE_COLOR = QColor("#BBB")
self.BAR_ACTIVE_COLOR = QColor("#555")
def paintEvent(self, event: QtGui.QPaintEvent) -> None:
painter = QPainter(self)
@ -157,26 +188,38 @@ class AudioMeterWidget(QWidget):
center_x = rect.center().x()
num_bars_in_half = int((rect.width() / 2) / (self.BAR_MARGIN + self.BAR_WIDTH))
for i in range(num_bars_in_half):
is_bar_active = ((self.current_amplitude - self.MINIMUM_AMPLITUDE) * self.AMPLITUDE_SCALE_FACTOR) > (
i / num_bars_in_half)
painter.setBrush(self.BAR_ACTIVE_COLOR if is_bar_active else self.BAR_INACTIVE_COLOR)
is_bar_active = (
(self.current_amplitude - self.MINIMUM_AMPLITUDE)
* self.AMPLITUDE_SCALE_FACTOR
) > (i / num_bars_in_half)
painter.setBrush(
self.BAR_ACTIVE_COLOR if is_bar_active else self.BAR_INACTIVE_COLOR
)
# draw to left
painter.drawRect(center_x - ((i + 1) * (self.BAR_MARGIN + self.BAR_WIDTH)), rect.top() + self.PADDING_TOP,
self.BAR_WIDTH,
rect.height() - self.PADDING_TOP)
painter.drawRect(
center_x - ((i + 1) * (self.BAR_MARGIN + self.BAR_WIDTH)),
rect.top() + self.PADDING_TOP,
self.BAR_WIDTH,
rect.height() - self.PADDING_TOP,
)
# draw to right
painter.drawRect(center_x + (self.BAR_MARGIN + (i * (self.BAR_MARGIN + self.BAR_WIDTH))),
rect.top() + self.PADDING_TOP,
self.BAR_WIDTH, rect.height() - self.PADDING_TOP)
painter.drawRect(
center_x + (self.BAR_MARGIN + (i * (self.BAR_MARGIN + self.BAR_WIDTH))),
rect.top() + self.PADDING_TOP,
self.BAR_WIDTH,
rect.height() - self.PADDING_TOP,
)
def update_amplitude(self, amplitude: float):
self.current_amplitude = max(amplitude, self.current_amplitude * self.SMOOTHING_FACTOR)
self.current_amplitude = max(
amplitude, self.current_amplitude * self.SMOOTHING_FACTOR
)
self.repaint()
class RecordingTranscriberWidget(QWidget):
current_status: 'RecordingStatus'
current_status: "RecordingStatus"
transcription_options: TranscriptionOptions
selected_device_id: Optional[int]
model_download_progress_dialog: Optional[ModelDownloadProgressDialog] = None
@ -190,7 +233,9 @@ class RecordingTranscriberWidget(QWidget):
STOPPED = auto()
RECORDING = auto()
def __init__(self, parent: Optional[QWidget] = None, flags: Optional[Qt.WindowType] = None) -> None:
def __init__(
self, parent: Optional[QWidget] = None, flags: Optional[Qt.WindowType] = None
) -> None:
super().__init__(parent)
if flags is not None:
@ -199,42 +244,63 @@ class RecordingTranscriberWidget(QWidget):
layout = QVBoxLayout(self)
self.current_status = self.RecordingStatus.STOPPED
self.setWindowTitle(_('Live Recording'))
self.setWindowTitle(_("Live Recording"))
self.settings = Settings()
default_language = self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE, default_value='')
default_language = self.settings.value(
key=Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE, default_value=""
)
self.transcription_options = TranscriptionOptions(
model=self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_MODEL, default_value=TranscriptionModel(
model_type=ModelType.WHISPER_CPP if LOADED_WHISPER_DLL else ModelType.WHISPER,
whisper_model_size=WhisperModelSize.TINY)),
task=self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_TASK, default_value=Task.TRANSCRIBE),
language=default_language if default_language != '' else None,
initial_prompt=self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT, default_value=''),
temperature=self.settings.value(key=Settings.Key.RECORDING_TRANSCRIBER_TEMPERATURE,
default_value=DEFAULT_WHISPER_TEMPERATURE), word_level_timings=False)
model=self.settings.value(
key=Settings.Key.RECORDING_TRANSCRIBER_MODEL,
default_value=TranscriptionModel(
model_type=ModelType.WHISPER_CPP
if LOADED_WHISPER_DLL
else ModelType.WHISPER,
whisper_model_size=WhisperModelSize.TINY,
),
),
task=self.settings.value(
key=Settings.Key.RECORDING_TRANSCRIBER_TASK,
default_value=Task.TRANSCRIBE,
),
language=default_language if default_language != "" else None,
initial_prompt=self.settings.value(
key=Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT, default_value=""
),
temperature=self.settings.value(
key=Settings.Key.RECORDING_TRANSCRIBER_TEMPERATURE,
default_value=DEFAULT_WHISPER_TEMPERATURE,
),
word_level_timings=False,
)
self.audio_devices_combo_box = AudioDevicesComboBox(self)
self.audio_devices_combo_box.device_changed.connect(
self.on_device_changed)
self.audio_devices_combo_box.device_changed.connect(self.on_device_changed)
self.selected_device_id = self.audio_devices_combo_box.get_default_device_id()
self.record_button = RecordButton(self)
self.record_button.clicked.connect(self.on_record_button_clicked)
self.text_box = TextDisplayBox(self)
self.text_box.setPlaceholderText(_('Click Record to begin...'))
self.text_box.setPlaceholderText(_("Click Record to begin..."))
transcription_options_group_box = TranscriptionOptionsGroupBox(
default_transcription_options=self.transcription_options,
# Live transcription with OpenAI Whisper API not implemented
model_types=[model_type for model_type in ModelType if model_type is not ModelType.OPEN_AI_WHISPER_API],
parent=self)
model_types=[
model_type
for model_type in ModelType
if model_type is not ModelType.OPEN_AI_WHISPER_API
],
parent=self,
)
transcription_options_group_box.transcription_options_changed.connect(
self.on_transcription_options_changed)
self.on_transcription_options_changed
)
recording_options_layout = QFormLayout()
recording_options_layout.addRow(
_('Microphone:'), self.audio_devices_combo_box)
recording_options_layout.addRow(_("Microphone:"), self.audio_devices_combo_box)
self.audio_meter_widget = AudioMeterWidget(self)
@ -252,7 +318,9 @@ class RecordingTranscriberWidget(QWidget):
self.reset_recording_amplitude_listener()
def on_transcription_options_changed(self, transcription_options: TranscriptionOptions):
def on_transcription_options_changed(
self, transcription_options: TranscriptionOptions
):
self.transcription_options = transcription_options
def on_device_changed(self, device_id: int):
@ -269,11 +337,16 @@ class RecordingTranscriberWidget(QWidget):
# Get the device sample rate before starting the listener as the PortAudio function
# fails if you try to get the device's settings while recording is in progress.
self.device_sample_rate = RecordingTranscriber.get_device_sample_rate(self.selected_device_id)
self.device_sample_rate = RecordingTranscriber.get_device_sample_rate(
self.selected_device_id
)
self.recording_amplitude_listener = RecordingAmplitudeListener(input_device_index=self.selected_device_id,
parent=self)
self.recording_amplitude_listener.amplitude_changed.connect(self.on_recording_amplitude_changed)
self.recording_amplitude_listener = RecordingAmplitudeListener(
input_device_index=self.selected_device_id, parent=self
)
self.recording_amplitude_listener.amplitude_changed.connect(
self.on_recording_amplitude_changed
)
self.recording_amplitude_listener.start_recording()
def on_record_button_clicked(self):
@ -306,16 +379,19 @@ class RecordingTranscriberWidget(QWidget):
self.transcription_thread = QThread()
# TODO: make runnable
self.transcriber = RecordingTranscriber(input_device_index=self.selected_device_id,
sample_rate=self.device_sample_rate,
transcription_options=self.transcription_options,
model_path=model_path)
self.transcriber = RecordingTranscriber(
input_device_index=self.selected_device_id,
sample_rate=self.device_sample_rate,
transcription_options=self.transcription_options,
model_path=model_path,
)
self.transcriber.moveToThread(self.transcription_thread)
self.transcription_thread.started.connect(self.transcriber.start)
self.transcription_thread.finished.connect(
self.transcription_thread.deleteLater)
self.transcription_thread.deleteLater
)
self.transcriber.transcription.connect(self.on_next_transcription)
@ -334,12 +410,16 @@ class RecordingTranscriberWidget(QWidget):
if self.model_download_progress_dialog is None:
self.model_download_progress_dialog = ModelDownloadProgressDialog(
model_type=self.transcription_options.model.model_type, parent=self)
model_type=self.transcription_options.model.model_type, parent=self
)
self.model_download_progress_dialog.canceled.connect(
self.on_cancel_model_progress_dialog)
self.on_cancel_model_progress_dialog
)
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.set_value(fraction_completed=current_size / total_size)
self.model_download_progress_dialog.set_value(
fraction_completed=current_size / total_size
)
def set_recording_status_stopped(self):
self.record_button.set_stopped()
@ -357,7 +437,7 @@ class RecordingTranscriberWidget(QWidget):
if len(text) > 0:
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
if len(self.text_box.toPlainText()) > 0:
self.text_box.insertPlainText('\n\n')
self.text_box.insertPlainText("\n\n")
self.text_box.insertPlainText(text)
self.text_box.moveCursor(QTextCursor.MoveOperation.End)
@ -374,9 +454,15 @@ class RecordingTranscriberWidget(QWidget):
self.reset_record_button()
self.set_recording_status_stopped()
QMessageBox.critical(
self, '',
_('An error occurred while starting a new recording:') + error + '. ' +
_('Please check your audio devices or check the application logs for more information.'))
self,
"",
_("An error occurred while starting a new recording:")
+ error
+ ". "
+ _(
"Please check your audio devices or check the application logs for more information."
),
)
def on_cancel_model_progress_dialog(self):
if self.model_loader is not None:
@ -392,7 +478,7 @@ class RecordingTranscriberWidget(QWidget):
def reset_recording_controls(self):
# Clear text box placeholder because the first chunk takes a while to process
self.text_box.setPlaceholderText('')
self.text_box.setPlaceholderText("")
self.reset_record_button()
self.reset_model_download()
@ -411,49 +497,72 @@ class RecordingTranscriberWidget(QWidget):
self.recording_amplitude_listener.stop_recording()
self.recording_amplitude_listener.deleteLater()
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE, self.transcription_options.language)
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_TASK, self.transcription_options.task)
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_TEMPERATURE, self.transcription_options.temperature)
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT,
self.transcription_options.initial_prompt)
self.settings.set_value(Settings.Key.RECORDING_TRANSCRIBER_MODEL, self.transcription_options.model)
self.settings.set_value(
Settings.Key.RECORDING_TRANSCRIBER_LANGUAGE,
self.transcription_options.language,
)
self.settings.set_value(
Settings.Key.RECORDING_TRANSCRIBER_TASK, self.transcription_options.task
)
self.settings.set_value(
Settings.Key.RECORDING_TRANSCRIBER_TEMPERATURE,
self.transcription_options.temperature,
)
self.settings.set_value(
Settings.Key.RECORDING_TRANSCRIBER_INITIAL_PROMPT,
self.transcription_options.initial_prompt,
)
self.settings.set_value(
Settings.Key.RECORDING_TRANSCRIBER_MODEL, self.transcription_options.model
)
return super().closeEvent(event)
RECORD_ICON_PATH = get_asset_path('assets/mic_FILL0_wght700_GRAD0_opsz48.svg')
EXPAND_ICON_PATH = get_asset_path('assets/open_in_full_FILL0_wght700_GRAD0_opsz48.svg')
ADD_ICON_PATH = get_asset_path('assets/add_FILL0_wght700_GRAD0_opsz48.svg')
TRASH_ICON_PATH = get_asset_path('assets/delete_FILL0_wght700_GRAD0_opsz48.svg')
CANCEL_ICON_PATH = get_asset_path('assets/cancel_FILL0_wght700_GRAD0_opsz48.svg')
RECORD_ICON_PATH = get_asset_path("assets/mic_FILL0_wght700_GRAD0_opsz48.svg")
EXPAND_ICON_PATH = get_asset_path("assets/open_in_full_FILL0_wght700_GRAD0_opsz48.svg")
ADD_ICON_PATH = get_asset_path("assets/add_FILL0_wght700_GRAD0_opsz48.svg")
TRASH_ICON_PATH = get_asset_path("assets/delete_FILL0_wght700_GRAD0_opsz48.svg")
CANCEL_ICON_PATH = get_asset_path("assets/cancel_FILL0_wght700_GRAD0_opsz48.svg")
class MainWindowToolbar(ToolBar):
new_transcription_action_triggered: pyqtSignal
open_transcript_action_triggered: pyqtSignal
clear_history_action_triggered: pyqtSignal
ICON_LIGHT_THEME_BACKGROUND = '#555'
ICON_DARK_THEME_BACKGROUND = '#AAA'
ICON_LIGHT_THEME_BACKGROUND = "#555"
ICON_DARK_THEME_BACKGROUND = "#AAA"
def __init__(self, shortcuts: Dict[str, str], parent: Optional[QWidget]):
super().__init__(parent)
self.record_action = Action(Icon(RECORD_ICON_PATH, self), _('Record'), self)
self.record_action = Action(Icon(RECORD_ICON_PATH, self), _("Record"), self)
self.record_action.triggered.connect(self.on_record_action_triggered)
self.new_transcription_action = Action(Icon(ADD_ICON_PATH, self), _('New Transcription'), self)
self.new_transcription_action_triggered = self.new_transcription_action.triggered
self.new_transcription_action = Action(
Icon(ADD_ICON_PATH, self), _("New Transcription"), self
)
self.new_transcription_action_triggered = (
self.new_transcription_action.triggered
)
self.open_transcript_action = Action(Icon(EXPAND_ICON_PATH, self),
_('Open Transcript'), self)
self.open_transcript_action = Action(
Icon(EXPAND_ICON_PATH, self), _("Open Transcript"), self
)
self.open_transcript_action_triggered = self.open_transcript_action.triggered
self.open_transcript_action.setDisabled(True)
self.stop_transcription_action = Action(Icon(CANCEL_ICON_PATH, self), _('Cancel Transcription'), self)
self.stop_transcription_action_triggered = self.stop_transcription_action.triggered
self.stop_transcription_action = Action(
Icon(CANCEL_ICON_PATH, self), _("Cancel Transcription"), self
)
self.stop_transcription_action_triggered = (
self.stop_transcription_action.triggered
)
self.stop_transcription_action.setDisabled(True)
self.clear_history_action = Action(Icon(TRASH_ICON_PATH, self), _('Clear History'), self)
self.clear_history_action = Action(
Icon(TRASH_ICON_PATH, self), _("Clear History"), self
)
self.clear_history_action_triggered = self.clear_history_action.triggered
self.clear_history_action.setDisabled(True)
@ -461,21 +570,38 @@ class MainWindowToolbar(ToolBar):
self.addAction(self.record_action)
self.addSeparator()
self.addActions([self.new_transcription_action, self.open_transcript_action, self.stop_transcription_action,
self.clear_history_action])
self.addActions(
[
self.new_transcription_action,
self.open_transcript_action,
self.stop_transcription_action,
self.clear_history_action,
]
)
self.setMovable(False)
self.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly)
def set_shortcuts(self, shortcuts: Dict[str, str]):
self.record_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.OPEN_RECORD_WINDOW.name]))
self.new_transcription_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name]))
self.record_action.setShortcut(
QKeySequence.fromString(shortcuts[Shortcut.OPEN_RECORD_WINDOW.name])
)
self.new_transcription_action.setShortcut(
QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name])
)
self.open_transcript_action.setShortcut(
QKeySequence.fromString(shortcuts[Shortcut.OPEN_TRANSCRIPT_EDITOR.name]))
self.stop_transcription_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.STOP_TRANSCRIPTION.name]))
self.clear_history_action.setShortcut(QKeySequence.fromString(shortcuts[Shortcut.CLEAR_HISTORY.name]))
QKeySequence.fromString(shortcuts[Shortcut.OPEN_TRANSCRIPT_EDITOR.name])
)
self.stop_transcription_action.setShortcut(
QKeySequence.fromString(shortcuts[Shortcut.STOP_TRANSCRIPTION.name])
)
self.clear_history_action.setShortcut(
QKeySequence.fromString(shortcuts[Shortcut.CLEAR_HISTORY.name])
)
def on_record_action_triggered(self):
recording_transcriber_window = RecordingTranscriberWidget(self, flags=Qt.WindowType.Window)
recording_transcriber_window = RecordingTranscriberWidget(
self, flags=Qt.WindowType.Window
)
recording_transcriber_window.show()
def set_stop_transcription_action_enabled(self, enabled: bool):
@ -490,7 +616,7 @@ class MainWindowToolbar(ToolBar):
class MainWindow(QMainWindow):
table_widget: TranscriptionTasksTableWidget
tasks: Dict[int, 'FileTranscriptionTask']
tasks: Dict[int, "FileTranscriptionTask"]
tasks_changed = pyqtSignal()
openai_access_token: Optional[str]
@ -511,34 +637,49 @@ class MainWindow(QMainWindow):
self.shortcuts = self.shortcut_settings.load()
self.default_export_file_name = self.settings.value(
Settings.Key.DEFAULT_EXPORT_FILE_NAME,
'{{ input_file_name }} ({{ task }}d on {{ date_time }})')
"{{ input_file_name }} ({{ task }}d on {{ date_time }})",
)
self.tasks = {}
self.tasks_changed.connect(self.on_tasks_changed)
self.toolbar = MainWindowToolbar(shortcuts=self.shortcuts, parent=self)
self.toolbar.new_transcription_action_triggered.connect(self.on_new_transcription_action_triggered)
self.toolbar.open_transcript_action_triggered.connect(self.open_transcript_viewer)
self.toolbar.clear_history_action_triggered.connect(self.on_clear_history_action_triggered)
self.toolbar.stop_transcription_action_triggered.connect(self.on_stop_transcription_action_triggered)
self.toolbar.new_transcription_action_triggered.connect(
self.on_new_transcription_action_triggered
)
self.toolbar.open_transcript_action_triggered.connect(
self.open_transcript_viewer
)
self.toolbar.clear_history_action_triggered.connect(
self.on_clear_history_action_triggered
)
self.toolbar.stop_transcription_action_triggered.connect(
self.on_stop_transcription_action_triggered
)
self.addToolBar(self.toolbar)
self.setUnifiedTitleAndToolBarOnMac(True)
self.menu_bar = MenuBar(shortcuts=self.shortcuts,
default_export_file_name=self.default_export_file_name,
parent=self)
self.menu_bar = MenuBar(
shortcuts=self.shortcuts,
default_export_file_name=self.default_export_file_name,
parent=self,
)
self.menu_bar.import_action_triggered.connect(
self.on_new_transcription_action_triggered)
self.on_new_transcription_action_triggered
)
self.menu_bar.shortcuts_changed.connect(self.on_shortcuts_changed)
self.menu_bar.openai_api_key_changed.connect(self.on_openai_access_token_changed)
self.menu_bar.default_export_file_name_changed.connect(self.default_export_file_name_changed)
self.menu_bar.openai_api_key_changed.connect(
self.on_openai_access_token_changed
)
self.menu_bar.default_export_file_name_changed.connect(
self.default_export_file_name_changed
)
self.setMenuBar(self.menu_bar)
self.table_widget = TranscriptionTasksTableWidget(self)
self.table_widget.doubleClicked.connect(self.on_table_double_clicked)
self.table_widget.return_clicked.connect(self.open_transcript_viewer)
self.table_widget.itemSelectionChanged.connect(
self.on_table_selection_changed)
self.table_widget.itemSelectionChanged.connect(self.on_table_selection_changed)
self.setCentralWidget(self.table_widget)
@ -548,8 +689,7 @@ class MainWindow(QMainWindow):
self.transcriber_worker = FileTranscriberQueueWorker()
self.transcriber_worker.moveToThread(self.transcriber_thread)
self.transcriber_worker.task_updated.connect(
self.update_task_table_row)
self.transcriber_worker.task_updated.connect(self.update_task_table_row)
self.transcriber_worker.completed.connect(self.transcriber_thread.quit)
self.transcriber_thread.started.connect(self.transcriber_worker.run)
@ -569,11 +709,14 @@ class MainWindow(QMainWindow):
file_paths = [url.toLocalFile() for url in event.mimeData().urls()]
self.open_file_transcriber_widget(file_paths=file_paths)
def on_file_transcriber_triggered(self, options: Tuple[TranscriptionOptions, FileTranscriptionOptions, str]):
def on_file_transcriber_triggered(
self, options: Tuple[TranscriptionOptions, FileTranscriptionOptions, str]
):
transcription_options, file_transcription_options, model_path = options
for file_path in file_transcription_options.file_paths:
task = FileTranscriptionTask(
file_path, transcription_options, file_transcription_options, model_path)
file_path, transcription_options, file_transcription_options, model_path
)
self.add_task(task)
def load_task(self, task: FileTranscriptionTask):
@ -586,8 +729,10 @@ class MainWindow(QMainWindow):
@staticmethod
def task_completed_or_errored(task: FileTranscriptionTask):
return task.status == FileTranscriptionTask.Status.COMPLETED or \
task.status == FileTranscriptionTask.Status.FAILED
return (
task.status == FileTranscriptionTask.Status.COMPLETED
or task.status == FileTranscriptionTask.Status.FAILED
)
def on_clear_history_action_triggered(self):
selected_rows = self.table_widget.selectionModel().selectedRows()
@ -595,10 +740,17 @@ class MainWindow(QMainWindow):
return
reply = QMessageBox.question(
self, _('Clear History'),
_('Are you sure you want to delete the selected transcription(s)? This action cannot be undone.'))
self,
_("Clear History"),
_(
"Are you sure you want to delete the selected transcription(s)? This action cannot be undone."
),
)
if reply == QMessageBox.StandardButton.Yes:
task_ids = [TranscriptionTasksTableWidget.find_task_id(selected_row) for selected_row in selected_rows]
task_ids = [
TranscriptionTasksTableWidget.find_task_id(selected_row)
for selected_row in selected_rows
]
for task_id in task_ids:
self.table_widget.clear_task(task_id)
self.tasks.pop(task_id)
@ -617,20 +769,24 @@ class MainWindow(QMainWindow):
def on_new_transcription_action_triggered(self):
(file_paths, __) = QFileDialog.getOpenFileNames(
self, _('Select audio file'), '', SUPPORTED_OUTPUT_FORMATS)
self, _("Select audio file"), "", SUPPORTED_OUTPUT_FORMATS
)
if len(file_paths) == 0:
return
self.open_file_transcriber_widget(file_paths)
def open_file_transcriber_widget(self, file_paths: List[str]):
file_transcriber_window = FileTranscriberWidget(file_paths=file_paths,
default_output_file_name=self.default_export_file_name,
parent=self,
flags=Qt.WindowType.Window)
file_transcriber_window.triggered.connect(
self.on_file_transcriber_triggered)
file_transcriber_window.openai_access_token_changed.connect(self.on_openai_access_token_changed)
file_transcriber_window = FileTranscriberWidget(
file_paths=file_paths,
default_output_file_name=self.default_export_file_name,
parent=self,
flags=Qt.WindowType.Window,
)
file_transcriber_window.triggered.connect(self.on_file_transcriber_triggered)
file_transcriber_window.openai_access_token_changed.connect(
self.on_openai_access_token_changed
)
file_transcriber_window.show()
@staticmethod
@ -639,7 +795,9 @@ class MainWindow(QMainWindow):
def default_export_file_name_changed(self, default_export_file_name: str):
self.default_export_file_name = default_export_file_name
self.settings.set_value(Settings.Key.DEFAULT_EXPORT_FILE_NAME, default_export_file_name)
self.settings.set_value(
Settings.Key.DEFAULT_EXPORT_FILE_NAME, default_export_file_name
)
def open_transcript_viewer(self):
selected_rows = self.table_widget.selectionModel().selectedRows()
@ -648,29 +806,49 @@ class MainWindow(QMainWindow):
self.open_transcription_viewer(task_id)
def on_table_selection_changed(self):
self.toolbar.set_open_transcript_action_enabled(self.should_enable_open_transcript_action())
self.toolbar.set_stop_transcription_action_enabled(self.should_enable_stop_transcription_action())
self.toolbar.set_clear_history_action_enabled(self.should_enable_clear_history_action())
self.toolbar.set_open_transcript_action_enabled(
self.should_enable_open_transcript_action()
)
self.toolbar.set_stop_transcription_action_enabled(
self.should_enable_stop_transcription_action()
)
self.toolbar.set_clear_history_action_enabled(
self.should_enable_clear_history_action()
)
def should_enable_open_transcript_action(self):
return self.selected_tasks_have_status([FileTranscriptionTask.Status.COMPLETED])
def should_enable_stop_transcription_action(self):
return self.selected_tasks_have_status(
[FileTranscriptionTask.Status.IN_PROGRESS, FileTranscriptionTask.Status.QUEUED])
[
FileTranscriptionTask.Status.IN_PROGRESS,
FileTranscriptionTask.Status.QUEUED,
]
)
def should_enable_clear_history_action(self):
return self.selected_tasks_have_status(
[FileTranscriptionTask.Status.COMPLETED, FileTranscriptionTask.Status.FAILED,
FileTranscriptionTask.Status.CANCELED])
[
FileTranscriptionTask.Status.COMPLETED,
FileTranscriptionTask.Status.FAILED,
FileTranscriptionTask.Status.CANCELED,
]
)
def selected_tasks_have_status(self, statuses: List[FileTranscriptionTask.Status]):
selected_rows = self.table_widget.selectionModel().selectedRows()
if len(selected_rows) == 0:
return False
return all(
[self.tasks[TranscriptionTasksTableWidget.find_task_id(selected_row)].status in statuses for selected_row in
selected_rows])
[
self.tasks[
TranscriptionTasksTableWidget.find_task_id(selected_row)
].status
in statuses
for selected_row in selected_rows
]
)
def on_table_double_clicked(self, index: QModelIndex):
task_id = TranscriptionTasksTableWidget.find_task_id(index)
@ -682,7 +860,8 @@ class MainWindow(QMainWindow):
return
transcription_viewer_widget = TranscriptionViewerWidget(
transcription_task=task, parent=self, flags=Qt.WindowType.Window)
transcription_task=task, parent=self, flags=Qt.WindowType.Window
)
transcription_viewer_widget.task_changed.connect(self.on_tasks_changed)
transcription_viewer_widget.show()
@ -692,8 +871,10 @@ class MainWindow(QMainWindow):
def load_tasks_from_cache(self):
tasks = self.tasks_cache.load()
for task in tasks:
if task.status == FileTranscriptionTask.Status.QUEUED or \
task.status == FileTranscriptionTask.Status.IN_PROGRESS:
if (
task.status == FileTranscriptionTask.Status.QUEUED
or task.status == FileTranscriptionTask.Status.IN_PROGRESS
):
task.status = None
self.transcriber_worker.add_task(task)
else:
@ -703,9 +884,15 @@ class MainWindow(QMainWindow):
self.tasks_cache.save(list(self.tasks.values()))
def on_tasks_changed(self):
self.toolbar.set_open_transcript_action_enabled(self.should_enable_open_transcript_action())
self.toolbar.set_stop_transcription_action_enabled(self.should_enable_stop_transcription_action())
self.toolbar.set_clear_history_action_enabled(self.should_enable_clear_history_action())
self.toolbar.set_open_transcript_action_enabled(
self.should_enable_open_transcript_action()
)
self.toolbar.set_stop_transcription_action_enabled(
self.should_enable_stop_transcription_action()
)
self.toolbar.set_clear_history_action_enabled(
self.should_enable_clear_history_action()
)
self.save_tasks_to_cache()
def on_shortcuts_changed(self, shortcuts: dict):
@ -723,7 +910,6 @@ class MainWindow(QMainWindow):
super().closeEvent(event)
class Application(QApplication):
window: MainWindow

View file

@ -6,12 +6,12 @@ from PyQt6.QtCore import QLocale
from buzz.assets import get_asset_path
from buzz.settings.settings import APP_NAME
if 'LANG' not in os.environ:
if "LANG" not in os.environ:
language = str(QLocale().uiLanguages()[0]).replace("-", "_")
os.environ['LANG'] = language
os.environ["LANG"] = language
locale_dir = get_asset_path('locale')
gettext.bindtextdomain('buzz', locale_dir)
locale_dir = get_asset_path("locale")
gettext.bindtextdomain("buzz", locale_dir)
translate = gettext.translation(APP_NAME, locale_dir, fallback=True)

View file

@ -20,11 +20,11 @@ from tqdm.auto import tqdm
class WhisperModelSize(str, enum.Enum):
TINY = 'tiny'
BASE = 'base'
SMALL = 'small'
MEDIUM = 'medium'
LARGE = 'large'
TINY = "tiny"
BASE = "base"
SMALL = "small"
MEDIUM = "medium"
LARGE = "large"
def to_faster_whisper_model_size(self) -> str:
if self == WhisperModelSize.LARGE:
@ -33,11 +33,11 @@ class WhisperModelSize(str, enum.Enum):
class ModelType(enum.Enum):
WHISPER = 'Whisper'
WHISPER_CPP = 'Whisper.cpp'
HUGGING_FACE = 'Hugging Face'
FASTER_WHISPER = 'Faster Whisper'
OPEN_AI_WHISPER_API = 'OpenAI Whisper API'
WHISPER = "Whisper"
WHISPER_CPP = "Whisper.cpp"
HUGGING_FACE = "Hugging Face"
FASTER_WHISPER = "Faster Whisper"
OPEN_AI_WHISPER_API = "OpenAI Whisper API"
@dataclass()
@ -47,9 +47,10 @@ class TranscriptionModel:
hugging_face_model_id: Optional[str] = None
def is_deletable(self):
return ((self.model_type == ModelType.WHISPER or
self.model_type == ModelType.WHISPER_CPP) and
self.get_local_model_path() is not None)
return (
self.model_type == ModelType.WHISPER
or self.model_type == ModelType.WHISPER_CPP
) and self.get_local_model_path() is not None
def open_file_location(self):
model_path = self.get_local_model_path()
@ -84,18 +85,20 @@ class TranscriptionModel:
if self.model_type == ModelType.FASTER_WHISPER:
try:
return download_faster_whisper_model(size=self.whisper_model_size.value,
local_files_only=True)
return download_faster_whisper_model(
size=self.whisper_model_size.value, local_files_only=True
)
except (ValueError, FileNotFoundError):
return None
if self.model_type == ModelType.OPEN_AI_WHISPER_API:
return ''
return ""
if self.model_type == ModelType.HUGGING_FACE:
try:
return huggingface_hub.snapshot_download(self.hugging_face_model_id,
local_files_only=True)
return huggingface_hub.snapshot_download(
self.hugging_face_model_id, local_files_only=True
)
except (ValueError, FileNotFoundError):
return None
@ -103,36 +106,38 @@ class TranscriptionModel:
WHISPER_CPP_MODELS_SHA256 = {
'tiny': 'be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21',
'base': '60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe',
'small': '1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b',
'medium': '6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208',
'large': '9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487'
"tiny": "be07e048e1e599ad46341c8d2a135645097a538221678b7acdd1b1919c6e1b21",
"base": "60ed5bc3dd14eea856493d334349b405782ddcaf0028d4b5df4088345fba2efe",
"small": "1be3a9b2063867b937e64e2ec7483364a79917e157fa98c5d94b5c1fffea987b",
"medium": "6c14d5adee5f86394037b4e4e8b59f1673b6cee10e3cf0b11bbdbee79c156208",
"large": "9a423fe4d40c82774b6af34115b8b935f34152246eb19e80e376071d3f999487",
}
def get_hugging_face_file_url(author: str, repository_name: str, filename: str):
return f'https://huggingface.co/{author}/{repository_name}/resolve/main/{filename}'
return f"https://huggingface.co/{author}/{repository_name}/resolve/main/{filename}"
def get_whisper_cpp_file_path(size: WhisperModelSize) -> str:
root_dir = user_cache_dir('Buzz')
return os.path.join(root_dir, f'ggml-model-whisper-{size.value}.bin')
root_dir = user_cache_dir("Buzz")
return os.path.join(root_dir, f"ggml-model-whisper-{size.value}.bin")
def get_whisper_file_path(size: WhisperModelSize) -> str:
root_dir = os.getenv("XDG_CACHE_HOME", os.path.join(
os.path.expanduser("~"), ".cache", "whisper"))
root_dir = os.getenv(
"XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache", "whisper")
)
url = whisper._MODELS[size.value]
return os.path.join(root_dir, os.path.basename(url))
def download_faster_whisper_model(size: str, local_files_only=False,
tqdm_class: Optional[tqdm] = None):
def download_faster_whisper_model(
size: str, local_files_only=False, tqdm_class: Optional[tqdm] = None
):
if size not in faster_whisper.utils._MODELS:
raise ValueError(
"Invalid model size '%s', expected one of: %s" % (
size, ", ".join(faster_whisper.utils._MODELS))
"Invalid model size '%s', expected one of: %s"
% (size, ", ".join(faster_whisper.utils._MODELS))
)
repo_id = "guillaumekln/faster-whisper-%s" % size
@ -144,9 +149,12 @@ def download_faster_whisper_model(size: str, local_files_only=False,
"vocabulary.txt",
]
return huggingface_hub.snapshot_download(repo_id, allow_patterns=allow_patterns,
local_files_only=local_files_only,
tqdm_class=tqdm_class)
return huggingface_hub.snapshot_download(
repo_id,
allow_patterns=allow_patterns,
local_files_only=local_files_only,
tqdm_class=tqdm_class,
)
class ModelDownloader(QRunnable):
@ -165,22 +173,24 @@ class ModelDownloader(QRunnable):
def run(self) -> None:
if self.model.model_type == ModelType.WHISPER_CPP:
model_name = self.model.whisper_model_size.value
url = get_hugging_face_file_url(author='ggerganov',
repository_name='whisper.cpp',
filename=f'ggml-{model_name}.bin')
file_path = get_whisper_cpp_file_path(
size=self.model.whisper_model_size)
url = get_hugging_face_file_url(
author="ggerganov",
repository_name="whisper.cpp",
filename=f"ggml-{model_name}.bin",
)
file_path = get_whisper_cpp_file_path(size=self.model.whisper_model_size)
expected_sha256 = WHISPER_CPP_MODELS_SHA256[model_name]
return self.download_model_to_path(url=url, file_path=file_path,
expected_sha256=expected_sha256)
return self.download_model_to_path(
url=url, file_path=file_path, expected_sha256=expected_sha256
)
if self.model.model_type == ModelType.WHISPER:
url = whisper._MODELS[self.model.whisper_model_size.value]
file_path = get_whisper_file_path(
size=self.model.whisper_model_size)
expected_sha256 = url.split('/')[-2]
return self.download_model_to_path(url=url, file_path=file_path,
expected_sha256=expected_sha256)
file_path = get_whisper_file_path(size=self.model.whisper_model_size)
expected_sha256 = url.split("/")[-2]
return self.download_model_to_path(
url=url, file_path=file_path, expected_sha256=expected_sha256
)
progress = self.signals.progress
@ -197,44 +207,47 @@ class ModelDownloader(QRunnable):
if self.model.model_type == ModelType.FASTER_WHISPER:
model_path = download_faster_whisper_model(
size=self.model.whisper_model_size.to_faster_whisper_model_size(),
tqdm_class=_tqdm)
tqdm_class=_tqdm,
)
self.signals.finished.emit(model_path)
return
if self.model.model_type == ModelType.HUGGING_FACE:
model_path = huggingface_hub.snapshot_download(
self.model.hugging_face_model_id, tqdm_class=_tqdm)
self.model.hugging_face_model_id, tqdm_class=_tqdm
)
self.signals.finished.emit(model_path)
return
if self.model.model_type == ModelType.OPEN_AI_WHISPER_API:
self.signals.finished.emit('')
self.signals.finished.emit("")
return
raise Exception("Invalid model type: " + self.model.model_type.value)
def download_model_to_path(self, url: str, file_path: str,
expected_sha256: Optional[str]):
def download_model_to_path(
self, url: str, file_path: str, expected_sha256: Optional[str]
):
try:
downloaded = self.download_model(url, file_path, expected_sha256)
if downloaded:
self.signals.finished.emit(file_path)
except requests.RequestException:
self.signals.error.emit('A connection error occurred')
logging.exception('')
self.signals.error.emit("A connection error occurred")
logging.exception("")
except Exception as exc:
self.signals.error.emit(str(exc))
logging.exception(exc)
def download_model(self, url: str, file_path: str,
expected_sha256: Optional[str]) -> bool:
logging.debug(f'Downloading model from {url} to {file_path}')
def download_model(
self, url: str, file_path: str, expected_sha256: Optional[str]
) -> bool:
logging.debug(f"Downloading model from {url} to {file_path}")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
if os.path.exists(file_path) and not os.path.isfile(file_path):
raise RuntimeError(
f"{file_path} exists and is not a regular file")
raise RuntimeError(f"{file_path} exists and is not a regular file")
if os.path.isfile(file_path):
if expected_sha256 is None:
@ -246,17 +259,19 @@ class ModelDownloader(QRunnable):
return True
else:
warnings.warn(
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file")
f"{file_path} exists, but the SHA256 checksum does not match; re-downloading the file"
)
tmp_file = tempfile.mktemp()
logging.debug('Downloading to temporary file = %s', tmp_file)
logging.debug("Downloading to temporary file = %s", tmp_file)
# Downloads the model using the requests module instead of urllib to
# use the certs from certifi when the app is running in frozen mode
with requests.get(url, stream=True, timeout=15) as source, open(tmp_file,
'wb') as output:
with requests.get(url, stream=True, timeout=15) as source, open(
tmp_file, "wb"
) as output:
source.raise_for_status()
total_size = float(source.headers.get('Content-Length', 0))
total_size = float(source.headers.get("Content-Length", 0))
current = 0.0
self.signals.progress.emit((current, total_size))
for chunk in source.iter_content(chunk_size=8192):
@ -271,13 +286,14 @@ class ModelDownloader(QRunnable):
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not match. Please retry loading the "
"model.")
"model."
)
logging.debug('Downloaded model')
logging.debug("Downloaded model")
# https://github.com/chidiwilliams/buzz/issues/454
shutil.move(tmp_file, file_path)
logging.debug('Moved file from %s to %s', tmp_file, file_path)
logging.debug("Moved file from %s to %s", tmp_file, file_path)
return True
def cancel(self):

View file

@ -7,4 +7,4 @@ def file_path_as_title(file_path: str):
def file_paths_as_title(file_paths: List[str]):
return ', '.join([file_path_as_title(path) for path in file_paths])
return ", ".join([file_path_as_title(path) for path in file_paths])

View file

@ -10,19 +10,25 @@ class RecordingAmplitudeListener(QObject):
stream: Optional[sounddevice.InputStream] = None
amplitude_changed = pyqtSignal(float)
def __init__(self, input_device_index: Optional[int] = None,
parent: Optional[QObject] = None,
):
def __init__(
self,
input_device_index: Optional[int] = None,
parent: Optional[QObject] = None,
):
super().__init__(parent)
self.input_device_index = input_device_index
def start_recording(self):
try:
self.stream = sounddevice.InputStream(device=self.input_device_index, dtype='float32',
channels=1, callback=self.stream_callback)
self.stream = sounddevice.InputStream(
device=self.input_device_index,
dtype="float32",
channels=1,
callback=self.stream_callback,
)
self.stream.start()
except sounddevice.PortAudioError:
logging.exception('')
logging.exception("")
def stop_recording(self):
if self.stream is not None:
@ -31,5 +37,5 @@ class RecordingAmplitudeListener(QObject):
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):
chunk = in_data.ravel()
amplitude = np.sqrt(np.mean(chunk ** 2)) # root-mean-square
amplitude = np.sqrt(np.mean(chunk**2)) # root-mean-square
self.amplitude_changed.emit(amplitude)

View file

@ -22,9 +22,14 @@ class RecordingTranscriber(QObject):
is_running = False
MAX_QUEUE_SIZE = 10
def __init__(self, transcription_options: TranscriptionOptions,
input_device_index: Optional[int], sample_rate: int, model_path: str,
parent: Optional[QObject] = None) -> None:
def __init__(
self,
transcription_options: TranscriptionOptions,
input_device_index: Optional[int],
sample_rate: int,
model_path: str,
parent: Optional[QObject] = None,
) -> None:
super().__init__(parent)
self.transcription_options = transcription_options
self.current_stream = None
@ -49,60 +54,91 @@ class RecordingTranscriber(QObject):
initial_prompt = self.transcription_options.initial_prompt
logging.debug('Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s',
self.transcription_options, model_path, self.sample_rate, self.input_device_index)
logging.debug(
"Recording, transcription options = %s, model path = %s, sample rate = %s, device = %s",
self.transcription_options,
model_path,
self.sample_rate,
self.input_device_index,
)
self.is_running = True
try:
with sounddevice.InputStream(samplerate=self.sample_rate,
device=self.input_device_index, dtype="float32",
channels=1, callback=self.stream_callback):
with sounddevice.InputStream(
samplerate=self.sample_rate,
device=self.input_device_index,
dtype="float32",
channels=1,
callback=self.stream_callback,
):
while self.is_running:
self.mutex.acquire()
if self.queue.size >= self.n_batch_samples:
samples = self.queue[:self.n_batch_samples]
self.queue = self.queue[self.n_batch_samples:]
samples = self.queue[: self.n_batch_samples]
self.queue = self.queue[self.n_batch_samples :]
self.mutex.release()
logging.debug('Processing next frame, sample size = %s, queue size = %s, amplitude = %s',
samples.size, self.queue.size, self.amplitude(samples))
logging.debug(
"Processing next frame, sample size = %s, queue size = %s, amplitude = %s",
samples.size,
self.queue.size,
self.amplitude(samples),
)
time_started = datetime.datetime.now()
if self.transcription_options.model.model_type == ModelType.WHISPER:
if (
self.transcription_options.model.model_type
== ModelType.WHISPER
):
assert isinstance(model, whisper.Whisper)
result = model.transcribe(
audio=samples, language=self.transcription_options.language,
audio=samples,
language=self.transcription_options.language,
task=self.transcription_options.task.value,
initial_prompt=initial_prompt,
temperature=self.transcription_options.temperature)
elif self.transcription_options.model.model_type == ModelType.WHISPER_CPP:
temperature=self.transcription_options.temperature,
)
elif (
self.transcription_options.model.model_type
== ModelType.WHISPER_CPP
):
assert isinstance(model, WhisperCpp)
result = model.transcribe(
audio=samples,
params=whisper_cpp_params(
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value, word_level_timings=False))
if self.transcription_options.language is not None
else "en",
task=self.transcription_options.task.value,
word_level_timings=False,
),
)
else:
assert isinstance(model, TransformersWhisper)
result = model.transcribe(audio=samples,
language=self.transcription_options.language
if self.transcription_options.language is not None else 'en',
task=self.transcription_options.task.value)
result = model.transcribe(
audio=samples,
language=self.transcription_options.language
if self.transcription_options.language is not None
else "en",
task=self.transcription_options.task.value,
)
next_text: str = result.get('text')
next_text: str = result.get("text")
# Update initial prompt between successive recording chunks
initial_prompt += next_text
logging.debug('Received next result, length = %s, time taken = %s',
len(next_text), datetime.datetime.now() - time_started)
logging.debug(
"Received next result, length = %s, time taken = %s",
len(next_text),
datetime.datetime.now() - time_started,
)
self.transcription.emit(next_text)
else:
self.mutex.release()
except PortAudioError as exc:
self.error.emit(str(exc))
logging.exception('')
logging.exception("")
return
self.finished.emit()
@ -116,12 +152,13 @@ class RecordingTranscriber(QObject):
whisper_sample_rate = whisper.audio.SAMPLE_RATE
try:
sounddevice.check_input_settings(
device=device_id, samplerate=whisper_sample_rate)
device=device_id, samplerate=whisper_sample_rate
)
return whisper_sample_rate
except PortAudioError:
device_info = sounddevice.query_devices(device=device_id)
if isinstance(device_info, dict):
return int(device_info.get('default_samplerate', whisper_sample_rate))
return int(device_info.get("default_samplerate", whisper_sample_rate))
return whisper_sample_rate
def stream_callback(self, in_data: np.ndarray, frame_count, time_info, status):

View file

@ -3,7 +3,7 @@ import typing
from PyQt6.QtCore import QSettings
APP_NAME = 'Buzz'
APP_NAME = "Buzz"
class Settings:
@ -11,32 +11,38 @@ class Settings:
self.settings = QSettings(APP_NAME)
class Key(enum.Enum):
RECORDING_TRANSCRIBER_TASK = 'recording-transcriber/task'
RECORDING_TRANSCRIBER_MODEL = 'recording-transcriber/model'
RECORDING_TRANSCRIBER_LANGUAGE = 'recording-transcriber/language'
RECORDING_TRANSCRIBER_TEMPERATURE = 'recording-transcriber/temperature'
RECORDING_TRANSCRIBER_INITIAL_PROMPT = 'recording-transcriber/initial-prompt'
RECORDING_TRANSCRIBER_TASK = "recording-transcriber/task"
RECORDING_TRANSCRIBER_MODEL = "recording-transcriber/model"
RECORDING_TRANSCRIBER_LANGUAGE = "recording-transcriber/language"
RECORDING_TRANSCRIBER_TEMPERATURE = "recording-transcriber/temperature"
RECORDING_TRANSCRIBER_INITIAL_PROMPT = "recording-transcriber/initial-prompt"
FILE_TRANSCRIBER_TASK = 'file-transcriber/task'
FILE_TRANSCRIBER_MODEL = 'file-transcriber/model'
FILE_TRANSCRIBER_LANGUAGE = 'file-transcriber/language'
FILE_TRANSCRIBER_TEMPERATURE = 'file-transcriber/temperature'
FILE_TRANSCRIBER_INITIAL_PROMPT = 'file-transcriber/initial-prompt'
FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS = 'file-transcriber/word-level-timings'
FILE_TRANSCRIBER_EXPORT_FORMATS = 'file-transcriber/export-formats'
FILE_TRANSCRIBER_TASK = "file-transcriber/task"
FILE_TRANSCRIBER_MODEL = "file-transcriber/model"
FILE_TRANSCRIBER_LANGUAGE = "file-transcriber/language"
FILE_TRANSCRIBER_TEMPERATURE = "file-transcriber/temperature"
FILE_TRANSCRIBER_INITIAL_PROMPT = "file-transcriber/initial-prompt"
FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS = "file-transcriber/word-level-timings"
FILE_TRANSCRIBER_EXPORT_FORMATS = "file-transcriber/export-formats"
DEFAULT_EXPORT_FILE_NAME = 'transcriber/default-export-file-name'
DEFAULT_EXPORT_FILE_NAME = "transcriber/default-export-file-name"
SHORTCUTS = 'shortcuts'
SHORTCUTS = "shortcuts"
def set_value(self, key: Key, value: typing.Any) -> None:
self.settings.setValue(key.value, value)
def value(self, key: Key, default_value: typing.Any,
value_type: typing.Optional[type] = None) -> typing.Any:
return self.settings.value(key.value, default_value,
value_type if value_type is not None else type(
default_value))
def value(
self,
key: Key,
default_value: typing.Any,
value_type: typing.Optional[type] = None,
) -> typing.Any:
return self.settings.value(
key.value,
default_value,
value_type if value_type is not None else type(default_value),
)
def clear(self):
self.settings.clear()

View file

@ -13,13 +13,13 @@ class Shortcut(str, enum.Enum):
obj.description = description
return obj
OPEN_RECORD_WINDOW = ('Ctrl+R', "Open Record Window")
OPEN_IMPORT_WINDOW = ('Ctrl+O', "Import File")
OPEN_PREFERENCES_WINDOW = ('Ctrl+,', 'Open Preferences Window')
OPEN_RECORD_WINDOW = ("Ctrl+R", "Open Record Window")
OPEN_IMPORT_WINDOW = ("Ctrl+O", "Import File")
OPEN_PREFERENCES_WINDOW = ("Ctrl+,", "Open Preferences Window")
OPEN_TRANSCRIPT_EDITOR = ('Ctrl+E', "Open Transcript Viewer")
CLEAR_HISTORY = ('Ctrl+S', "Clear History")
STOP_TRANSCRIPTION = ('Ctrl+X', "Cancel Transcription")
OPEN_TRANSCRIPT_EDITOR = ("Ctrl+E", "Open Transcript Viewer")
CLEAR_HISTORY = ("Ctrl+S", "Clear History")
STOP_TRANSCRIPTION = ("Ctrl+X", "Cancel Transcription")
@staticmethod
def get_default_shortcuts() -> typing.Dict[str, str]:

View file

@ -10,7 +10,9 @@ class ShortcutSettings:
def load(self) -> typing.Dict[str, str]:
shortcuts = Shortcut.get_default_shortcuts()
custom_shortcuts: typing.Dict[str, str] = self.settings.value(Settings.Key.SHORTCUTS, {})
custom_shortcuts: typing.Dict[str, str] = self.settings.value(
Settings.Key.SHORTCUTS, {}
)
for shortcut_name in custom_shortcuts:
shortcuts[shortcut_name] = custom_shortcuts[shortcut_name]
return shortcuts

View file

@ -9,20 +9,20 @@ from buzz.settings.settings import APP_NAME
class KeyringStore:
class Key(enum.Enum):
OPENAI_API_KEY = 'OpenAI API key'
OPENAI_API_KEY = "OpenAI API key"
def get_password(self, key: Key) -> str:
try:
password = keyring.get_password(APP_NAME, username=key.value)
if password is None:
return ''
return ""
return password
except (KeyringLocked, KeyringError) as exc:
logging.error('Unable to read from keyring: %s', exc)
return ''
logging.error("Unable to read from keyring: %s", exc)
return ""
def set_password(self, username: Key, password: str) -> None:
try:
keyring.set_password(APP_NAME, username.value, password)
except (KeyringLocked, PasswordSetError) as exc:
logging.error('Unable to write to keyring: %s', exc)
logging.error("Unable to write to keyring: %s", exc)

View file

@ -38,7 +38,7 @@ try:
LOADED_WHISPER_DLL = True
except ImportError:
logging.exception('')
logging.exception("")
DEFAULT_WHISPER_TEMPERATURE = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
@ -65,27 +65,28 @@ class TranscriptionOptions:
model: TranscriptionModel = field(default_factory=TranscriptionModel)
word_level_timings: bool = False
temperature: Tuple[float, ...] = DEFAULT_WHISPER_TEMPERATURE
initial_prompt: str = ''
openai_access_token: str = field(default='',
metadata=config(exclude=Exclude.ALWAYS))
initial_prompt: str = ""
openai_access_token: str = field(
default="", metadata=config(exclude=Exclude.ALWAYS)
)
@dataclass()
class FileTranscriptionOptions:
file_paths: List[str]
output_formats: Set['OutputFormat'] = field(default_factory=set)
default_output_file_name: str = ''
output_formats: Set["OutputFormat"] = field(default_factory=set)
default_output_file_name: str = ""
@dataclass_json
@dataclass
class FileTranscriptionTask:
class Status(enum.Enum):
QUEUED = 'queued'
IN_PROGRESS = 'in_progress'
COMPLETED = 'completed'
FAILED = 'failed'
CANCELED = 'canceled'
QUEUED = "queued"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
CANCELED = "canceled"
file_path: str
transcription_options: TranscriptionOptions
@ -102,9 +103,9 @@ class FileTranscriptionTask:
class OutputFormat(enum.Enum):
TXT = 'txt'
SRT = 'srt'
VTT = 'vtt'
TXT = "txt"
SRT = "srt"
VTT = "vtt"
class FileTranscriber(QObject):
@ -113,8 +114,7 @@ class FileTranscriber(QObject):
completed = pyqtSignal(list) # List[Segment]
error = pyqtSignal(Exception)
def __init__(self, task: FileTranscriptionTask,
parent: Optional['QObject'] = None):
def __init__(self, task: FileTranscriptionTask, parent: Optional["QObject"] = None):
super().__init__(parent)
self.transcription_task = task
@ -128,12 +128,16 @@ class FileTranscriber(QObject):
self.completed.emit(segments)
for output_format in self.transcription_task.file_transcription_options.output_formats:
default_path = get_default_output_file_path(task=self.transcription_task,
output_format=output_format)
for (
output_format
) in self.transcription_task.file_transcription_options.output_formats:
default_path = get_default_output_file_path(
task=self.transcription_task, output_format=output_format
)
write_output(path=default_path, segments=segments,
output_format=output_format)
write_output(
path=default_path, segments=segments, output_format=output_format
)
@abstractmethod
def transcribe(self) -> List[Segment]:
@ -150,13 +154,14 @@ class Stopped(Exception):
class WhisperCppFileTranscriber(FileTranscriber):
duration_audio_ms = sys.maxsize # max int
state: 'WhisperCppFileTranscriber.State'
state: "WhisperCppFileTranscriber.State"
class State:
running = True
def __init__(self, task: FileTranscriptionTask,
parent: Optional['QObject'] = None) -> None:
def __init__(
self, task: FileTranscriptionTask, parent: Optional["QObject"] = None
) -> None:
super().__init__(task, parent)
self.file_path = task.file_path
@ -171,24 +176,33 @@ class WhisperCppFileTranscriber(FileTranscriber):
model_path = self.model_path
logging.debug(
'Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s, '
'word level timings = %s',
self.file_path, self.language, self.task, model_path,
self.word_level_timings)
"Starting whisper_cpp file transcription, file path = %s, language = %s, task = %s, model_path = %s, "
"word level timings = %s",
self.file_path,
self.language,
self.task,
model_path,
self.word_level_timings,
)
audio = whisper.audio.load_audio(self.file_path)
self.duration_audio_ms = len(audio) * 1000 / whisper.audio.SAMPLE_RATE
whisper_params = whisper_cpp_params(
language=self.language if self.language is not None else '', task=self.task,
word_level_timings=self.word_level_timings)
language=self.language if self.language is not None else "",
task=self.task,
word_level_timings=self.word_level_timings,
)
whisper_params.encoder_begin_callback_user_data = ctypes.c_void_p(
id(self.state))
whisper_params.encoder_begin_callback = whisper_cpp.whisper_encoder_begin_callback(
self.encoder_begin_callback)
id(self.state)
)
whisper_params.encoder_begin_callback = (
whisper_cpp.whisper_encoder_begin_callback(self.encoder_begin_callback)
)
whisper_params.new_segment_callback_user_data = ctypes.c_void_p(id(self.state))
whisper_params.new_segment_callback = whisper_cpp.whisper_new_segment_callback(
self.new_segment_callback)
self.new_segment_callback
)
model = WhisperCpp(model=model_path)
result = model.transcribe(audio=self.file_path, params=whisper_params)
@ -197,7 +211,7 @@ class WhisperCppFileTranscriber(FileTranscriber):
raise Stopped
self.state.running = False
return result['segments']
return result["segments"]
def new_segment_callback(self, ctx, _state, _n_new, user_data):
n_segments = whisper_cpp.whisper_full_n_segments(ctx)
@ -205,15 +219,17 @@ class WhisperCppFileTranscriber(FileTranscriber):
# t1 seems to sometimes be larger than the duration when the
# audio ends in silence. Trim to fix the displayed progress.
progress = min(t1 * 10, self.duration_audio_ms)
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data,
ctypes.py_object).value
state: WhisperCppFileTranscriber.State = ctypes.cast(
user_data, ctypes.py_object
).value
if state.running:
self.progress.emit((progress, self.duration_audio_ms))
@staticmethod
def encoder_begin_callback(_ctx, _state, user_data):
state: WhisperCppFileTranscriber.State = ctypes.cast(user_data,
ctypes.py_object).value
state: WhisperCppFileTranscriber.State = ctypes.cast(
user_data, ctypes.py_object
).value
return state.running == 1
def stop(self):
@ -221,18 +237,19 @@ class WhisperCppFileTranscriber(FileTranscriber):
class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
def __init__(self, task: FileTranscriptionTask, parent: Optional['QObject'] = None):
def __init__(self, task: FileTranscriptionTask, parent: Optional["QObject"] = None):
super().__init__(task=task, parent=parent)
self.file_path = task.file_path
self.task = task.transcription_options.task
def transcribe(self) -> List[Segment]:
logging.debug(
'Starting OpenAI Whisper API file transcription, file path = %s, task = %s',
"Starting OpenAI Whisper API file transcription, file path = %s, task = %s",
self.file_path,
self.task)
self.task,
)
wav_file = tempfile.mktemp() + '.wav'
wav_file = tempfile.mktemp() + ".wav"
(
ffmpeg.input(self.file_path)
.output(wav_file, acodec="pcm_s16le", ac=1, ar=whisper.audio.SAMPLE_RATE)
@ -241,22 +258,30 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
# TODO: Check if file size is more than 25MB (2.5 minutes), then chunk
audio_file = open(wav_file, "rb")
openai.api_key = self.transcription_task.transcription_options.openai_access_token
openai.api_key = (
self.transcription_task.transcription_options.openai_access_token
)
language = self.transcription_task.transcription_options.language
response_format = "verbose_json"
if self.transcription_task.transcription_options.task == Task.TRANSLATE:
transcript = openai.Audio.translate("whisper-1", audio_file,
response_format=response_format,
language=language)
transcript = openai.Audio.translate(
"whisper-1",
audio_file,
response_format=response_format,
language=language,
)
else:
transcript = openai.Audio.transcribe("whisper-1", audio_file,
response_format=response_format,
language=language)
transcript = openai.Audio.transcribe(
"whisper-1",
audio_file,
response_format=response_format,
language=language,
)
segments = [
Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"]) for
segment in
transcript["segments"]]
Segment(segment["start"] * 1000, segment["end"] * 1000, segment["text"])
for segment in transcript["segments"]
]
return segments
def stop(self):
@ -265,15 +290,16 @@ class OpenAIWhisperAPIFileTranscriber(FileTranscriber):
class WhisperFileTranscriber(FileTranscriber):
"""WhisperFileTranscriber transcribes an audio file to text, writes the text to a file, and then opens the file
using the default program for opening txt files. """
using the default program for opening txt files."""
current_process: multiprocessing.Process
running = False
read_line_thread: Optional[Thread] = None
READ_LINE_THREAD_STOP_TOKEN = '--STOP--'
READ_LINE_THREAD_STOP_TOKEN = "--STOP--"
def __init__(self, task: FileTranscriptionTask,
parent: Optional['QObject'] = None) -> None:
def __init__(
self, task: FileTranscriptionTask, parent: Optional["QObject"] = None
) -> None:
super().__init__(task, parent)
self.segments = []
self.started_process = False
@ -282,19 +308,19 @@ class WhisperFileTranscriber(FileTranscriber):
def transcribe(self) -> List[Segment]:
time_started = datetime.datetime.now()
logging.debug(
'Starting whisper file transcription, task = %s', self.transcription_task)
"Starting whisper file transcription, task = %s", self.transcription_task
)
recv_pipe, send_pipe = multiprocessing.Pipe(duplex=False)
self.current_process = multiprocessing.Process(target=self.transcribe_whisper,
args=(send_pipe,
self.transcription_task))
self.current_process = multiprocessing.Process(
target=self.transcribe_whisper, args=(send_pipe, self.transcription_task)
)
if not self.stopped:
self.current_process.start()
self.started_process = True
self.read_line_thread = Thread(
target=self.read_line, args=(recv_pipe,))
self.read_line_thread = Thread(target=self.read_line, args=(recv_pipe,))
self.read_line_thread.start()
self.current_process.join()
@ -305,76 +331,96 @@ class WhisperFileTranscriber(FileTranscriber):
self.read_line_thread.join()
logging.debug(
'whisper process completed with code = %s, time taken = %s, number of segments = %s',
self.current_process.exitcode, datetime.datetime.now() - time_started,
len(self.segments))
"whisper process completed with code = %s, time taken = %s, number of segments = %s",
self.current_process.exitcode,
datetime.datetime.now() - time_started,
len(self.segments),
)
if self.current_process.exitcode != 0:
raise Exception('Unknown error')
raise Exception("Unknown error")
return self.segments
@classmethod
def transcribe_whisper(cls, stderr_conn: Connection,
task: FileTranscriptionTask) -> None:
def transcribe_whisper(
cls, stderr_conn: Connection, task: FileTranscriptionTask
) -> None:
with pipe_stderr(stderr_conn):
if task.transcription_options.model.model_type == ModelType.HUGGING_FACE:
segments = cls.transcribe_hugging_face(task)
elif task.transcription_options.model.model_type == ModelType.FASTER_WHISPER:
elif (
task.transcription_options.model.model_type == ModelType.FASTER_WHISPER
):
segments = cls.transcribe_faster_whisper(task)
elif task.transcription_options.model.model_type == ModelType.WHISPER:
segments = cls.transcribe_openai_whisper(task)
else:
raise Exception(
f"Invalid model type: {task.transcription_options.model.model_type}")
f"Invalid model type: {task.transcription_options.model.model_type}"
)
segments_json = json.dumps(
segments, ensure_ascii=True, default=vars)
sys.stderr.write(f'segments = {segments_json}\n')
sys.stderr.write(
WhisperFileTranscriber.READ_LINE_THREAD_STOP_TOKEN + '\n')
segments_json = json.dumps(segments, ensure_ascii=True, default=vars)
sys.stderr.write(f"segments = {segments_json}\n")
sys.stderr.write(WhisperFileTranscriber.READ_LINE_THREAD_STOP_TOKEN + "\n")
@classmethod
def transcribe_hugging_face(cls, task: FileTranscriptionTask) -> List[Segment]:
model = transformers_whisper.load_model(task.model_path)
language = task.transcription_options.language if task.transcription_options.language is not None else 'en'
result = model.transcribe(audio=task.file_path, language=language,
task=task.transcription_options.task.value,
verbose=False)
language = (
task.transcription_options.language
if task.transcription_options.language is not None
else "en"
)
result = model.transcribe(
audio=task.file_path,
language=language,
task=task.transcription_options.task.value,
verbose=False,
)
return [
Segment(
start=int(segment.get('start') * 1000),
end=int(segment.get('end') * 1000),
text=segment.get('text'),
) for segment in result.get('segments')]
start=int(segment.get("start") * 1000),
end=int(segment.get("end") * 1000),
text=segment.get("text"),
)
for segment in result.get("segments")
]
@classmethod
def transcribe_faster_whisper(cls, task: FileTranscriptionTask) -> List[Segment]:
model = faster_whisper.WhisperModel(
model_size_or_path=task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size())
whisper_segments, info = model.transcribe(audio=task.file_path,
language=task.transcription_options.language,
task=task.transcription_options.task.value,
temperature=task.transcription_options.temperature,
initial_prompt=task.transcription_options.initial_prompt,
word_timestamps=task.transcription_options.word_level_timings)
model_size_or_path=task.transcription_options.model.whisper_model_size.to_faster_whisper_model_size()
)
whisper_segments, info = model.transcribe(
audio=task.file_path,
language=task.transcription_options.language,
task=task.transcription_options.task.value,
temperature=task.transcription_options.temperature,
initial_prompt=task.transcription_options.initial_prompt,
word_timestamps=task.transcription_options.word_level_timings,
)
segments = []
with tqdm.tqdm(total=round(info.duration, 2), unit=' seconds') as pbar:
with tqdm.tqdm(total=round(info.duration, 2), unit=" seconds") as pbar:
for segment in list(whisper_segments):
# Segment will contain words if word-level timings is True
if segment.words:
for word in segment.words:
segments.append(Segment(
start=int(word.start * 1000),
end=int(word.end * 1000),
text=word.word
))
segments.append(
Segment(
start=int(word.start * 1000),
end=int(word.end * 1000),
text=word.word,
)
)
else:
segments.append(Segment(
start=int(segment.start * 1000),
end=int(segment.end * 1000),
text=segment.text
))
segments.append(
Segment(
start=int(segment.start * 1000),
end=int(segment.end * 1000),
text=segment.text,
)
)
pbar.update(segment.end - segment.start)
return segments
@ -386,28 +432,40 @@ class WhisperFileTranscriber(FileTranscriber):
if task.transcription_options.word_level_timings:
stable_whisper.modify_model(model)
result = model.transcribe(
audio=task.file_path, language=task.transcription_options.language,
audio=task.file_path,
language=task.transcription_options.language,
task=task.transcription_options.task.value,
temperature=task.transcription_options.temperature,
initial_prompt=task.transcription_options.initial_prompt, pbar=True)
initial_prompt=task.transcription_options.initial_prompt,
pbar=True,
)
segments = stable_whisper.group_word_timestamps(result)
return [Segment(
start=int(segment.get('start') * 1000),
end=int(segment.get('end') * 1000),
text=segment.get('text'),
) for segment in segments]
return [
Segment(
start=int(segment.get("start") * 1000),
end=int(segment.get("end") * 1000),
text=segment.get("text"),
)
for segment in segments
]
result = model.transcribe(
audio=task.file_path, language=task.transcription_options.language,
audio=task.file_path,
language=task.transcription_options.language,
task=task.transcription_options.task.value,
temperature=task.transcription_options.temperature,
initial_prompt=task.transcription_options.initial_prompt, verbose=False)
segments = result.get('segments')
return [Segment(
start=int(segment.get('start') * 1000),
end=int(segment.get('end') * 1000),
text=segment.get('text'),
) for segment in segments]
initial_prompt=task.transcription_options.initial_prompt,
verbose=False,
)
segments = result.get("segments")
return [
Segment(
start=int(segment.get("start") * 1000),
end=int(segment.get("end") * 1000),
text=segment.get("text"),
)
for segment in segments
]
def stop(self):
self.stopped = True
@ -424,102 +482,119 @@ class WhisperFileTranscriber(FileTranscriber):
if line == self.READ_LINE_THREAD_STOP_TOKEN:
return
if line.startswith('segments = '):
if line.startswith("segments = "):
segments_dict = json.loads(line[11:])
segments = [Segment(
start=segment.get('start'),
end=segment.get('end'),
text=segment.get('text'),
) for segment in segments_dict]
segments = [
Segment(
start=segment.get("start"),
end=segment.get("end"),
text=segment.get("text"),
)
for segment in segments_dict
]
self.segments = segments
else:
try:
progress = int(line.split('|')[0].strip().strip('%'))
progress = int(line.split("|")[0].strip().strip("%"))
self.progress.emit((progress, 100))
except ValueError:
logging.debug('whisper (stderr): %s', line)
logging.debug("whisper (stderr): %s", line)
continue
def write_output(path: str, segments: List[Segment], output_format: OutputFormat):
logging.debug(
'Writing transcription output, path = %s, output format = %s, number of segments = %s',
path, output_format,
len(segments))
"Writing transcription output, path = %s, output format = %s, number of segments = %s",
path,
output_format,
len(segments),
)
with open(path, 'w', encoding='utf-8') as file:
with open(path, "w", encoding="utf-8") as file:
if output_format == OutputFormat.TXT:
for (i, segment) in enumerate(segments):
for i, segment in enumerate(segments):
file.write(segment.text)
file.write('\n')
file.write("\n")
elif output_format == OutputFormat.VTT:
file.write('WEBVTT\n\n')
file.write("WEBVTT\n\n")
for segment in segments:
file.write(
f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n')
file.write(f'{segment.text}\n\n')
f"{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n"
)
file.write(f"{segment.text}\n\n")
elif output_format == OutputFormat.SRT:
for (i, segment) in enumerate(segments):
file.write(f'{i + 1}\n')
for i, segment in enumerate(segments):
file.write(f"{i + 1}\n")
file.write(
f'{to_timestamp(segment.start, ms_separator=",")} --> {to_timestamp(segment.end, ms_separator=",")}\n')
file.write(f'{segment.text}\n\n')
f'{to_timestamp(segment.start, ms_separator=",")} --> {to_timestamp(segment.end, ms_separator=",")}\n'
)
file.write(f"{segment.text}\n\n")
logging.debug('Written transcription output')
logging.debug("Written transcription output")
def segments_to_text(segments: List[Segment]) -> str:
result = ''
for (i, segment) in enumerate(segments):
result += f'{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n'
result += f'{segment.text}'
result = ""
for i, segment in enumerate(segments):
result += f"{to_timestamp(segment.start)} --> {to_timestamp(segment.end)}\n"
result += f"{segment.text}"
if i < len(segments) - 1:
result += '\n\n'
result += "\n\n"
return result
def to_timestamp(ms: float, ms_separator='.') -> str:
def to_timestamp(ms: float, ms_separator=".") -> str:
hr = int(ms / (1000 * 60 * 60))
ms = ms - hr * (1000 * 60 * 60)
min = int(ms / (1000 * 60))
ms = ms - min * (1000 * 60)
sec = int(ms / 1000)
ms = int(ms - sec * 1000)
return f'{hr:02d}:{min:02d}:{sec:02d}{ms_separator}{ms:03d}'
return f"{hr:02d}:{min:02d}:{sec:02d}{ms_separator}{ms:03d}"
SUPPORTED_OUTPUT_FORMATS = 'Audio files (*.mp3 *.wav *.m4a *.ogg);;\
Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)'
SUPPORTED_OUTPUT_FORMATS = "Audio files (*.mp3 *.wav *.m4a *.ogg);;\
Video files (*.mp4 *.webm *.ogm *.mov);;All files (*.*)"
def get_default_output_file_path(task: FileTranscriptionTask,
output_format: OutputFormat):
def get_default_output_file_path(
task: FileTranscriptionTask, output_format: OutputFormat
):
input_file_name = os.path.splitext(task.file_path)[0]
date_time_now = datetime.datetime.now().strftime('%d-%b-%Y %H-%M-%S')
return (task.file_transcription_options.default_output_file_name
.replace('{{ input_file_name }}', input_file_name)
.replace('{{ task }}', task.transcription_options.task.value)
.replace('{{ language }}', task.transcription_options.language or '')
.replace('{{ model_type }}',
task.transcription_options.model.model_type.value)
.replace('{{ model_size }}',
task.transcription_options.model.whisper_model_size.value if
task.transcription_options.model.whisper_model_size is not None else
'')
.replace('{{ date_time }}', date_time_now)
+ f".{output_format.value}")
date_time_now = datetime.datetime.now().strftime("%d-%b-%Y %H-%M-%S")
return (
task.file_transcription_options.default_output_file_name.replace(
"{{ input_file_name }}", input_file_name
)
.replace("{{ task }}", task.transcription_options.task.value)
.replace("{{ language }}", task.transcription_options.language or "")
.replace("{{ model_type }}", task.transcription_options.model.model_type.value)
.replace(
"{{ model_size }}",
task.transcription_options.model.whisper_model_size.value
if task.transcription_options.model.whisper_model_size is not None
else "",
)
.replace("{{ date_time }}", date_time_now)
+ f".{output_format.value}"
)
def whisper_cpp_params(
language: str, task: Task, word_level_timings: bool,
print_realtime=False, print_progress=False, ):
language: str,
task: Task,
word_level_timings: bool,
print_realtime=False,
print_progress=False,
):
params = whisper_cpp.whisper_full_default_params(
whisper_cpp.WHISPER_SAMPLING_GREEDY)
whisper_cpp.WHISPER_SAMPLING_GREEDY
)
params.print_realtime = print_realtime
params.print_progress = print_progress
params.language = whisper_cpp.String(language.encode('utf-8'))
params.language = whisper_cpp.String(language.encode("utf-8"))
params.translate = task == Task.TRANSLATE
params.max_len = ctypes.c_int(1)
params.max_len = 1 if word_level_timings else 0
@ -529,20 +604,20 @@ def whisper_cpp_params(
class WhisperCpp:
def __init__(self, model: str) -> None:
self.ctx = whisper_cpp.whisper_init_from_file(model.encode('utf-8'))
self.ctx = whisper_cpp.whisper_init_from_file(model.encode("utf-8"))
def transcribe(self, audio: Union[np.ndarray, str], params: Any):
if isinstance(audio, str):
audio = whisper.audio.load_audio(audio)
logging.debug('Loaded audio with length = %s', len(audio))
logging.debug("Loaded audio with length = %s", len(audio))
whisper_cpp_audio = audio.ctypes.data_as(
ctypes.POINTER(ctypes.c_float))
whisper_cpp_audio = audio.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
result = whisper_cpp.whisper_full(
self.ctx, params, whisper_cpp_audio, len(audio))
self.ctx, params, whisper_cpp_audio, len(audio)
)
if result != 0:
raise Exception(f'Error from whisper.cpp: {result}')
raise Exception(f"Error from whisper.cpp: {result}")
segments: List[Segment] = []
@ -553,13 +628,17 @@ class WhisperCpp:
t1 = whisper_cpp.whisper_full_get_segment_t1((self.ctx), i)
segments.append(
Segment(start=t0 * 10, # centisecond to ms
end=t1 * 10, # centisecond to ms
text=txt.decode('utf-8')))
Segment(
start=t0 * 10, # centisecond to ms
end=t1 * 10, # centisecond to ms
text=txt.decode("utf-8"),
)
)
return {
'segments': segments,
'text': ''.join([segment.text for segment in segments])}
"segments": segments,
"text": "".join([segment.text for segment in segments]),
}
def __del__(self):
whisper_cpp.whisper_free(self.ctx)

View file

@ -16,43 +16,65 @@ class TransformersWhisper:
SAMPLE_RATE = whisper.audio.SAMPLE_RATE
N_SAMPLES_IN_CHUNK = whisper.audio.N_SAMPLES
def __init__(self, processor: WhisperProcessor, model: WhisperForConditionalGeneration):
def __init__(
self, processor: WhisperProcessor, model: WhisperForConditionalGeneration
):
self.processor = processor
self.model = model
# Patch implementation of transcribing with transformers' WhisperProcessor until long-form transcription and
# timestamps are available. See: https://github.com/huggingface/transformers/issues/19887,
# https://github.com/huggingface/transformers/pull/20620.
def transcribe(self, audio: Union[str, np.ndarray], language: str, task: str, verbose: Optional[bool] = None):
def transcribe(
self,
audio: Union[str, np.ndarray],
language: str,
task: str,
verbose: Optional[bool] = None,
):
if isinstance(audio, str):
audio = whisper.load_audio(audio, sr=self.SAMPLE_RATE)
self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(task=task, language=language)
self.model.config.forced_decoder_ids = self.processor.get_decoder_prompt_ids(
task=task, language=language
)
segments = []
all_predicted_ids = []
num_samples = audio.size
seek = 0
with tqdm(total=num_samples, unit='samples', disable=verbose is not False) as progress_bar:
with tqdm(
total=num_samples, unit="samples", disable=verbose is not False
) as progress_bar:
while seek < num_samples:
chunk = audio[seek: seek + self.N_SAMPLES_IN_CHUNK]
input_features = self.processor(chunk, return_tensors="pt",
sampling_rate=self.SAMPLE_RATE).input_features
chunk = audio[seek : seek + self.N_SAMPLES_IN_CHUNK]
input_features = self.processor(
chunk, return_tensors="pt", sampling_rate=self.SAMPLE_RATE
).input_features
predicted_ids = self.model.generate(input_features)
all_predicted_ids.extend(predicted_ids)
text: str = self.processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
if text.strip() != '':
segments.append({
'start': seek / self.SAMPLE_RATE,
'end': min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) / self.SAMPLE_RATE,
'text': text
})
text: str = self.processor.batch_decode(
predicted_ids, skip_special_tokens=True
)[0]
if text.strip() != "":
segments.append(
{
"start": seek / self.SAMPLE_RATE,
"end": min(seek + self.N_SAMPLES_IN_CHUNK, num_samples)
/ self.SAMPLE_RATE,
"text": text,
}
)
progress_bar.update(min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek)
progress_bar.update(
min(seek + self.N_SAMPLES_IN_CHUNK, num_samples) - seek
)
seek += self.N_SAMPLES_IN_CHUNK
return {
'text': self.processor.batch_decode(all_predicted_ids, skip_special_tokens=True)[0],
'segments': segments
"text": self.processor.batch_decode(
all_predicted_ids, skip_special_tokens=True
)[0],
"segments": segments,
}

View file

@ -5,8 +5,15 @@ from PyQt6 import QtGui
from PyQt6.QtCore import Qt, QUrl
from PyQt6.QtGui import QIcon, QPixmap, QDesktopServices
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply
from PyQt6.QtWidgets import QDialog, QWidget, QVBoxLayout, QLabel, QPushButton, \
QDialogButtonBox, QMessageBox
from PyQt6.QtWidgets import (
QDialog,
QWidget,
QVBoxLayout,
QLabel,
QPushButton,
QDialogButtonBox,
QMessageBox,
)
from buzz.__version__ import VERSION
from buzz.widgets.icon import BUZZ_ICON_PATH, BUZZ_LARGE_ICON_PATH
@ -15,11 +22,16 @@ from buzz.settings.settings import APP_NAME
class AboutDialog(QDialog):
GITHUB_API_LATEST_RELEASE_URL = 'https://api.github.com/repos/chidiwilliams/buzz/releases/latest'
GITHUB_LATEST_RELEASE_URL = 'https://github.com/chidiwilliams/buzz/releases/latest'
GITHUB_API_LATEST_RELEASE_URL = (
"https://api.github.com/repos/chidiwilliams/buzz/releases/latest"
)
GITHUB_LATEST_RELEASE_URL = "https://github.com/chidiwilliams/buzz/releases/latest"
def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None,
parent: Optional[QWidget] = None) -> None:
def __init__(
self,
network_access_manager: Optional[QNetworkAccessManager] = None,
parent: Optional[QWidget] = None,
) -> None:
super().__init__(parent)
self.setWindowIcon(QIcon(BUZZ_ICON_PATH))
@ -35,28 +47,42 @@ class AboutDialog(QDialog):
image_label = QLabel()
pixmap = QPixmap(BUZZ_LARGE_ICON_PATH).scaled(
80, 80, Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation)
80,
80,
Qt.AspectRatioMode.KeepAspectRatio,
Qt.TransformationMode.SmoothTransformation,
)
image_label.setPixmap(pixmap)
image_label.setAlignment(Qt.AlignmentFlag(
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter))
image_label.setAlignment(
Qt.AlignmentFlag(
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter
)
)
buzz_label = QLabel(APP_NAME)
buzz_label.setAlignment(Qt.AlignmentFlag(
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter))
buzz_label.setAlignment(
Qt.AlignmentFlag(
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter
)
)
buzz_label_font = QtGui.QFont()
buzz_label_font.setBold(True)
buzz_label_font.setPointSize(20)
buzz_label.setFont(buzz_label_font)
version_label = QLabel(f"{_('Version')} {VERSION}")
version_label.setAlignment(Qt.AlignmentFlag(
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter))
version_label.setAlignment(
Qt.AlignmentFlag(
Qt.AlignmentFlag.AlignVCenter | Qt.AlignmentFlag.AlignHCenter
)
)
self.check_updates_button = QPushButton(_('Check for updates'), self)
self.check_updates_button = QPushButton(_("Check for updates"), self)
self.check_updates_button.clicked.connect(self.on_click_check_for_updates)
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
QDialogButtonBox.StandardButton.Close), self)
button_box = QDialogButtonBox(
QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Close), self
)
button_box.accepted.connect(self.accept)
button_box.rejected.connect(self.reject)
@ -76,13 +102,13 @@ class AboutDialog(QDialog):
def on_latest_release_reply(self, reply: QNetworkReply):
if reply.error() == QNetworkReply.NetworkError.NoError:
response = json.loads(reply.readAll().data())
tag_name = response.get('name')
tag_name = response.get("name")
if self.is_version_lower(VERSION, tag_name[1:]):
QDesktopServices.openUrl(QUrl(self.GITHUB_LATEST_RELEASE_URL))
else:
QMessageBox.information(self, '', _("You're up to date!"))
QMessageBox.information(self, "", _("You're up to date!"))
self.check_updates_button.setEnabled(True)
@staticmethod
def is_version_lower(version_a: str, version_b: str):
return version_a.replace('.', '') < version_b.replace('.', '')
return version_a.replace(".", "") < version_b.replace(".", "")

View file

@ -111,7 +111,7 @@ class AudioPlayer(QWidget):
def update_time_label(self):
position_time = QTime(0, 0).addMSecs(self.position).toString()
duration_time = QTime(0, 0).addMSecs(self.duration).toString()
self.time_label.setText(f'{position_time} / {duration_time}')
self.time_label.setText(f"{position_time} / {duration_time}")
def stop(self):
self.media_player.stop()

View file

@ -6,8 +6,8 @@ from buzz.assets import get_asset_path
# TODO: move icons to Qt resources: https://stackoverflow.com/a/52341917/9830227
class Icon(QIcon):
LIGHT_THEME_BACKGROUND = '#555'
DARK_THEME_BACKGROUND = '#EEE'
LIGHT_THEME_BACKGROUND = "#555"
DARK_THEME_BACKGROUND = "#EEE"
def __init__(self, path: str, parent: QWidget):
# Adapted from https://stackoverflow.com/questions/15123544/change-the-color-of-an-svg-in-qt
@ -23,18 +23,20 @@ class Icon(QIcon):
super().__init__(pixmap)
def get_color(self, is_dark_theme):
return self.DARK_THEME_BACKGROUND if is_dark_theme else self.LIGHT_THEME_BACKGROUND
return (
self.DARK_THEME_BACKGROUND if is_dark_theme else self.LIGHT_THEME_BACKGROUND
)
class PlayIcon(Icon):
def __init__(self, parent: QWidget):
super().__init__(get_asset_path('assets/play_arrow_black_24dp.svg'), parent)
super().__init__(get_asset_path("assets/play_arrow_black_24dp.svg"), parent)
class PauseIcon(Icon):
def __init__(self, parent: QWidget):
super().__init__(get_asset_path('assets/pause_black_24dp.svg'), parent)
super().__init__(get_asset_path("assets/pause_black_24dp.svg"), parent)
BUZZ_ICON_PATH = get_asset_path('assets/buzz.ico')
BUZZ_LARGE_ICON_PATH = get_asset_path('assets/buzz-icon-1024.png')
BUZZ_ICON_PATH = get_asset_path("assets/buzz.ico")
BUZZ_LARGE_ICON_PATH = get_asset_path("assets/buzz-icon-1024.png")

View file

@ -5,7 +5,7 @@ from PyQt6.QtWidgets import QLineEdit, QWidget
class LineEdit(QLineEdit):
def __init__(self, default_text: str = '', parent: Optional[QWidget] = None):
def __init__(self, default_text: str = "", parent: Optional[QWidget] = None):
super().__init__(default_text, parent)
if platform.system() == 'Darwin':
self.setStyleSheet('QLineEdit { padding: 4px }')
if platform.system() == "Darwin":
self.setStyleSheet("QLineEdit { padding: 4px }")

View file

@ -18,16 +18,16 @@ class MenuBar(QMenuBar):
openai_api_key_changed = pyqtSignal(str)
default_export_file_name_changed = pyqtSignal(str)
def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str,
parent: QWidget):
def __init__(
self, shortcuts: Dict[str, str], default_export_file_name: str, parent: QWidget
):
super().__init__(parent)
self.shortcuts = shortcuts
self.default_export_file_name = default_export_file_name
self.import_action = QAction(_("Import Media File..."), self)
self.import_action.triggered.connect(
self.on_import_action_triggered)
self.import_action.triggered.connect(self.on_import_action_triggered)
about_action = QAction(f'{_("About")} {APP_NAME}', self)
about_action.triggered.connect(self.on_about_action_triggered)
@ -56,22 +56,27 @@ class MenuBar(QMenuBar):
about_dialog.open()
def on_preferences_action_triggered(self):
preferences_dialog = PreferencesDialog(shortcuts=self.shortcuts,
default_export_file_name=self.default_export_file_name,
parent=self)
preferences_dialog = PreferencesDialog(
shortcuts=self.shortcuts,
default_export_file_name=self.default_export_file_name,
parent=self,
)
preferences_dialog.shortcuts_changed.connect(self.shortcuts_changed)
preferences_dialog.openai_api_key_changed.connect(self.openai_api_key_changed)
preferences_dialog.default_export_file_name_changed.connect(
self.default_export_file_name_changed)
self.default_export_file_name_changed
)
preferences_dialog.open()
def on_help_action_triggered(self):
webbrowser.open('https://chidiwilliams.github.io/buzz/docs')
webbrowser.open("https://chidiwilliams.github.io/buzz/docs")
def set_shortcuts(self, shortcuts: Dict[str, str]):
self.shortcuts = shortcuts
self.import_action.setShortcut(
QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name]))
QKeySequence.fromString(shortcuts[Shortcut.OPEN_IMPORT_WINDOW.name])
)
self.preferences_action.setShortcut(
QKeySequence.fromString(shortcuts[Shortcut.OPEN_PREFERENCES_WINDOW.name]))
QKeySequence.fromString(shortcuts[Shortcut.OPEN_PREFERENCES_WINDOW.name])
)

View file

@ -10,10 +10,17 @@ from buzz.model_loader import ModelType
class ModelDownloadProgressDialog(QProgressDialog):
def __init__(self, model_type: ModelType, parent: Optional[QWidget] = None, modality=Qt.WindowModality.WindowModal):
def __init__(
self,
model_type: ModelType,
parent: Optional[QWidget] = None,
modality=Qt.WindowModality.WindowModal,
):
super().__init__(parent)
self.cancelable = model_type == ModelType.WHISPER or model_type == ModelType.WHISPER_CPP
self.cancelable = (
model_type == ModelType.WHISPER or model_type == ModelType.WHISPER_CPP
)
self.start_time = datetime.now()
self.setRange(0, 100)
self.setMinimumDuration(0)
@ -21,7 +28,7 @@ class ModelDownloadProgressDialog(QProgressDialog):
self.update_label_text(0)
if not self.cancelable:
cancel_button = QPushButton('Cancel', self)
cancel_button = QPushButton("Cancel", self)
cancel_button.setEnabled(False)
self.setCancelButton(cancel_button)
@ -30,8 +37,8 @@ class ModelDownloadProgressDialog(QProgressDialog):
if fraction_completed > 0:
time_spent = (datetime.now() - self.start_time).total_seconds()
time_left = (time_spent / fraction_completed) - time_spent
label_text += f', {humanize.naturaldelta(time_left)} remaining'
label_text += ')'
label_text += f", {humanize.naturaldelta(time_left)} remaining"
label_text += ")"
self.setLabelText(label_text)

View file

@ -10,8 +10,12 @@ from buzz.transcriber import LOADED_WHISPER_DLL
class ModelTypeComboBox(QComboBox):
changed = pyqtSignal(ModelType)
def __init__(self, model_types: Optional[List[ModelType]] = None, default_model: Optional[ModelType] = None,
parent: Optional[QWidget] = None):
def __init__(
self,
model_types: Optional[List[ModelType]] = None,
default_model: Optional[ModelType] = None,
parent: Optional[QWidget] = None,
):
super().__init__(parent)
if model_types is None:

View file

@ -16,15 +16,22 @@ class OpenAIAPIKeyLineEdit(LineEdit):
self.key = key
self.visible_on_icon = Icon(get_asset_path('assets/visibility_FILL0_wght700_GRAD0_opsz48.svg'), self)
self.visible_off_icon = Icon(get_asset_path('assets/visibility_off_FILL0_wght700_GRAD0_opsz48.svg'), self)
self.visible_on_icon = Icon(
get_asset_path("assets/visibility_FILL0_wght700_GRAD0_opsz48.svg"), self
)
self.visible_off_icon = Icon(
get_asset_path("assets/visibility_off_FILL0_wght700_GRAD0_opsz48.svg"), self
)
self.setPlaceholderText('sk-...')
self.setPlaceholderText("sk-...")
self.setEchoMode(QLineEdit.EchoMode.Password)
self.textChanged.connect(self.on_openai_api_key_changed)
self.toggle_show_openai_api_key_action = self.addAction(self.visible_on_icon,
QLineEdit.ActionPosition.TrailingPosition)
self.toggle_show_openai_api_key_action.triggered.connect(self.on_toggle_show_action_triggered)
self.toggle_show_openai_api_key_action = self.addAction(
self.visible_on_icon, QLineEdit.ActionPosition.TrailingPosition
)
self.toggle_show_openai_api_key_action.triggered.connect(
self.on_toggle_show_action_triggered
)
def on_toggle_show_action_triggered(self):
if self.echoMode() == QLineEdit.EchoMode.Password:

View file

@ -15,32 +15,40 @@ class GeneralPreferencesWidget(QWidget):
openai_api_key_changed = pyqtSignal(str)
default_export_file_name_changed = pyqtSignal(str)
def __init__(self, default_export_file_name: str, keyring_store=KeyringStore(),
parent: Optional[QWidget] = None):
def __init__(
self,
default_export_file_name: str,
keyring_store=KeyringStore(),
parent: Optional[QWidget] = None,
):
super().__init__(parent)
self.openai_api_key = keyring_store.get_password(
KeyringStore.Key.OPENAI_API_KEY)
KeyringStore.Key.OPENAI_API_KEY
)
layout = QFormLayout(self)
self.openai_api_key_line_edit = OpenAIAPIKeyLineEdit(self.openai_api_key, self)
self.openai_api_key_line_edit.key_changed.connect(
self.on_openai_api_key_changed)
self.on_openai_api_key_changed
)
self.test_openai_api_key_button = QPushButton('Test')
self.test_openai_api_key_button = QPushButton("Test")
self.test_openai_api_key_button.clicked.connect(
self.on_click_test_openai_api_key_button)
self.on_click_test_openai_api_key_button
)
self.update_test_openai_api_key_button()
layout.addRow('OpenAI API Key', self.openai_api_key_line_edit)
layout.addRow('', self.test_openai_api_key_button)
layout.addRow("OpenAI API Key", self.openai_api_key_line_edit)
layout.addRow("", self.test_openai_api_key_button)
default_export_file_name_line_edit = LineEdit(default_export_file_name, self)
default_export_file_name_line_edit.textChanged.connect(
self.default_export_file_name_changed)
self.default_export_file_name_changed
)
default_export_file_name_line_edit.setMinimumWidth(200)
layout.addRow('Default export file name', default_export_file_name_line_edit)
layout.addRow("Default export file name", default_export_file_name_line_edit)
self.setLayout(layout)
@ -60,12 +68,15 @@ class GeneralPreferencesWidget(QWidget):
def on_test_openai_api_key_success(self):
self.test_openai_api_key_button.setEnabled(True)
QMessageBox.information(self, 'OpenAI API Key Test',
'Your API key is valid. Buzz will use this key to perform Whisper API transcriptions.')
QMessageBox.information(
self,
"OpenAI API Key Test",
"Your API key is valid. Buzz will use this key to perform Whisper API transcriptions.",
)
def on_test_openai_api_key_failure(self, error: str):
self.test_openai_api_key_button.setEnabled(True)
QMessageBox.warning(self, 'OpenAI API Key Test', error)
QMessageBox.warning(self, "OpenAI API Key Test", error)
def on_openai_api_key_changed(self, key: str):
self.openai_api_key = key

View file

@ -1,34 +1,55 @@
from typing import Optional
from PyQt6.QtCore import Qt, QThreadPool
from PyQt6.QtWidgets import QWidget, QFormLayout, QTreeWidget, QTreeWidgetItem, \
QPushButton, QMessageBox, QHBoxLayout
from PyQt6.QtWidgets import (
QWidget,
QFormLayout,
QTreeWidget,
QTreeWidgetItem,
QPushButton,
QMessageBox,
QHBoxLayout,
)
from buzz.locale import _
from buzz.model_loader import ModelType, WhisperModelSize, TranscriptionModel, ModelDownloader
from buzz.model_loader import (
ModelType,
WhisperModelSize,
TranscriptionModel,
ModelDownloader,
)
from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog
from buzz.widgets.model_type_combo_box import ModelTypeComboBox
class ModelsPreferencesWidget(QWidget):
def __init__(self, progress_dialog_modality=Qt.WindowModality.WindowModal,
parent: Optional[QWidget] = None):
def __init__(
self,
progress_dialog_modality=Qt.WindowModality.WindowModal,
parent: Optional[QWidget] = None,
):
super().__init__(parent)
self.model_downloader: Optional[ModelDownloader] = None
self.model = TranscriptionModel(model_type=ModelType.WHISPER,
whisper_model_size=WhisperModelSize.TINY)
self.model = TranscriptionModel(
model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY
)
self.progress_dialog_modality = progress_dialog_modality
self.progress_dialog: Optional[ModelDownloadProgressDialog] = None
layout = QFormLayout()
model_type_combo_box = ModelTypeComboBox(
model_types=[ModelType.WHISPER, ModelType.WHISPER_CPP,
ModelType.FASTER_WHISPER],
default_model=self.model.model_type, parent=self)
model_types=[
ModelType.WHISPER,
ModelType.WHISPER_CPP,
ModelType.FASTER_WHISPER,
],
default_model=self.model.model_type,
parent=self,
)
model_type_combo_box.changed.connect(self.on_model_type_changed)
layout.addRow('Group', model_type_combo_box)
layout.addRow("Group", model_type_combo_box)
self.model_list_widget = QTreeWidget()
self.model_list_widget.setColumnCount(1)
@ -37,20 +58,21 @@ class ModelsPreferencesWidget(QWidget):
buttons_layout = QHBoxLayout()
self.download_button = QPushButton(_('Download'))
self.download_button.setObjectName('DownloadButton')
self.download_button = QPushButton(_("Download"))
self.download_button.setObjectName("DownloadButton")
self.download_button.clicked.connect(self.on_download_button_clicked)
buttons_layout.addWidget(self.download_button)
self.show_file_location_button = QPushButton(_('Show file location'))
self.show_file_location_button.setObjectName('ShowFileLocationButton')
self.show_file_location_button = QPushButton(_("Show file location"))
self.show_file_location_button.setObjectName("ShowFileLocationButton")
self.show_file_location_button.clicked.connect(
self.on_show_file_location_button_clicked)
self.on_show_file_location_button_clicked
)
buttons_layout.addWidget(self.show_file_location_button)
buttons_layout.addStretch(1)
self.delete_button = QPushButton(_('Delete'))
self.delete_button.setObjectName('DeleteButton')
self.delete_button = QPushButton(_("Delete"))
self.delete_button.setObjectName("DeleteButton")
self.delete_button.clicked.connect(self.on_delete_button_clicked)
buttons_layout.addWidget(self.delete_button)
@ -71,9 +93,10 @@ class ModelsPreferencesWidget(QWidget):
@staticmethod
def can_delete_model(model: TranscriptionModel):
return ((model.model_type == ModelType.WHISPER or
model.model_type == ModelType.WHISPER_CPP) and
model.get_local_model_path() is not None)
return (
model.model_type == ModelType.WHISPER
or model.model_type == ModelType.WHISPER_CPP
) and model.get_local_model_path() is not None
def reset(self):
# reset buttons
@ -85,20 +108,21 @@ class ModelsPreferencesWidget(QWidget):
# reset model list
self.model_list_widget.clear()
downloaded_item = QTreeWidgetItem(self.model_list_widget)
downloaded_item.setText(0, _('Downloaded'))
downloaded_item.setText(0, _("Downloaded"))
downloaded_item.setFlags(
downloaded_item.flags() & ~Qt.ItemFlag.ItemIsSelectable)
downloaded_item.flags() & ~Qt.ItemFlag.ItemIsSelectable
)
available_item = QTreeWidgetItem(self.model_list_widget)
available_item.setText(0, _('Available for Download'))
available_item.setFlags(
available_item.flags() & ~Qt.ItemFlag.ItemIsSelectable)
available_item.setText(0, _("Available for Download"))
available_item.setFlags(available_item.flags() & ~Qt.ItemFlag.ItemIsSelectable)
self.model_list_widget.addTopLevelItems([downloaded_item, available_item])
self.model_list_widget.expandToDepth(2)
self.model_list_widget.setHeaderHidden(True)
self.model_list_widget.setAlternatingRowColors(True)
for model_size in WhisperModelSize:
model = TranscriptionModel(model_type=self.model.model_type,
whisper_model_size=model_size)
model = TranscriptionModel(
model_type=self.model.model_type, whisper_model_size=model_size
)
model_path = model.get_local_model_path()
parent = downloaded_item if model_path is not None else available_item
item = QTreeWidgetItem(parent)
@ -115,7 +139,9 @@ class ModelsPreferencesWidget(QWidget):
def on_download_button_clicked(self):
self.progress_dialog = ModelDownloadProgressDialog(
model_type=self.model.model_type,
modality=self.progress_dialog_modality, parent=self)
modality=self.progress_dialog_modality,
parent=self,
)
self.progress_dialog.canceled.connect(self.on_progress_dialog_canceled)
self.download_button.setEnabled(False)
@ -128,8 +154,10 @@ class ModelsPreferencesWidget(QWidget):
def on_delete_button_clicked(self):
reply = QMessageBox.question(
self, _('Delete Model'),
_('Are you sure you want to delete the selected model?'))
self,
_("Delete Model"),
_("Are you sure you want to delete the selected model?"),
)
if reply == QMessageBox.StandardButton.Yes:
self.model.delete_local_file()
self.reset()
@ -147,7 +175,7 @@ class ModelsPreferencesWidget(QWidget):
self.progress_dialog = None
self.download_button.setEnabled(True)
self.reset()
QMessageBox.warning(self, _('Error'), f'Download failed: {error}')
QMessageBox.warning(self, _("Error"), f"Download failed: {error}")
def on_download_progress(self, progress: tuple):
self.progress_dialog.set_value(float(progress[0]) / progress[1])

View file

@ -4,12 +4,15 @@ from PyQt6.QtCore import pyqtSignal
from PyQt6.QtWidgets import QDialog, QWidget, QVBoxLayout, QTabWidget, QDialogButtonBox
from buzz.locale import _
from buzz.widgets.preferences_dialog.general_preferences_widget import \
GeneralPreferencesWidget
from buzz.widgets.preferences_dialog.models_preferences_widget import \
ModelsPreferencesWidget
from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import \
ShortcutsEditorPreferencesWidget
from buzz.widgets.preferences_dialog.general_preferences_widget import (
GeneralPreferencesWidget,
)
from buzz.widgets.preferences_dialog.models_preferences_widget import (
ModelsPreferencesWidget,
)
from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import (
ShortcutsEditorPreferencesWidget,
)
class PreferencesDialog(QDialog):
@ -17,31 +20,38 @@ class PreferencesDialog(QDialog):
openai_api_key_changed = pyqtSignal(str)
default_export_file_name_changed = pyqtSignal(str)
def __init__(self, shortcuts: Dict[str, str], default_export_file_name: str,
parent: Optional[QWidget] = None) -> None:
def __init__(
self,
shortcuts: Dict[str, str],
default_export_file_name: str,
parent: Optional[QWidget] = None,
) -> None:
super().__init__(parent)
self.setWindowTitle('Preferences')
self.setWindowTitle("Preferences")
layout = QVBoxLayout(self)
tab_widget = QTabWidget(self)
general_tab_widget = GeneralPreferencesWidget(
default_export_file_name=default_export_file_name, parent=self)
default_export_file_name=default_export_file_name, parent=self
)
general_tab_widget.openai_api_key_changed.connect(self.openai_api_key_changed)
general_tab_widget.default_export_file_name_changed.connect(
self.default_export_file_name_changed)
tab_widget.addTab(general_tab_widget, _('General'))
self.default_export_file_name_changed
)
tab_widget.addTab(general_tab_widget, _("General"))
models_tab_widget = ModelsPreferencesWidget(parent=self)
tab_widget.addTab(models_tab_widget, _('Models'))
tab_widget.addTab(models_tab_widget, _("Models"))
shortcuts_table_widget = ShortcutsEditorPreferencesWidget(shortcuts, self)
shortcuts_table_widget.shortcuts_changed.connect(self.shortcuts_changed)
tab_widget.addTab(shortcuts_table_widget, _('Shortcuts'))
tab_widget.addTab(shortcuts_table_widget, _("Shortcuts"))
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
QDialogButtonBox.StandardButton.Ok), self)
button_box = QDialogButtonBox(
QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Ok), self
)
button_box.accepted.connect(self.accept)
button_box.rejected.connect(self.reject)

View file

@ -18,12 +18,13 @@ class ShortcutsEditorPreferencesWidget(QWidget):
self.layout = QFormLayout(self)
for shortcut in Shortcut:
sequence_edit = SequenceEdit(shortcuts.get(shortcut.name, ''), self)
sequence_edit = SequenceEdit(shortcuts.get(shortcut.name, ""), self)
sequence_edit.keySequenceChanged.connect(
self.get_key_sequence_changed(shortcut.name))
self.get_key_sequence_changed(shortcut.name)
)
self.layout.addRow(shortcut.description, sequence_edit)
reset_to_defaults_button = QPushButton('Reset to Defaults', self)
reset_to_defaults_button = QPushButton("Reset to Defaults", self)
reset_to_defaults_button.setDefault(False)
reset_to_defaults_button.setAutoDefault(False)
reset_to_defaults_button.clicked.connect(self.reset_to_defaults)
@ -41,8 +42,9 @@ class ShortcutsEditorPreferencesWidget(QWidget):
self.shortcuts = Shortcut.get_default_shortcuts()
for i, shortcut in enumerate(Shortcut):
sequence_edit = self.layout.itemAt(i,
QFormLayout.ItemRole.FieldRole).widget()
sequence_edit = self.layout.itemAt(
i, QFormLayout.ItemRole.FieldRole
).widget()
assert isinstance(sequence_edit, SequenceEdit)
sequence_edit.setKeySequence(QKeySequence(self.shortcuts[shortcut.name]))

View file

@ -10,8 +10,8 @@ class SequenceEdit(QKeySequenceEdit):
def __init__(self, sequence: str, parent: Optional[QWidget] = None):
super().__init__(sequence, parent)
self.setClearButtonEnabled(True)
if platform.system() == 'Darwin':
self.setStyleSheet('QLineEdit:focus { border: 2px solid #4d90fe; }')
if platform.system() == "Darwin":
self.setStyleSheet("QLineEdit:focus { border: 2px solid #4d90fe; }")
def keyPressEvent(self, event: QtGui.QKeyEvent) -> None:
key = event.key()
@ -23,7 +23,12 @@ class SequenceEdit(QKeySequenceEdit):
return
# Ignore pressing *only* modifier keys
if key == Qt.Key.Key_Control or key == Qt.Key.Key_Shift or key == Qt.Key.Key_Alt or key == Qt.Key.Key_Meta:
if (
key == Qt.Key.Key_Control
or key == Qt.Key.Key_Shift
or key == Qt.Key.Key_Alt
or key == Qt.Key.Key_Meta
):
return
super().keyPressEvent(event)

View file

@ -11,7 +11,7 @@ class ToolBar(QToolBar):
super().__init__(parent)
self.setIconSize(QSize(18, 18))
self.setStyleSheet('QToolButton{margin: 6px 3px;}')
self.setStyleSheet("QToolButton{margin: 6px 3px;}")
self.setToolButtonStyle(Qt.ToolButtonStyle.ToolButtonIconOnly)
def addAction(self, action: QtGui.QAction) -> None:
@ -23,6 +23,7 @@ class ToolBar(QToolBar):
self.fix_spacing_on_mac()
def fix_spacing_on_mac(self):
if platform.system() == 'Darwin':
if platform.system() == "Darwin":
self.widgetForAction(self.actions()[0]).setStyleSheet(
'QToolButton { margin-left: 9px; margin-right: 1px; }')
"QToolButton { margin-left: 9px; margin-right: 1px; }"
)

View file

@ -5,4 +5,4 @@ from PyQt6.QtWidgets import QPushButton, QWidget
class AdvancedSettingsButton(QPushButton):
def __init__(self, parent: Optional[QWidget]) -> None:
super().__init__('Advanced...', parent)
super().__init__("Advanced...", parent)

View file

@ -1,6 +1,11 @@
from PyQt6.QtCore import pyqtSignal
from PyQt6.QtWidgets import QDialog, QWidget, QDialogButtonBox, QFormLayout, \
QPlainTextEdit
from PyQt6.QtWidgets import (
QDialog,
QWidget,
QDialogButtonBox,
QFormLayout,
QPlainTextEdit,
)
from buzz.widgets.transcriber.temperature_validator import TemperatureValidator
from buzz.locale import _
@ -13,39 +18,48 @@ class AdvancedSettingsDialog(QDialog):
transcription_options: TranscriptionOptions
transcription_options_changed = pyqtSignal(TranscriptionOptions)
def __init__(self, transcription_options: TranscriptionOptions, parent: QWidget | None = None):
def __init__(
self, transcription_options: TranscriptionOptions, parent: QWidget | None = None
):
super().__init__(parent)
self.transcription_options = transcription_options
self.setWindowTitle(_('Advanced Settings'))
self.setWindowTitle(_("Advanced Settings"))
button_box = QDialogButtonBox(QDialogButtonBox.StandardButton(
QDialogButtonBox.StandardButton.Ok), self)
button_box = QDialogButtonBox(
QDialogButtonBox.StandardButton(QDialogButtonBox.StandardButton.Ok), self
)
button_box.accepted.connect(self.accept)
layout = QFormLayout(self)
default_temperature_text = ', '.join(
[str(temp) for temp in transcription_options.temperature])
default_temperature_text = ", ".join(
[str(temp) for temp in transcription_options.temperature]
)
self.temperature_line_edit = LineEdit(default_temperature_text, self)
self.temperature_line_edit.setPlaceholderText(
_('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"'))
_('Comma-separated, e.g. "0.0, 0.2, 0.4, 0.6, 0.8, 1.0"')
)
self.temperature_line_edit.setMinimumWidth(170)
self.temperature_line_edit.textChanged.connect(
self.on_temperature_changed)
self.temperature_line_edit.textChanged.connect(self.on_temperature_changed)
self.temperature_line_edit.setValidator(TemperatureValidator(self))
self.temperature_line_edit.setEnabled(transcription_options.model.model_type == ModelType.WHISPER)
self.temperature_line_edit.setEnabled(
transcription_options.model.model_type == ModelType.WHISPER
)
self.initial_prompt_text_edit = QPlainTextEdit(
transcription_options.initial_prompt, self)
transcription_options.initial_prompt, self
)
self.initial_prompt_text_edit.textChanged.connect(
self.on_initial_prompt_changed)
self.on_initial_prompt_changed
)
self.initial_prompt_text_edit.setEnabled(
transcription_options.model.model_type == ModelType.WHISPER)
transcription_options.model.model_type == ModelType.WHISPER
)
layout.addRow(_('Temperature:'), self.temperature_line_edit)
layout.addRow(_('Initial Prompt:'), self.initial_prompt_text_edit)
layout.addRow(_("Temperature:"), self.temperature_line_edit)
layout.addRow(_("Initial Prompt:"), self.initial_prompt_text_edit)
layout.addWidget(button_box)
self.setLayout(layout)
@ -53,12 +67,14 @@ class AdvancedSettingsDialog(QDialog):
def on_temperature_changed(self, text: str):
try:
temperatures = [float(temp.strip()) for temp in text.split(',')]
temperatures = [float(temp.strip()) for temp in text.split(",")]
self.transcription_options.temperature = tuple(temperatures)
self.transcription_options_changed.emit(self.transcription_options)
except ValueError:
pass
def on_initial_prompt_changed(self):
self.transcription_options.initial_prompt = self.initial_prompt_text_edit.toPlainText()
self.transcription_options.initial_prompt = (
self.initial_prompt_text_edit.toPlainText()
)
self.transcription_options_changed.emit(self.transcription_options)

View file

@ -2,8 +2,14 @@ from typing import Optional, List, Tuple
from PyQt6 import QtGui
from PyQt6.QtCore import pyqtSignal, Qt, QThreadPool
from PyQt6.QtWidgets import QWidget, QVBoxLayout, QCheckBox, QFormLayout, QHBoxLayout, \
QPushButton
from PyQt6.QtWidgets import (
QWidget,
QVBoxLayout,
QCheckBox,
QFormLayout,
QHBoxLayout,
QPushButton,
)
from buzz.dialogs import show_model_download_error_dialog
from buzz.locale import _
@ -11,11 +17,17 @@ from buzz.model_loader import ModelDownloader, TranscriptionModel, ModelType
from buzz.paths import file_paths_as_title
from buzz.settings.settings import Settings
from buzz.store.keyring_store import KeyringStore
from buzz.transcriber import FileTranscriptionOptions, TranscriptionOptions, Task, \
DEFAULT_WHISPER_TEMPERATURE, OutputFormat
from buzz.transcriber import (
FileTranscriptionOptions,
TranscriptionOptions,
Task,
DEFAULT_WHISPER_TEMPERATURE,
OutputFormat,
)
from buzz.widgets.model_download_progress_dialog import ModelDownloadProgressDialog
from buzz.widgets.transcriber.transcription_options_group_box import \
TranscriptionOptionsGroupBox
from buzz.widgets.transcriber.transcription_options_group_box import (
TranscriptionOptionsGroupBox,
)
class FileTranscriberWidget(QWidget):
@ -29,74 +41,100 @@ class FileTranscriberWidget(QWidget):
openai_access_token_changed = pyqtSignal(str)
settings = Settings()
def __init__(self, file_paths: List[str],
default_output_file_name: str,
parent: Optional[QWidget] = None,
flags: Qt.WindowType = Qt.WindowType.Widget) -> None:
def __init__(
self,
file_paths: List[str],
default_output_file_name: str,
parent: Optional[QWidget] = None,
flags: Qt.WindowType = Qt.WindowType.Widget,
) -> None:
super().__init__(parent, flags)
self.setWindowTitle(file_paths_as_title(file_paths))
openai_access_token = KeyringStore().get_password(
KeyringStore.Key.OPENAI_API_KEY)
KeyringStore.Key.OPENAI_API_KEY
)
self.file_paths = file_paths
default_language = self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value='')
key=Settings.Key.FILE_TRANSCRIBER_LANGUAGE, default_value=""
)
self.transcription_options = TranscriptionOptions(
openai_access_token=openai_access_token,
model=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_MODEL,
default_value=TranscriptionModel()),
task=self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_TASK,
default_value=Task.TRANSCRIBE),
language=default_language if default_language != '' else None,
model=self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_MODEL,
default_value=TranscriptionModel(),
),
task=self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_TASK, default_value=Task.TRANSCRIBE
),
language=default_language if default_language != "" else None,
initial_prompt=self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value=''),
key=Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT, default_value=""
),
temperature=self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_TEMPERATURE,
default_value=DEFAULT_WHISPER_TEMPERATURE),
default_value=DEFAULT_WHISPER_TEMPERATURE,
),
word_level_timings=self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
default_value=False))
default_value=False,
),
)
default_export_format_states: List[str] = self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
default_value=[])
key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS, default_value=[]
)
self.file_transcription_options = FileTranscriptionOptions(
file_paths=self.file_paths,
output_formats=set([OutputFormat(output_format) for output_format in
default_export_format_states]),
default_output_file_name=default_output_file_name)
output_formats=set(
[
OutputFormat(output_format)
for output_format in default_export_format_states
]
),
default_output_file_name=default_output_file_name,
)
layout = QVBoxLayout(self)
transcription_options_group_box = TranscriptionOptionsGroupBox(
default_transcription_options=self.transcription_options, parent=self)
default_transcription_options=self.transcription_options, parent=self
)
transcription_options_group_box.transcription_options_changed.connect(
self.on_transcription_options_changed)
self.on_transcription_options_changed
)
self.word_level_timings_checkbox = QCheckBox(_('Word-level timings'))
self.word_level_timings_checkbox = QCheckBox(_("Word-level timings"))
self.word_level_timings_checkbox.setChecked(
self.settings.value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
default_value=False))
self.settings.value(
key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
default_value=False,
)
)
self.word_level_timings_checkbox.stateChanged.connect(
self.on_word_level_timings_changed)
self.on_word_level_timings_changed
)
file_transcription_layout = QFormLayout()
file_transcription_layout.addRow('', self.word_level_timings_checkbox)
file_transcription_layout.addRow("", self.word_level_timings_checkbox)
export_format_layout = QHBoxLayout()
for output_format in OutputFormat:
export_format_checkbox = QCheckBox(f'{output_format.value.upper()}',
parent=self)
export_format_checkbox = QCheckBox(
f"{output_format.value.upper()}", parent=self
)
export_format_checkbox.setChecked(
output_format in self.file_transcription_options.output_formats)
output_format in self.file_transcription_options.output_formats
)
export_format_checkbox.stateChanged.connect(
self.get_on_checkbox_state_changed_callback(output_format))
self.get_on_checkbox_state_changed_callback(output_format)
)
export_format_layout.addWidget(export_format_checkbox)
file_transcription_layout.addRow('Export:', export_format_layout)
file_transcription_layout.addRow("Export:", export_format_layout)
self.run_button = QPushButton(_('Run'), self)
self.run_button = QPushButton(_("Run"), self)
self.run_button.setDefault(True)
self.run_button.clicked.connect(self.on_click_run)
@ -116,15 +154,19 @@ class FileTranscriberWidget(QWidget):
return on_checkbox_state_changed
def on_transcription_options_changed(self,
transcription_options: TranscriptionOptions):
def on_transcription_options_changed(
self, transcription_options: TranscriptionOptions
):
self.transcription_options = transcription_options
self.word_level_timings_checkbox.setDisabled(
self.transcription_options.model.model_type == ModelType.HUGGING_FACE or
self.transcription_options.model.model_type == ModelType.OPEN_AI_WHISPER_API)
if self.transcription_options.openai_access_token != '':
self.transcription_options.model.model_type == ModelType.HUGGING_FACE
or self.transcription_options.model.model_type
== ModelType.OPEN_AI_WHISPER_API
)
if self.transcription_options.openai_access_token != "":
self.openai_access_token_changed.emit(
self.transcription_options.openai_access_token)
self.transcription_options.openai_access_token
)
def on_click_run(self):
self.run_button.setDisabled(True)
@ -143,8 +185,9 @@ class FileTranscriberWidget(QWidget):
def on_model_loaded(self, model_path: str):
self.reset_transcriber_controls()
self.triggered.emit((self.transcription_options,
self.file_transcription_options, model_path))
self.triggered.emit(
(self.transcription_options, self.file_transcription_options, model_path)
)
self.close()
def on_download_model_progress(self, progress: Tuple[float, float]):
@ -152,13 +195,16 @@ class FileTranscriberWidget(QWidget):
if self.model_download_progress_dialog is None:
self.model_download_progress_dialog = ModelDownloadProgressDialog(
model_type=self.transcription_options.model.model_type, parent=self)
model_type=self.transcription_options.model.model_type, parent=self
)
self.model_download_progress_dialog.canceled.connect(
self.on_cancel_model_progress_dialog)
self.on_cancel_model_progress_dialog
)
if self.model_download_progress_dialog is not None:
self.model_download_progress_dialog.set_value(
fraction_completed=current_size / total_size)
fraction_completed=current_size / total_size
)
def on_download_model_error(self, error: str):
self.reset_model_download()
@ -179,26 +225,41 @@ class FileTranscriberWidget(QWidget):
self.model_download_progress_dialog = None
def on_word_level_timings_changed(self, value: int):
self.transcription_options.word_level_timings = value == Qt.CheckState.Checked.value
self.transcription_options.word_level_timings = (
value == Qt.CheckState.Checked.value
)
def closeEvent(self, event: QtGui.QCloseEvent) -> None:
if self.model_loader is not None:
self.model_loader.cancel()
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_LANGUAGE,
self.transcription_options.language)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TASK,
self.transcription_options.task)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_TEMPERATURE,
self.transcription_options.temperature)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT,
self.transcription_options.initial_prompt)
self.settings.set_value(Settings.Key.FILE_TRANSCRIBER_MODEL,
self.transcription_options.model)
self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
value=self.transcription_options.word_level_timings)
self.settings.set_value(key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
value=[export_format.value for export_format in
self.file_transcription_options.output_formats])
self.settings.set_value(
Settings.Key.FILE_TRANSCRIBER_LANGUAGE, self.transcription_options.language
)
self.settings.set_value(
Settings.Key.FILE_TRANSCRIBER_TASK, self.transcription_options.task
)
self.settings.set_value(
Settings.Key.FILE_TRANSCRIBER_TEMPERATURE,
self.transcription_options.temperature,
)
self.settings.set_value(
Settings.Key.FILE_TRANSCRIBER_INITIAL_PROMPT,
self.transcription_options.initial_prompt,
)
self.settings.set_value(
Settings.Key.FILE_TRANSCRIBER_MODEL, self.transcription_options.model
)
self.settings.set_value(
key=Settings.Key.FILE_TRANSCRIBER_WORD_LEVEL_TIMINGS,
value=self.transcription_options.word_level_timings,
)
self.settings.set_value(
key=Settings.Key.FILE_TRANSCRIBER_EXPORT_FORMATS,
value=[
export_format.value
for export_format in self.file_transcription_options.output_formats
],
)
super().closeEvent(event)

View file

@ -2,8 +2,17 @@ import json
import logging
from typing import Optional
from PyQt6.QtCore import pyqtSignal, QTimer, Qt, QMetaObject, QUrl, QUrlQuery, QPoint, \
QObject, QEvent
from PyQt6.QtCore import (
pyqtSignal,
QTimer,
Qt,
QMetaObject,
QUrl,
QUrlQuery,
QPoint,
QObject,
QEvent,
)
from PyQt6.QtGui import QKeyEvent
from PyQt6.QtNetwork import QNetworkAccessManager, QNetworkRequest, QNetworkReply
from PyQt6.QtWidgets import QListWidget, QWidget, QAbstractItemView, QListWidgetItem
@ -16,12 +25,15 @@ class HuggingFaceSearchLineEdit(LineEdit):
model_selected = pyqtSignal(str)
popup: QListWidget
def __init__(self, network_access_manager: Optional[QNetworkAccessManager] = None,
parent: Optional[QWidget] = None):
super().__init__('', parent)
def __init__(
self,
network_access_manager: Optional[QNetworkAccessManager] = None,
parent: Optional[QWidget] = None,
):
super().__init__("", parent)
self.setMinimumWidth(150)
self.setPlaceholderText('openai/whisper-tiny')
self.setPlaceholderText("openai/whisper-tiny")
self.timer = QTimer(self)
self.timer.setSingleShot(True)
@ -56,7 +68,7 @@ class HuggingFaceSearchLineEdit(LineEdit):
item = self.popup.currentItem()
self.setText(item.text())
QMetaObject.invokeMethod(self, 'returnPressed')
QMetaObject.invokeMethod(self, "returnPressed")
self.model_selected.emit(item.data(Qt.ItemDataRole.UserRole))
def fetch_models(self):
@ -79,7 +91,9 @@ class HuggingFaceSearchLineEdit(LineEdit):
def on_request_response(self, network_reply: QNetworkReply):
if network_reply.error() != QNetworkReply.NetworkError.NoError:
logging.debug('Error fetching Hugging Face models: %s', network_reply.error())
logging.debug(
"Error fetching Hugging Face models: %s", network_reply.error()
)
return
models = json.loads(network_reply.readAll().data())
@ -88,7 +102,7 @@ class HuggingFaceSearchLineEdit(LineEdit):
self.popup.clear()
for model in models:
model_id = model.get('id')
model_id = model.get("id")
item = QListWidgetItem(self.popup)
item.setText(model_id)
@ -96,14 +110,16 @@ class HuggingFaceSearchLineEdit(LineEdit):
self.popup.setCurrentItem(self.popup.item(0))
self.popup.setFixedWidth(self.popup.sizeHintForColumn(0) + 20)
self.popup.setFixedHeight(self.popup.sizeHintForRow(0) * min(len(models), 8)) # show max 8 models, then scroll
self.popup.setFixedHeight(
self.popup.sizeHintForRow(0) * min(len(models), 8)
) # show max 8 models, then scroll
self.popup.setUpdatesEnabled(True)
self.popup.move(self.mapToGlobal(QPoint(0, self.height())))
self.popup.setFocus()
self.popup.show()
def eventFilter(self, target: QObject, event: QEvent):
if hasattr(self, 'popup') is False or target != self.popup:
if hasattr(self, "popup") is False or target != self.popup:
return False
if event.type() == QEvent.Type.MouseButtonPress:
@ -123,8 +139,14 @@ class HuggingFaceSearchLineEdit(LineEdit):
self.popup.hide()
return True
if key in [Qt.Key.Key_Up, Qt.Key.Key_Down, Qt.Key.Key_Home, Qt.Key.Key_End, Qt.Key.Key_PageUp,
Qt.Key.Key_PageDown]:
if key in [
Qt.Key.Key_Up,
Qt.Key.Key_Down,
Qt.Key.Key_Home,
Qt.Key.Key_End,
Qt.Key.Key_PageUp,
Qt.Key.Key_PageDown,
]:
return False
self.setFocus()

View file

@ -9,20 +9,25 @@ from buzz.transcriber import LANGUAGES
class LanguagesComboBox(QComboBox):
"""LanguagesComboBox displays a list of languages available to use with Whisper"""
# language is a language key from whisper.tokenizer.LANGUAGES or '' for "detect language"
languageChanged = pyqtSignal(str)
def __init__(self, default_language: Optional[str], parent: Optional[QWidget] = None) -> None:
def __init__(
self, default_language: Optional[str], parent: Optional[QWidget] = None
) -> None:
super().__init__(parent)
whisper_languages = sorted(
[(lang, LANGUAGES[lang].title()) for lang in LANGUAGES], key=lambda lang: lang[1])
self.languages = [('', _('Detect Language'))] + whisper_languages
[(lang, LANGUAGES[lang].title()) for lang in LANGUAGES],
key=lambda lang: lang[1],
)
self.languages = [("", _("Detect Language"))] + whisper_languages
self.addItems([lang[1] for lang in self.languages])
self.currentIndexChanged.connect(self.on_index_changed)
default_language_key = default_language if default_language != '' else None
default_language_key = default_language if default_language != "" else None
for i, lang in enumerate(self.languages):
if lang[0] == default_language_key:
self.setCurrentIndex(i)

View file

@ -8,6 +8,7 @@ from buzz.transcriber import Task
class TasksComboBox(QComboBox):
"""TasksComboBox displays a list of tasks available to use with Whisper"""
taskChanged = pyqtSignal(Task)
def __init__(self, default_task: Task, parent: Optional[QWidget], *args) -> None:

View file

@ -8,10 +8,12 @@ class TemperatureValidator(QValidator):
def __init__(self, parent: Optional[QObject] = ...) -> None:
super().__init__(parent)
def validate(self, text: str, cursor_position: int) -> Tuple['QValidator.State', str, int]:
def validate(
self, text: str, cursor_position: int
) -> Tuple["QValidator.State", str, int]:
try:
temp_strings = [temp.strip() for temp in text.split(',')]
if temp_strings[-1] == '':
temp_strings = [temp.strip() for temp in text.split(",")]
if temp_strings[-1] == "":
return QValidator.State.Intermediate, text, cursor_position
_ = [float(temp) for temp in temp_strings]
return QValidator.State.Acceptable, text, cursor_position

View file

@ -10,8 +10,9 @@ from buzz.widgets.model_type_combo_box import ModelTypeComboBox
from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
from buzz.widgets.transcriber.advanced_settings_button import AdvancedSettingsButton
from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog
from buzz.widgets.transcriber.hugging_face_search_line_edit import \
HuggingFaceSearchLineEdit
from buzz.widgets.transcriber.hugging_face_search_line_edit import (
HuggingFaceSearchLineEdit,
)
from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox
from buzz.widgets.transcriber.tasks_combo_box import TasksComboBox
@ -21,64 +22,70 @@ class TranscriptionOptionsGroupBox(QGroupBox):
transcription_options_changed = pyqtSignal(TranscriptionOptions)
def __init__(
self,
default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
model_types: Optional[List[ModelType]] = None,
parent: Optional[QWidget] = None):
super().__init__(title='', parent=parent)
self,
default_transcription_options: TranscriptionOptions = TranscriptionOptions(),
model_types: Optional[List[ModelType]] = None,
parent: Optional[QWidget] = None,
):
super().__init__(title="", parent=parent)
self.transcription_options = default_transcription_options
self.form_layout = QFormLayout(self)
self.tasks_combo_box = TasksComboBox(
default_task=self.transcription_options.task,
parent=self)
default_task=self.transcription_options.task, parent=self
)
self.tasks_combo_box.taskChanged.connect(self.on_task_changed)
self.languages_combo_box = LanguagesComboBox(
default_language=self.transcription_options.language,
parent=self)
self.languages_combo_box.languageChanged.connect(
self.on_language_changed)
default_language=self.transcription_options.language, parent=self
)
self.languages_combo_box.languageChanged.connect(self.on_language_changed)
self.advanced_settings_button = AdvancedSettingsButton(self)
self.advanced_settings_button.clicked.connect(
self.open_advanced_settings)
self.advanced_settings_button.clicked.connect(self.open_advanced_settings)
self.hugging_face_search_line_edit = HuggingFaceSearchLineEdit()
self.hugging_face_search_line_edit.model_selected.connect(
self.on_hugging_face_model_changed)
self.on_hugging_face_model_changed
)
self.model_type_combo_box = ModelTypeComboBox(model_types=model_types,
default_model=default_transcription_options.model.model_type,
parent=self)
self.model_type_combo_box = ModelTypeComboBox(
model_types=model_types,
default_model=default_transcription_options.model.model_type,
parent=self,
)
self.model_type_combo_box.changed.connect(self.on_model_type_changed)
self.whisper_model_size_combo_box = QComboBox(self)
self.whisper_model_size_combo_box.addItems(
[size.value.title() for size in WhisperModelSize])
[size.value.title() for size in WhisperModelSize]
)
if default_transcription_options.model.whisper_model_size is not None:
self.whisper_model_size_combo_box.setCurrentText(
default_transcription_options.model.whisper_model_size.value.title())
default_transcription_options.model.whisper_model_size.value.title()
)
self.whisper_model_size_combo_box.currentTextChanged.connect(
self.on_whisper_model_size_changed)
self.on_whisper_model_size_changed
)
self.openai_access_token_edit = OpenAIAPIKeyLineEdit(
key=default_transcription_options.openai_access_token,
parent=self)
key=default_transcription_options.openai_access_token, parent=self
)
self.openai_access_token_edit.key_changed.connect(
self.on_openai_access_token_edit_changed)
self.on_openai_access_token_edit_changed
)
self.form_layout.addRow(_('Model:'), self.model_type_combo_box)
self.form_layout.addRow('', self.whisper_model_size_combo_box)
self.form_layout.addRow('', self.hugging_face_search_line_edit)
self.form_layout.addRow('Access Token:', self.openai_access_token_edit)
self.form_layout.addRow(_('Task:'), self.tasks_combo_box)
self.form_layout.addRow(_('Language:'), self.languages_combo_box)
self.form_layout.addRow(_("Model:"), self.model_type_combo_box)
self.form_layout.addRow("", self.whisper_model_size_combo_box)
self.form_layout.addRow("", self.hugging_face_search_line_edit)
self.form_layout.addRow("Access Token:", self.openai_access_token_edit)
self.form_layout.addRow(_("Task:"), self.tasks_combo_box)
self.form_layout.addRow(_("Language:"), self.languages_combo_box)
self.reset_visible_rows()
self.form_layout.addRow('', self.advanced_settings_button)
self.form_layout.addRow("", self.advanced_settings_button)
self.setLayout(self.form_layout)
@ -104,26 +111,33 @@ class TranscriptionOptionsGroupBox(QGroupBox):
def open_advanced_settings(self):
dialog = AdvancedSettingsDialog(
transcription_options=self.transcription_options, parent=self)
transcription_options=self.transcription_options, parent=self
)
dialog.transcription_options_changed.connect(
self.on_transcription_options_changed)
self.on_transcription_options_changed
)
dialog.exec()
def on_transcription_options_changed(self,
transcription_options: TranscriptionOptions):
def on_transcription_options_changed(
self, transcription_options: TranscriptionOptions
):
self.transcription_options = transcription_options
self.transcription_options_changed.emit(transcription_options)
def reset_visible_rows(self):
model_type = self.transcription_options.model.model_type
self.form_layout.setRowVisible(self.hugging_face_search_line_edit,
model_type == ModelType.HUGGING_FACE)
self.form_layout.setRowVisible(self.whisper_model_size_combo_box,
(model_type == ModelType.WHISPER) or (
model_type == ModelType.WHISPER_CPP) or (
model_type == ModelType.FASTER_WHISPER))
self.form_layout.setRowVisible(self.openai_access_token_edit,
model_type == ModelType.OPEN_AI_WHISPER_API)
self.form_layout.setRowVisible(
self.hugging_face_search_line_edit, model_type == ModelType.HUGGING_FACE
)
self.form_layout.setRowVisible(
self.whisper_model_size_combo_box,
(model_type == ModelType.WHISPER)
or (model_type == ModelType.WHISPER_CPP)
or (model_type == ModelType.FASTER_WHISPER),
)
self.form_layout.setRowVisible(
self.openai_access_token_edit, model_type == ModelType.OPEN_AI_WHISPER_API
)
def on_model_type_changed(self, model_type: ModelType):
self.transcription_options.model.model_type = model_type

View file

@ -27,9 +27,10 @@ class TranscriptionSegmentsEditorWidget(QTableWidget):
self.setColumnCount(3)
self.verticalHeader().hide()
self.setHorizontalHeaderLabels([_('Start'), _('End'), _('Text')])
self.horizontalHeader().setSectionResizeMode(2,
QHeaderView.ResizeMode.ResizeToContents)
self.setHorizontalHeaderLabels([_("Start"), _("End"), _("Text")])
self.horizontalHeader().setSectionResizeMode(
2, QHeaderView.ResizeMode.ResizeToContents
)
self.setSelectionMode(QTableWidget.SelectionMode.SingleSelection)
for segment in segments:
@ -38,12 +39,18 @@ class TranscriptionSegmentsEditorWidget(QTableWidget):
start_item = QTableWidgetItem(to_timestamp(segment.start))
start_item.setFlags(
start_item.flags() & ~Qt.ItemFlag.ItemIsEditable & ~Qt.ItemFlag.ItemIsSelectable)
start_item.flags()
& ~Qt.ItemFlag.ItemIsEditable
& ~Qt.ItemFlag.ItemIsSelectable
)
self.setItem(row_index, self.Column.START.value, start_item)
end_item = QTableWidgetItem(to_timestamp(segment.end))
end_item.setFlags(
end_item.flags() & ~Qt.ItemFlag.ItemIsEditable & ~Qt.ItemFlag.ItemIsSelectable)
end_item.flags()
& ~Qt.ItemFlag.ItemIsEditable
& ~Qt.ItemFlag.ItemIsSelectable
)
self.setItem(row_index, self.Column.END.value, end_item)
text_item = QTableWidgetItem(segment.text)
@ -61,5 +68,4 @@ class TranscriptionSegmentsEditorWidget(QTableWidget):
def on_item_selection_changed(self):
ranges = self.selectedRanges()
self.segment_index_selected.emit(
ranges[0].topRow() if len(ranges) > 0 else -1)
self.segment_index_selected.emit(ranges[0].topRow() if len(ranges) > 0 else -1)

View file

@ -30,13 +30,12 @@ class TranscriptionTasksTableWidget(QTableWidget):
self.setColumnHidden(0, True)
self.verticalHeader().hide()
self.setHorizontalHeaderLabels([_('ID'), _('File Name'), _('Status')])
self.setHorizontalHeaderLabels([_("ID"), _("File Name"), _("Status")])
self.setColumnWidth(self.Column.FILE_NAME.value, 250)
self.setColumnWidth(self.Column.STATUS.value, 180)
self.horizontalHeader().setMinimumSectionSize(180)
self.setSelectionBehavior(
QAbstractItemView.SelectionBehavior.SelectRows)
self.setSelectionBehavior(QAbstractItemView.SelectionBehavior.SelectRows)
def upsert_task(self, task: FileTranscriptionTask):
task_row_index = self.task_row_index(task.id)
@ -45,21 +44,19 @@ class TranscriptionTasksTableWidget(QTableWidget):
row_index = self.rowCount() - 1
task_id_widget_item = QTableWidgetItem(str(task.id))
self.setItem(row_index, self.Column.TASK_ID.value,
task_id_widget_item)
self.setItem(row_index, self.Column.TASK_ID.value, task_id_widget_item)
file_name_widget_item = QTableWidgetItem(
os.path.basename(task.file_path))
file_name_widget_item = QTableWidgetItem(os.path.basename(task.file_path))
file_name_widget_item.setFlags(
file_name_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
self.setItem(row_index, self.Column.FILE_NAME.value,
file_name_widget_item)
file_name_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable
)
self.setItem(row_index, self.Column.FILE_NAME.value, file_name_widget_item)
status_widget_item = QTableWidgetItem(self.get_status_text(task))
status_widget_item.setFlags(
status_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable)
self.setItem(row_index, self.Column.STATUS.value,
status_widget_item)
status_widget_item.flags() & ~Qt.ItemFlag.ItemIsEditable
)
self.setItem(row_index, self.Column.STATUS.value, status_widget_item)
else:
status_widget = self.item(task_row_index, self.Column.STATUS.value)
status_widget.setText(self.get_status_text(task))
@ -67,31 +64,30 @@ class TranscriptionTasksTableWidget(QTableWidget):
@staticmethod
def format_timedelta(delta: datetime.timedelta):
mm, ss = divmod(delta.seconds, 60)
result = f'{ss}s'
result = f"{ss}s"
if mm == 0:
return result
hh, mm = divmod(mm, 60)
result = f'{mm}m {result}'
result = f"{mm}m {result}"
if hh == 0:
return result
return f'{hh}h {result}'
return f"{hh}h {result}"
@staticmethod
def get_status_text(task: FileTranscriptionTask):
if task.status == FileTranscriptionTask.Status.IN_PROGRESS:
return (
f'{_("In Progress")} ({task.fraction_completed :.0%})')
return f'{_("In Progress")} ({task.fraction_completed :.0%})'
elif task.status == FileTranscriptionTask.Status.COMPLETED:
status = _('Completed')
status = _("Completed")
if task.started_at is not None and task.completed_at is not None:
status += f" ({TranscriptionTasksTableWidget.format_timedelta(task.completed_at - task.started_at)})"
return status
elif task.status == FileTranscriptionTask.Status.FAILED:
return f'{_("Failed")} ({task.error})'
elif task.status == FileTranscriptionTask.Status.CANCELED:
return _('Canceled')
return _("Canceled")
elif task.status == FileTranscriptionTask.Status.QUEUED:
return _('Queued')
return _("Queued")
def clear_task(self, task_id: int):
task_row_index = self.task_row_index(task_id)
@ -99,15 +95,20 @@ class TranscriptionTasksTableWidget(QTableWidget):
self.removeRow(task_row_index)
def task_row_index(self, task_id: int) -> int | None:
table_items_matching_task_id = [item for item in self.findItems(str(task_id), Qt.MatchFlag.MatchExactly) if
item.column() == self.Column.TASK_ID.value]
table_items_matching_task_id = [
item
for item in self.findItems(str(task_id), Qt.MatchFlag.MatchExactly)
if item.column() == self.Column.TASK_ID.value
]
if len(table_items_matching_task_id) == 0:
return None
return table_items_matching_task_id[0].row()
@staticmethod
def find_task_id(index: QModelIndex):
sibling_index = index.siblingAtColumn(TranscriptionTasksTableWidget.Column.TASK_ID.value).data()
sibling_index = index.siblingAtColumn(
TranscriptionTasksTableWidget.Column.TASK_ID.value
).data()
return int(sibling_index) if sibling_index is not None else None
def keyPressEvent(self, event: QtGui.QKeyEvent) -> None:

View file

@ -3,20 +3,32 @@ from typing import List, Optional
from PyQt6.QtCore import Qt, pyqtSignal
from PyQt6.QtGui import QUndoCommand, QUndoStack, QKeySequence, QAction
from PyQt6.QtWidgets import QWidget, QHBoxLayout, QMenu, QPushButton, QVBoxLayout, \
QFileDialog
from PyQt6.QtWidgets import (
QWidget,
QHBoxLayout,
QMenu,
QPushButton,
QVBoxLayout,
QFileDialog,
)
from buzz.action import Action
from buzz.assets import get_asset_path
from buzz.locale import _
from buzz.paths import file_path_as_title
from buzz.transcriber import FileTranscriptionTask, Segment, OutputFormat, \
get_default_output_file_path, write_output
from buzz.transcriber import (
FileTranscriptionTask,
Segment,
OutputFormat,
get_default_output_file_path,
write_output,
)
from buzz.widgets.audio_player import AudioPlayer
from buzz.widgets.icon import Icon
from buzz.widgets.toolbar import ToolBar
from buzz.widgets.transcription_segments_editor_widget import \
TranscriptionSegmentsEditorWidget
from buzz.widgets.transcription_segments_editor_widget import (
TranscriptionSegmentsEditorWidget,
)
class TranscriptionViewerWidget(QWidget):
@ -24,9 +36,14 @@ class TranscriptionViewerWidget(QWidget):
task_changed = pyqtSignal()
class ChangeSegmentTextCommand(QUndoCommand):
def __init__(self, table_widget: TranscriptionSegmentsEditorWidget,
segments: List[Segment],
segment_index: int, segment_text: str, task_changed: pyqtSignal):
def __init__(
self,
table_widget: TranscriptionSegmentsEditorWidget,
segments: List[Segment],
segment_index: int,
segment_text: str,
task_changed: pyqtSignal,
):
super().__init__()
self.table_widget = table_widget
@ -52,10 +69,11 @@ class TranscriptionViewerWidget(QWidget):
self.task_changed.emit()
def __init__(
self, transcription_task: FileTranscriptionTask,
open_transcription_output=True,
parent: Optional['QWidget'] = None,
flags: Qt.WindowType = Qt.WindowType.Widget,
self,
transcription_task: FileTranscriptionTask,
open_transcription_output=True,
parent: Optional["QWidget"] = None,
flags: Qt.WindowType = Qt.WindowType.Widget,
) -> None:
super().__init__(parent, flags)
self.transcription_task = transcription_task
@ -71,20 +89,23 @@ class TranscriptionViewerWidget(QWidget):
undo_action = self.undo_stack.createUndoAction(self, _("Undo"))
undo_action.setShortcuts(QKeySequence.StandardKey.Undo)
undo_action.setIcon(
Icon(get_asset_path('assets/undo_FILL0_wght700_GRAD0_opsz48.svg'), self))
Icon(get_asset_path("assets/undo_FILL0_wght700_GRAD0_opsz48.svg"), self)
)
undo_action.setToolTip(Action.get_tooltip(undo_action))
redo_action = self.undo_stack.createRedoAction(self, _("Redo"))
redo_action.setShortcuts(QKeySequence.StandardKey.Redo)
redo_action.setIcon(
Icon(get_asset_path('assets/redo_FILL0_wght700_GRAD0_opsz48.svg'), self))
Icon(get_asset_path("assets/redo_FILL0_wght700_GRAD0_opsz48.svg"), self)
)
redo_action.setToolTip(Action.get_tooltip(redo_action))
toolbar = ToolBar()
toolbar.addActions([undo_action, redo_action])
self.table_widget = TranscriptionSegmentsEditorWidget(
segments=transcription_task.segments, parent=self)
segments=transcription_task.segments, parent=self
)
self.table_widget.segment_text_changed.connect(self.on_segment_text_changed)
self.table_widget.segment_index_selected.connect(self.on_segment_index_selected)
@ -96,14 +117,16 @@ class TranscriptionViewerWidget(QWidget):
buttons_layout.addStretch()
export_button_menu = QMenu()
actions = [QAction(text=output_format.value.upper(), parent=self)
for output_format in OutputFormat]
actions = [
QAction(text=output_format.value.upper(), parent=self)
for output_format in OutputFormat
]
export_button_menu.addActions(actions)
export_button_menu.triggered.connect(self.on_menu_triggered)
export_button = QPushButton(self)
export_button.setText(_('Export'))
export_button.setText(_("Export"))
export_button.setMenu(export_button_menu)
buttons_layout.addWidget(export_button)
@ -120,11 +143,14 @@ class TranscriptionViewerWidget(QWidget):
def on_segment_text_changed(self, event: tuple):
segment_index, segment_text = event
self.undo_stack.push(
self.ChangeSegmentTextCommand(table_widget=self.table_widget,
segments=self.transcription_task.segments,
segment_index=segment_index,
segment_text=segment_text,
task_changed=self.task_changed))
self.ChangeSegmentTextCommand(
table_widget=self.table_widget,
segments=self.transcription_task.segments,
segment_index=segment_index,
segment_text=segment_text,
task_changed=self.task_changed,
)
)
def on_segment_index_selected(self, index: int):
selected_segment = self.transcription_task.segments[index]
@ -134,15 +160,22 @@ class TranscriptionViewerWidget(QWidget):
def on_menu_triggered(self, action: QAction):
output_format = OutputFormat[action.text()]
default_path = get_default_output_file_path(task=self.transcription_task,
output_format=output_format)
default_path = get_default_output_file_path(
task=self.transcription_task, output_format=output_format
)
(output_file_path, nil) = QFileDialog.getSaveFileName(self, _('Save File'),
default_path,
_('Text files') + f' (*.{output_format.value})')
(output_file_path, nil) = QFileDialog.getSaveFileName(
self,
_("Save File"),
default_path,
_("Text files") + f" (*.{output_format.value})",
)
if output_file_path == '':
if output_file_path == "":
return
write_output(path=output_file_path, segments=self.transcription_task.segments,
output_format=output_format)
write_output(
path=output_file_path,
segments=self.transcription_task.segments,
output_format=output_format,
)

77
poetry.lock generated
View file

@ -266,6 +266,54 @@ files = [
{file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"},
]
[[package]]
name = "black"
version = "23.7.0"
description = "The uncompromising code formatter."
category = "dev"
optional = false
python-versions = ">=3.8"
files = [
{file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"},
{file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"},
{file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"},
{file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"},
{file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"},
{file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"},
{file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"},
{file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"},
{file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"},
{file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"},
{file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"},
{file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"},
{file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"},
{file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"},
{file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"},
{file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"},
{file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"},
{file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"},
{file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"},
{file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"},
{file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"},
{file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"},
]
[package.dependencies]
aiohttp = {version = ">=3.7.4", optional = true, markers = "extra == \"d\""}
click = ">=8.0.0"
mypy-extensions = ">=0.4.3"
packaging = ">=22.0"
pathspec = ">=0.9.0"
platformdirs = ">=2"
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
typing-extensions = {version = ">=3.10.0.0", markers = "python_version < \"3.10\""}
[package.extras]
colorama = ["colorama (>=0.4.3)"]
d = ["aiohttp (>=3.7.4)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]]
name = "certifi"
version = "2023.5.7"
@ -452,6 +500,21 @@ files = [
{file = "charset_normalizer-3.1.0-py3-none-any.whl", hash = "sha256:3d9098b479e78c85080c98e1e35ff40b4a31d8953102bb0fd7d1b6f8a2111a3d"},
]
[[package]]
name = "click"
version = "8.1.7"
description = "Composable command line interface toolkit"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "click-8.1.7-py3-none-any.whl", hash = "sha256:ae74fb96c20a0277a1d615f1e4d73c8414f5a98db8b799a7931d1582f3390c28"},
{file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"},
]
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]]
name = "cmake"
version = "3.26.4"
@ -1534,6 +1597,18 @@ files = [
{file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"},
]
[[package]]
name = "pathspec"
version = "0.11.2"
description = "Utility library for gitignore style pattern matching of file paths."
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "pathspec-0.11.2-py3-none-any.whl", hash = "sha256:1d6ed233af05e679efb96b1851550ea95bbb64b7c490b0f5aa52996c11e92a20"},
{file = "pathspec-0.11.2.tar.gz", hash = "sha256:e0d8d0ac2f12da61956eb2306b69f9469b42f4deb0f3cb6ed47b9cce9996ced3"},
]
[[package]]
name = "pefile"
version = "2023.2.7"
@ -2669,4 +2744,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more
[metadata]
lock-version = "2.0"
python-versions = ">=3.9.13,<3.11"
content-hash = "ceb6ce6c7083882f1499bd36f5e98f6aa1e0a872d8268ccbda91d67ee81fdd1e"
content-hash = "fe7fae59602bd0ecdceafbfe274f6f36f0cb489b67bfc7d4bfae4998dbbe672a"

View file

@ -33,6 +33,7 @@ pytest-xvfb = "^2.0.0"
pylint = "^2.15.5"
pre-commit = "^2.20.0"
pytest-benchmark = "^4.0.0"
black = {extras = ["d"], version = "^23.7.0"}
[tool.poetry.group.build.dependencies]
ctypesgen = "^1.1.1"

File diff suppressed because one or more lines are too long

View file

@ -1,16 +1,31 @@
from buzz.cache import TasksCache
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
TranscriptionOptions)
from buzz.transcriber import (
FileTranscriptionOptions,
FileTranscriptionTask,
TranscriptionOptions,
)
class TestTasksCache:
def test_should_save_and_load(self, tmp_path):
cache = TasksCache(cache_dir=str(tmp_path))
tasks = [FileTranscriptionTask(file_path='1.mp3', transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=['1.mp3']),
model_path=''),
FileTranscriptionTask(file_path='2.mp3', transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=['2.mp3']),
model_path='')]
tasks = [
FileTranscriptionTask(
file_path="1.mp3",
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(
file_paths=["1.mp3"]
),
model_path="",
),
FileTranscriptionTask(
file_path="2.mp3",
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(
file_paths=["2.mp3"]
),
model_path="",
),
]
cache.save(tasks)
assert cache.load() == tasks

View file

@ -8,91 +8,100 @@ import pytest
import sounddevice
from PyQt6.QtCore import QSize, Qt
from PyQt6.QtGui import QValidator, QKeyEvent
from PyQt6.QtWidgets import QPushButton, QToolBar, QTableWidget, QApplication, QMessageBox
from PyQt6.QtWidgets import (
QPushButton,
QToolBar,
QTableWidget,
QApplication,
QMessageBox,
)
from _pytest.fixtures import SubRequest
from pytestqt.qtbot import QtBot
from buzz.__version__ import VERSION
from buzz.cache import TasksCache
from buzz.gui import (AudioDevicesComboBox, MainWindow,
RecordingTranscriberWidget)
from buzz.gui import AudioDevicesComboBox, MainWindow, RecordingTranscriberWidget
from buzz.widgets.transcriber.advanced_settings_dialog import AdvancedSettingsDialog
from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidget
from buzz.widgets.transcriber.hugging_face_search_line_edit import \
HuggingFaceSearchLineEdit
from buzz.widgets.transcriber.hugging_face_search_line_edit import (
HuggingFaceSearchLineEdit,
)
from buzz.widgets.transcriber.languages_combo_box import LanguagesComboBox
from buzz.widgets.transcriber.temperature_validator import TemperatureValidator
from buzz.widgets.about_dialog import AboutDialog
from buzz.model_loader import ModelType
from buzz.settings.settings import Settings
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask,
TranscriptionOptions)
from buzz.widgets.transcriber.transcription_options_group_box import \
TranscriptionOptionsGroupBox
from buzz.transcriber import (
FileTranscriptionOptions,
FileTranscriptionTask,
TranscriptionOptions,
)
from buzz.widgets.transcriber.transcription_options_group_box import (
TranscriptionOptionsGroupBox,
)
from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget
from tests.mock_sounddevice import MockInputStream, mock_query_devices
from .mock_qt import MockNetworkAccessManager, MockNetworkReply
if platform.system() == 'Linux':
multiprocessing.set_start_method('spawn')
if platform.system() == "Linux":
multiprocessing.set_start_method("spawn")
@pytest.fixture(scope='module', autouse=True)
@pytest.fixture(scope="module", autouse=True)
def audio_setup():
with patch('sounddevice.query_devices') as query_devices_mock, \
patch('sounddevice.InputStream', side_effect=MockInputStream), \
patch('sounddevice.check_input_settings'):
with patch("sounddevice.query_devices") as query_devices_mock, patch(
"sounddevice.InputStream", side_effect=MockInputStream
), patch("sounddevice.check_input_settings"):
query_devices_mock.return_value = mock_query_devices
sounddevice.default.device = 3, 4
yield
class TestLanguagesComboBox:
def test_should_show_sorted_whisper_languages(self, qtbot):
languages_combox_box = LanguagesComboBox('en')
languages_combox_box = LanguagesComboBox("en")
qtbot.add_widget(languages_combox_box)
assert languages_combox_box.itemText(0) == 'Detect Language'
assert languages_combox_box.itemText(10) == 'Belarusian'
assert languages_combox_box.itemText(20) == 'Dutch'
assert languages_combox_box.itemText(30) == 'Gujarati'
assert languages_combox_box.itemText(40) == 'Japanese'
assert languages_combox_box.itemText(50) == 'Lithuanian'
assert languages_combox_box.itemText(0) == "Detect Language"
assert languages_combox_box.itemText(10) == "Belarusian"
assert languages_combox_box.itemText(20) == "Dutch"
assert languages_combox_box.itemText(30) == "Gujarati"
assert languages_combox_box.itemText(40) == "Japanese"
assert languages_combox_box.itemText(50) == "Lithuanian"
def test_should_select_en_as_default_language(self, qtbot):
languages_combox_box = LanguagesComboBox('en')
languages_combox_box = LanguagesComboBox("en")
qtbot.add_widget(languages_combox_box)
assert languages_combox_box.currentText() == 'English'
assert languages_combox_box.currentText() == "English"
def test_should_select_detect_language_as_default(self, qtbot):
languages_combo_box = LanguagesComboBox(None)
qtbot.add_widget(languages_combo_box)
assert languages_combo_box.currentText() == 'Detect Language'
assert languages_combo_box.currentText() == "Detect Language"
class TestAudioDevicesComboBox:
def test_get_devices(self):
audio_devices_combo_box = AudioDevicesComboBox()
assert audio_devices_combo_box.itemText(0) == 'Background Music'
assert audio_devices_combo_box.itemText(1) == 'Background Music (UI Sounds)'
assert audio_devices_combo_box.itemText(2) == 'BlackHole 2ch'
assert audio_devices_combo_box.itemText(3) == 'MacBook Pro Microphone'
assert audio_devices_combo_box.itemText(4) == 'Null Audio Device'
assert audio_devices_combo_box.itemText(0) == "Background Music"
assert audio_devices_combo_box.itemText(1) == "Background Music (UI Sounds)"
assert audio_devices_combo_box.itemText(2) == "BlackHole 2ch"
assert audio_devices_combo_box.itemText(3) == "MacBook Pro Microphone"
assert audio_devices_combo_box.itemText(4) == "Null Audio Device"
assert audio_devices_combo_box.currentText() == 'MacBook Pro Microphone'
assert audio_devices_combo_box.currentText() == "MacBook Pro Microphone"
def test_select_default_mic_when_no_default(self):
sounddevice.default.device = -1, 1
audio_devices_combo_box = AudioDevicesComboBox()
assert audio_devices_combo_box.currentText() == 'Background Music'
assert audio_devices_combo_box.currentText() == "Background Music"
@pytest.fixture()
def tasks_cache(tmp_path, request: SubRequest):
cache = TasksCache(cache_dir=str(tmp_path))
if hasattr(request, 'param'):
if hasattr(request, "param"):
tasks: List[FileTranscriptionTask] = request.param
cache.save(tasks)
yield cache
@ -100,28 +109,40 @@ def tasks_cache(tmp_path, request: SubRequest):
def get_test_asset(filename: str):
return os.path.join(os.path.dirname(__file__), '../testdata/', filename)
return os.path.join(os.path.dirname(__file__), "../testdata/", filename)
mock_tasks = [
FileTranscriptionTask(file_path='', transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=[]), model_path='',
status=FileTranscriptionTask.Status.COMPLETED),
FileTranscriptionTask(file_path='', transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=[]), model_path='',
status=FileTranscriptionTask.Status.CANCELED),
FileTranscriptionTask(file_path='', transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=[]), model_path='',
status=FileTranscriptionTask.Status.FAILED, error='Error'),
FileTranscriptionTask(
file_path="",
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
model_path="",
status=FileTranscriptionTask.Status.COMPLETED,
),
FileTranscriptionTask(
file_path="",
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
model_path="",
status=FileTranscriptionTask.Status.CANCELED,
),
FileTranscriptionTask(
file_path="",
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(file_paths=[]),
model_path="",
status=FileTranscriptionTask.Status.FAILED,
error="Error",
),
]
class TestMainWindow:
def test_should_set_window_title_and_icon(self, qtbot):
window = MainWindow()
qtbot.add_widget(window)
assert window.windowTitle() == 'Buzz'
assert window.windowTitle() == "Buzz"
assert window.windowIcon().pixmap(QSize(64, 64)).isNull() is False
window.close()
@ -132,13 +153,18 @@ class TestMainWindow:
self._start_new_transcription(window)
open_transcript_action = self._get_toolbar_action(window, 'Open Transcript')
open_transcript_action = self._get_toolbar_action(window, "Open Transcript")
assert open_transcript_action.isEnabled() is False
table_widget: QTableWidget = window.findChild(QTableWidget)
qtbot.wait_until(self._assert_task_status(table_widget, 0, 'Completed'), timeout=2 * 60 * 1000)
qtbot.wait_until(
self._assert_task_status(table_widget, 0, "Completed"),
timeout=2 * 60 * 1000,
)
table_widget.setCurrentIndex(table_widget.indexFromItem(table_widget.item(0, 1)))
table_widget.setCurrentIndex(
table_widget.indexFromItem(table_widget.item(0, 1))
)
assert open_transcript_action.isEnabled()
# @pytest.mark.skip(reason='Timing out or crashing')
@ -152,8 +178,8 @@ class TestMainWindow:
def assert_task_in_progress():
assert table_widget.rowCount() > 0
assert table_widget.item(0, 1).text() == 'whisper-french.mp3'
assert 'In Progress' in table_widget.item(0, 2).text()
assert table_widget.item(0, 1).text() == "whisper-french.mp3"
assert "In Progress" in table_widget.item(0, 2).text()
qtbot.wait_until(assert_task_in_progress, timeout=2 * 60 * 1000)
@ -161,7 +187,9 @@ class TestMainWindow:
table_widget.selectRow(0)
window.toolbar.stop_transcription_action.trigger()
qtbot.wait_until(self._assert_task_status(table_widget, 0, 'Canceled'), timeout=60 * 1000)
qtbot.wait_until(
self._assert_task_status(table_widget, 0, "Canceled"), timeout=60 * 1000
)
table_widget.selectRow(0)
assert window.toolbar.stop_transcription_action.isEnabled() is False
@ -169,7 +197,7 @@ class TestMainWindow:
window.close()
@pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True)
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
def test_should_load_tasks_from_cache(self, qtbot, tasks_cache):
window = MainWindow(tasks_cache=tasks_cache)
qtbot.add_widget(window)
@ -177,43 +205,47 @@ class TestMainWindow:
table_widget: QTableWidget = window.findChild(QTableWidget)
assert table_widget.rowCount() == 3
assert table_widget.item(0, 2).text() == 'Completed'
assert table_widget.item(0, 2).text() == "Completed"
table_widget.selectRow(0)
assert window.toolbar.open_transcript_action.isEnabled()
assert table_widget.item(1, 2).text() == 'Canceled'
assert table_widget.item(1, 2).text() == "Canceled"
table_widget.selectRow(1)
assert window.toolbar.open_transcript_action.isEnabled() is False
assert table_widget.item(2, 2).text() == 'Failed (Error)'
assert table_widget.item(2, 2).text() == "Failed (Error)"
table_widget.selectRow(2)
assert window.toolbar.open_transcript_action.isEnabled() is False
window.close()
@pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True)
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
def test_should_clear_history_with_rows_selected(self, qtbot, tasks_cache):
window = MainWindow(tasks_cache=tasks_cache)
table_widget: QTableWidget = window.findChild(QTableWidget)
table_widget.selectAll()
with patch('PyQt6.QtWidgets.QMessageBox.question') as question_message_box_mock:
with patch("PyQt6.QtWidgets.QMessageBox.question") as question_message_box_mock:
question_message_box_mock.return_value = QMessageBox.StandardButton.Yes
window.toolbar.clear_history_action.trigger()
assert table_widget.rowCount() == 0
window.close()
@pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True)
def test_should_have_clear_history_action_disabled_with_no_rows_selected(self, qtbot, tasks_cache):
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
def test_should_have_clear_history_action_disabled_with_no_rows_selected(
self, qtbot, tasks_cache
):
window = MainWindow(tasks_cache=tasks_cache)
qtbot.add_widget(window)
assert window.toolbar.clear_history_action.isEnabled() is False
window.close()
@pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True)
def test_should_open_transcription_viewer_when_menu_action_is_clicked(self, qtbot, tasks_cache):
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
def test_should_open_transcription_viewer_when_menu_action_is_clicked(
self, qtbot, tasks_cache
):
window = MainWindow(tasks_cache=tasks_cache)
qtbot.add_widget(window)
@ -228,23 +260,33 @@ class TestMainWindow:
window.close()
@pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True)
def test_should_open_transcription_viewer_when_return_clicked(self, qtbot, tasks_cache):
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
def test_should_open_transcription_viewer_when_return_clicked(
self, qtbot, tasks_cache
):
window = MainWindow(tasks_cache=tasks_cache)
qtbot.add_widget(window)
table_widget: QTableWidget = window.findChild(QTableWidget)
table_widget.selectRow(0)
table_widget.keyPressEvent(
QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Return, Qt.KeyboardModifier.NoModifier, '\r'))
QKeyEvent(
QKeyEvent.Type.KeyPress,
Qt.Key.Key_Return,
Qt.KeyboardModifier.NoModifier,
"\r",
)
)
transcription_viewer = window.findChild(TranscriptionViewerWidget)
assert transcription_viewer is not None
window.close()
@pytest.mark.parametrize('tasks_cache', [mock_tasks], indirect=True)
def test_should_have_open_transcript_action_disabled_with_no_rows_selected(self, qtbot, tasks_cache):
@pytest.mark.parametrize("tasks_cache", [mock_tasks], indirect=True)
def test_should_have_open_transcript_action_disabled_with_no_rows_selected(
self, qtbot, tasks_cache
):
window = MainWindow(tasks_cache=tasks_cache)
qtbot.add_widget(window)
@ -253,20 +295,31 @@ class TestMainWindow:
@staticmethod
def _start_new_transcription(window: MainWindow):
with patch('PyQt6.QtWidgets.QFileDialog.getOpenFileNames') as open_file_names_mock:
open_file_names_mock.return_value = ([get_test_asset('whisper-french.mp3')], '')
new_transcription_action = TestMainWindow._get_toolbar_action(window, 'New Transcription')
with patch(
"PyQt6.QtWidgets.QFileDialog.getOpenFileNames"
) as open_file_names_mock:
open_file_names_mock.return_value = (
[get_test_asset("whisper-french.mp3")],
"",
)
new_transcription_action = TestMainWindow._get_toolbar_action(
window, "New Transcription"
)
new_transcription_action.trigger()
file_transcriber_widget: FileTranscriberWidget = window.findChild(FileTranscriberWidget)
file_transcriber_widget: FileTranscriberWidget = window.findChild(
FileTranscriberWidget
)
run_button: QPushButton = file_transcriber_widget.findChild(QPushButton)
run_button.click()
@staticmethod
def _assert_task_status(table_widget: QTableWidget, row_index: int, expected_status: str):
def _assert_task_status(
table_widget: QTableWidget, row_index: int, expected_status: str
):
def assert_task_canceled():
assert table_widget.rowCount() > 0
assert table_widget.item(row_index, 1).text() == 'whisper-french.mp3'
assert table_widget.item(row_index, 1).text() == "whisper-french.mp3"
assert expected_status in table_widget.item(row_index, 2).text()
return assert_task_canceled
@ -277,7 +330,7 @@ class TestMainWindow:
return [action for action in toolbar.actions() if action.text() == text][0]
@pytest.fixture(scope='module', autouse=True)
@pytest.fixture(scope="module", autouse=True)
def clear_settings():
settings = Settings()
settings.clear()
@ -285,7 +338,7 @@ def clear_settings():
class TestAboutDialog:
def test_should_check_for_updates(self, qtbot: QtBot):
reply = MockNetworkReply(data={'name': 'v' + VERSION})
reply = MockNetworkReply(data={"name": "v" + VERSION})
manager = MockNetworkAccessManager(reply=reply)
dialog = AboutDialog(network_access_manager=manager)
qtbot.add_widget(dialog)
@ -296,41 +349,45 @@ class TestAboutDialog:
with qtbot.wait_signal(dialog.network_access_manager.finished):
dialog.check_updates_button.click()
mock_message_box_information.assert_called_with(dialog, '', "You're up to date!")
mock_message_box_information.assert_called_with(
dialog, "", "You're up to date!"
)
class TestAdvancedSettingsDialog:
def test_should_update_advanced_settings(self, qtbot: QtBot):
dialog = AdvancedSettingsDialog(
transcription_options=TranscriptionOptions(temperature=(0.0, 0.8), initial_prompt='prompt'))
transcription_options=TranscriptionOptions(
temperature=(0.0, 0.8), initial_prompt="prompt"
)
)
qtbot.add_widget(dialog)
transcription_options_mock = Mock()
dialog.transcription_options_changed.connect(
transcription_options_mock)
dialog.transcription_options_changed.connect(transcription_options_mock)
assert dialog.windowTitle() == 'Advanced Settings'
assert dialog.temperature_line_edit.text() == '0.0, 0.8'
assert dialog.initial_prompt_text_edit.toPlainText() == 'prompt'
assert dialog.windowTitle() == "Advanced Settings"
assert dialog.temperature_line_edit.text() == "0.0, 0.8"
assert dialog.initial_prompt_text_edit.toPlainText() == "prompt"
dialog.temperature_line_edit.setText('0.0, 0.8, 1.0')
dialog.initial_prompt_text_edit.setPlainText('new prompt')
dialog.temperature_line_edit.setText("0.0, 0.8, 1.0")
dialog.initial_prompt_text_edit.setPlainText("new prompt")
assert transcription_options_mock.call_args[0][0].temperature == (
0.0, 0.8, 1.0)
assert transcription_options_mock.call_args[0][0].initial_prompt == 'new prompt'
assert transcription_options_mock.call_args[0][0].temperature == (0.0, 0.8, 1.0)
assert transcription_options_mock.call_args[0][0].initial_prompt == "new prompt"
class TestTemperatureValidator:
validator = TemperatureValidator(None)
@pytest.mark.parametrize(
'text,state',
"text,state",
[
('0.0,0.5,1.0', QValidator.State.Acceptable),
('0.0,0.5,', QValidator.State.Intermediate),
('0.0,0.5,p', QValidator.State.Invalid),
])
("0.0,0.5,1.0", QValidator.State.Acceptable),
("0.0,0.5,", QValidator.State.Intermediate),
("0.0,0.5,p", QValidator.State.Invalid),
],
)
def test_should_validate_temperature(self, text: str, state: QValidator.State):
assert self.validator.validate(text, 0)[0] == state
@ -339,9 +396,9 @@ class TestRecordingTranscriberWidget:
def test_should_set_window_title(self, qtbot: QtBot):
widget = RecordingTranscriberWidget()
qtbot.add_widget(widget)
assert widget.windowTitle() == 'Live Recording'
assert widget.windowTitle() == "Live Recording"
@pytest.mark.skip(reason='Seg faults on CI')
@pytest.mark.skip(reason="Seg faults on CI")
def test_should_transcribe(self, qtbot):
widget = RecordingTranscriberWidget()
qtbot.add_widget(widget)
@ -355,31 +412,37 @@ class TestRecordingTranscriberWidget:
with qtbot.wait_signal(widget.transcription_thread.finished, timeout=60 * 1000):
widget.stop_recording()
assert 'Welcome to Passe' in widget.text_box.toPlainText()
assert "Welcome to Passe" in widget.text_box.toPlainText()
class TestHuggingFaceSearchLineEdit:
def test_should_update_selected_model_on_type(self, qtbot: QtBot):
widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager())
widget = HuggingFaceSearchLineEdit(
network_access_manager=self.network_access_manager()
)
qtbot.add_widget(widget)
mock_model_selected = Mock()
widget.model_selected.connect(mock_model_selected)
self._set_text_and_wait_response(qtbot, widget)
mock_model_selected.assert_called_with('openai/whisper-tiny')
mock_model_selected.assert_called_with("openai/whisper-tiny")
def test_should_show_list_of_models(self, qtbot: QtBot):
widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager())
widget = HuggingFaceSearchLineEdit(
network_access_manager=self.network_access_manager()
)
qtbot.add_widget(widget)
self._set_text_and_wait_response(qtbot, widget)
assert widget.popup.count() > 0
assert 'openai/whisper-tiny' in widget.popup.item(0).text()
assert "openai/whisper-tiny" in widget.popup.item(0).text()
def test_should_select_model_from_list(self, qtbot: QtBot):
widget = HuggingFaceSearchLineEdit(network_access_manager=self.network_access_manager())
widget = HuggingFaceSearchLineEdit(
network_access_manager=self.network_access_manager()
)
qtbot.add_widget(widget)
mock_model_selected = Mock()
@ -388,23 +451,35 @@ class TestHuggingFaceSearchLineEdit:
self._set_text_and_wait_response(qtbot, widget)
# press down arrow and enter to select next item
QApplication.sendEvent(widget.popup,
QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Down, Qt.KeyboardModifier.NoModifier))
QApplication.sendEvent(widget.popup,
QKeyEvent(QKeyEvent.Type.KeyPress, Qt.Key.Key_Enter, Qt.KeyboardModifier.NoModifier))
QApplication.sendEvent(
widget.popup,
QKeyEvent(
QKeyEvent.Type.KeyPress, Qt.Key.Key_Down, Qt.KeyboardModifier.NoModifier
),
)
QApplication.sendEvent(
widget.popup,
QKeyEvent(
QKeyEvent.Type.KeyPress,
Qt.Key.Key_Enter,
Qt.KeyboardModifier.NoModifier,
),
)
mock_model_selected.assert_called_with('openai/whisper-tiny.en')
mock_model_selected.assert_called_with("openai/whisper-tiny.en")
@staticmethod
def network_access_manager():
reply = MockNetworkReply(data=[{'id': 'openai/whisper-tiny'}, {'id': 'openai/whisper-tiny.en'}])
reply = MockNetworkReply(
data=[{"id": "openai/whisper-tiny"}, {"id": "openai/whisper-tiny.en"}]
)
return MockNetworkAccessManager(reply=reply)
@staticmethod
def _set_text_and_wait_response(qtbot: QtBot, widget: HuggingFaceSearchLineEdit):
with qtbot.wait_signal(widget.network_manager.finished):
widget.setText('openai/whisper-tiny')
widget.textEdited.emit('openai/whisper-tiny')
widget.setText("openai/whisper-tiny")
widget.textEdited.emit("openai/whisper-tiny")
class TestTranscriptionOptionsGroupBox:
@ -417,5 +492,7 @@ class TestTranscriptionOptionsGroupBox:
widget.model_type_combo_box.setCurrentIndex(1)
transcription_options: TranscriptionOptions = mock_transcription_options_changed.call_args[0][0]
transcription_options: TranscriptionOptions = (
mock_transcription_options_changed.call_args[0][0]
)
assert transcription_options.model.model_type == ModelType.WHISPER_CPP

View file

@ -1,4 +1,3 @@
import json
from typing import Optional
@ -10,10 +9,10 @@ class MockNetworkReply(QNetworkReply):
def __init__(self, data: object, _: Optional[QObject] = None) -> None:
self.data = data
def readAll(self) -> 'QByteArray':
return QByteArray(json.dumps(self.data).encode('utf-8'))
def readAll(self) -> "QByteArray":
return QByteArray(json.dumps(self.data).encode("utf-8"))
def error(self) -> 'QNetworkReply.NetworkError':
def error(self) -> "QNetworkReply.NetworkError":
return QNetworkReply.NetworkError.NoError
@ -21,10 +20,12 @@ class MockNetworkAccessManager(QNetworkAccessManager):
finished = pyqtSignal(object)
reply: MockNetworkReply
def __init__(self, reply: MockNetworkReply, parent: Optional[QObject] = None) -> None:
def __init__(
self, reply: MockNetworkReply, parent: Optional[QObject] = None
) -> None:
super().__init__(parent)
self.reply = reply
def get(self, _: 'QNetworkRequest') -> 'QNetworkReply':
def get(self, _: "QNetworkRequest") -> "QNetworkReply":
self.finished.emit(self.reply)
return self.reply

View file

@ -9,34 +9,90 @@ import sounddevice
import whisper
mock_query_devices = [
{'name': 'Background Music', 'index': 0, 'hostapi': 0, 'max_input_channels': 2, 'max_output_channels': 2,
'default_low_input_latency': 0.01,
'default_low_output_latency': 0.008, 'default_high_input_latency': 0.1, 'default_high_output_latency': 0.064,
'default_samplerate': 8000.0},
{'name': 'Background Music (UI Sounds)', 'index': 1, 'hostapi': 0, 'max_input_channels': 2,
'max_output_channels': 2, 'default_low_input_latency': 0.01,
'default_low_output_latency': 0.008, 'default_high_input_latency': 0.1, 'default_high_output_latency': 0.064,
'default_samplerate': 8000.0},
{'name': 'BlackHole 2ch', 'index': 2, 'hostapi': 0, 'max_input_channels': 2, 'max_output_channels': 2,
'default_low_input_latency': 0.01,
'default_low_output_latency': 0.0013333333333333333, 'default_high_input_latency': 0.1,
'default_high_output_latency': 0.010666666666666666, 'default_samplerate': 48000.0},
{'name': 'MacBook Pro Microphone', 'index': 3, 'hostapi': 0, 'max_input_channels': 1, 'max_output_channels': 0,
'default_low_input_latency': 0.034520833333333334,
'default_low_output_latency': 0.01, 'default_high_input_latency': 0.043854166666666666,
'default_high_output_latency': 0.1, 'default_samplerate': 48000.0},
{'name': 'MacBook Pro Speakers', 'index': 4, 'hostapi': 0, 'max_input_channels': 0, 'max_output_channels': 2,
'default_low_input_latency': 0.01,
'default_low_output_latency': 0.0070416666666666666, 'default_high_input_latency': 0.1,
'default_high_output_latency': 0.016375, 'default_samplerate': 48000.0},
{'name': 'Null Audio Device', 'index': 5, 'hostapi': 0, 'max_input_channels': 2, 'max_output_channels': 2,
'default_low_input_latency': 0.01,
'default_low_output_latency': 0.0014512471655328798, 'default_high_input_latency': 0.1,
'default_high_output_latency': 0.011609977324263039, 'default_samplerate': 44100.0},
{'name': 'Multi-Output Device', 'index': 6, 'hostapi': 0, 'max_input_channels': 0, 'max_output_channels': 2,
'default_low_input_latency': 0.01,
'default_low_output_latency': 0.0033333333333333335, 'default_high_input_latency': 0.1,
'default_high_output_latency': 0.012666666666666666, 'default_samplerate': 48000.0},
{
"name": "Background Music",
"index": 0,
"hostapi": 0,
"max_input_channels": 2,
"max_output_channels": 2,
"default_low_input_latency": 0.01,
"default_low_output_latency": 0.008,
"default_high_input_latency": 0.1,
"default_high_output_latency": 0.064,
"default_samplerate": 8000.0,
},
{
"name": "Background Music (UI Sounds)",
"index": 1,
"hostapi": 0,
"max_input_channels": 2,
"max_output_channels": 2,
"default_low_input_latency": 0.01,
"default_low_output_latency": 0.008,
"default_high_input_latency": 0.1,
"default_high_output_latency": 0.064,
"default_samplerate": 8000.0,
},
{
"name": "BlackHole 2ch",
"index": 2,
"hostapi": 0,
"max_input_channels": 2,
"max_output_channels": 2,
"default_low_input_latency": 0.01,
"default_low_output_latency": 0.0013333333333333333,
"default_high_input_latency": 0.1,
"default_high_output_latency": 0.010666666666666666,
"default_samplerate": 48000.0,
},
{
"name": "MacBook Pro Microphone",
"index": 3,
"hostapi": 0,
"max_input_channels": 1,
"max_output_channels": 0,
"default_low_input_latency": 0.034520833333333334,
"default_low_output_latency": 0.01,
"default_high_input_latency": 0.043854166666666666,
"default_high_output_latency": 0.1,
"default_samplerate": 48000.0,
},
{
"name": "MacBook Pro Speakers",
"index": 4,
"hostapi": 0,
"max_input_channels": 0,
"max_output_channels": 2,
"default_low_input_latency": 0.01,
"default_low_output_latency": 0.0070416666666666666,
"default_high_input_latency": 0.1,
"default_high_output_latency": 0.016375,
"default_samplerate": 48000.0,
},
{
"name": "Null Audio Device",
"index": 5,
"hostapi": 0,
"max_input_channels": 2,
"max_output_channels": 2,
"default_low_input_latency": 0.01,
"default_low_output_latency": 0.0014512471655328798,
"default_high_input_latency": 0.1,
"default_high_output_latency": 0.011609977324263039,
"default_samplerate": 44100.0,
},
{
"name": "Multi-Output Device",
"index": 6,
"hostapi": 0,
"max_input_channels": 0,
"max_output_channels": 2,
"default_low_input_latency": 0.01,
"default_low_output_latency": 0.0033333333333333335,
"default_high_input_latency": 0.1,
"default_high_output_latency": 0.012666666666666666,
"default_samplerate": 48000.0,
},
]
@ -44,7 +100,12 @@ class MockInputStream(MagicMock):
running = False
thread: Thread
def __init__(self, callback: Callable[[np.ndarray, int, Any, sounddevice.CallbackFlags], None], *args, **kwargs):
def __init__(
self,
callback: Callable[[np.ndarray, int, Any, sounddevice.CallbackFlags], None],
*args,
**kwargs
):
super().__init__(spec=sounddevice.InputStream)
self.thread = Thread(target=self.target)
self.callback = callback
@ -54,7 +115,9 @@ class MockInputStream(MagicMock):
def target(self):
sample_rate = whisper.audio.SAMPLE_RATE
file_path = os.path.join(os.path.dirname(__file__), '../testdata/whisper-french.mp3')
file_path = os.path.join(
os.path.dirname(__file__), "../testdata/whisper-french.mp3"
)
audio = whisper.load_audio(file_path, sr=sample_rate)
chunk_duration_secs = 1
@ -65,7 +128,7 @@ class MockInputStream(MagicMock):
while self.running:
time.sleep(chunk_duration_secs)
chunk = audio[seek:seek + num_samples_in_chunk]
chunk = audio[seek : seek + num_samples_in_chunk]
self.callback(chunk, 0, None, sounddevice.CallbackFlags())
seek += num_samples_in_chunk

View file

@ -1,14 +1,13 @@
from buzz.model_loader import TranscriptionModel, ModelDownloader
def get_model_path(transcription_model: TranscriptionModel) -> str:
path = transcription_model.get_local_model_path()
if path is not None:
return path
model_loader = ModelDownloader(model=transcription_model)
model_path = ''
model_path = ""
def on_load_model(path: str):
nonlocal model_path

View file

@ -4,20 +4,32 @@ from unittest.mock import Mock
import pytest
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, Task, WhisperCppFileTranscriber,
TranscriptionOptions, WhisperFileTranscriber, FileTranscriber)
from buzz.transcriber import (
FileTranscriptionOptions,
FileTranscriptionTask,
Task,
WhisperCppFileTranscriber,
TranscriptionOptions,
WhisperFileTranscriber,
FileTranscriber,
)
from tests.model_loader import get_model_path
def get_task(model: TranscriptionModel):
file_transcription_options = FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3'])
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
word_level_timings=False,
model=model)
file_paths=["testdata/whisper-french.mp3"]
)
transcription_options = TranscriptionOptions(
language="fr", task=Task.TRANSCRIBE, word_level_timings=False, model=model
)
model_path = get_model_path(transcription_options.model)
return FileTranscriptionTask(file_path='testdata/audio-long.mp3', transcription_options=transcription_options,
file_transcription_options=file_transcription_options, model_path=model_path)
return FileTranscriptionTask(
file_path="testdata/audio-long.mp3",
transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
model_path=model_path,
)
def transcribe(qtbot, transcriber: FileTranscriber):
@ -31,24 +43,53 @@ def transcribe(qtbot, transcriber: FileTranscriber):
@pytest.mark.parametrize(
'transcriber',
"transcriber",
[
pytest.param(
WhisperCppFileTranscriber(task=(get_task(
TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY)))),
id="Whisper.cpp - Tiny"),
pytest.param(
WhisperFileTranscriber(task=(get_task(
TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY)))),
id="Whisper - Tiny"),
pytest.param(
WhisperFileTranscriber(task=(get_task(
TranscriptionModel(model_type=ModelType.FASTER_WHISPER, whisper_model_size=WhisperModelSize.TINY)))),
id="Faster Whisper - Tiny",
marks=pytest.mark.skipif(platform.system() == 'Darwin',
reason='Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087')
WhisperCppFileTranscriber(
task=(
get_task(
TranscriptionModel(
model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY,
)
)
)
),
id="Whisper.cpp - Tiny",
),
])
pytest.param(
WhisperFileTranscriber(
task=(
get_task(
TranscriptionModel(
model_type=ModelType.WHISPER,
whisper_model_size=WhisperModelSize.TINY,
)
)
)
),
id="Whisper - Tiny",
),
pytest.param(
WhisperFileTranscriber(
task=(
get_task(
TranscriptionModel(
model_type=ModelType.FASTER_WHISPER,
whisper_model_size=WhisperModelSize.TINY,
)
)
)
),
id="Faster Whisper - Tiny",
marks=pytest.mark.skipif(
platform.system() == "Darwin",
reason="Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087",
),
),
],
)
def test_should_transcribe_and_benchmark(qtbot, benchmark, transcriber):
segments = benchmark(transcribe, qtbot, transcriber)
assert len(segments) > 0

View file

@ -11,26 +11,42 @@ from PyQt6.QtCore import QThread
from pytestqt.qtbot import QtBot
from buzz.model_loader import WhisperModelSize, ModelType, TranscriptionModel
from buzz.transcriber import (FileTranscriptionOptions, FileTranscriptionTask, OutputFormat, Segment, Task, WhisperCpp, WhisperCppFileTranscriber,
WhisperFileTranscriber,
get_default_output_file_path, to_timestamp,
whisper_cpp_params, write_output, TranscriptionOptions)
from buzz.transcriber import (
FileTranscriptionOptions,
FileTranscriptionTask,
OutputFormat,
Segment,
Task,
WhisperCpp,
WhisperCppFileTranscriber,
WhisperFileTranscriber,
get_default_output_file_path,
to_timestamp,
whisper_cpp_params,
write_output,
TranscriptionOptions,
)
from buzz.recording_transcriber import RecordingTranscriber
from tests.mock_sounddevice import MockInputStream
from tests.model_loader import get_model_path
class TestRecordingTranscriber:
@pytest.mark.skip(reason='Hanging')
@pytest.mark.skip(reason="Hanging")
def test_should_transcribe(self, qtbot):
thread = QThread()
transcription_model = TranscriptionModel(model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY)
transcription_model = TranscriptionModel(
model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY
)
transcriber = RecordingTranscriber(transcription_options=TranscriptionOptions(
model=transcription_model, language='fr', task=Task.TRANSCRIBE),
input_device_index=0, sample_rate=16_000)
transcriber = RecordingTranscriber(
transcription_options=TranscriptionOptions(
model=transcription_model, language="fr", task=Task.TRANSCRIBE
),
input_device_index=0,
sample_rate=16_000,
)
transcriber.moveToThread(thread)
thread.finished.connect(thread.deleteLater)
@ -41,39 +57,55 @@ class TestRecordingTranscriber:
transcriber.finished.connect(thread.quit)
transcriber.finished.connect(transcriber.deleteLater)
with patch('sounddevice.InputStream', side_effect=MockInputStream), patch(
'sounddevice.check_input_settings'), qtbot.wait_signal(transcriber.transcription, timeout=60 * 1000):
with patch("sounddevice.InputStream", side_effect=MockInputStream), patch(
"sounddevice.check_input_settings"
), qtbot.wait_signal(transcriber.transcription, timeout=60 * 1000):
thread.start()
with qtbot.wait_signal(thread.finished, timeout=60 * 1000):
transcriber.stop_recording()
text = mock_transcription.call_args[0][0]
assert 'Bienvenue dans Passe' in text
assert "Bienvenue dans Passe" in text
class TestWhisperCppFileTranscriber:
@pytest.mark.parametrize(
'word_level_timings,expected_segments',
"word_level_timings,expected_segments",
[
(False, [Segment(0, 6560,
'Bienvenue dans Passe-Relle. Un podcast pensé pour')]),
(True, [Segment(30, 330, 'Bien'), Segment(330, 740, 'venue')])
])
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]):
(
False,
[Segment(0, 6560, "Bienvenue dans Passe-Relle. Un podcast pensé pour")],
),
(True, [Segment(30, 330, "Bien"), Segment(330, 740, "venue")]),
],
)
def test_transcribe(
self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment]
):
file_transcription_options = FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3'])
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
word_level_timings=word_level_timings,
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY))
file_paths=["testdata/whisper-french.mp3"]
)
transcription_options = TranscriptionOptions(
language="fr",
task=Task.TRANSCRIBE,
word_level_timings=word_level_timings,
model=TranscriptionModel(
model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY,
),
)
model_path = get_model_path(transcription_options.model)
transcriber = WhisperCppFileTranscriber(
task=FileTranscriptionTask(file_path='testdata/whisper-french.mp3',
transcription_options=transcription_options,
file_transcription_options=file_transcription_options, model_path=model_path))
mock_progress = Mock(side_effect=lambda value: print('progress: ', value))
task=FileTranscriptionTask(
file_path="testdata/whisper-french.mp3",
transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
model_path=model_path,
)
)
mock_progress = Mock(side_effect=lambda value: print("progress: ", value))
mock_completed = Mock()
transcriber.progress.connect(mock_progress)
transcriber.completed.connect(mock_completed)
@ -81,7 +113,11 @@ class TestWhisperCppFileTranscriber:
transcriber.run()
mock_progress.assert_called()
segments = [segment for segment in mock_completed.call_args[0][0] if len(segment.text) > 0]
segments = [
segment
for segment in mock_completed.call_args[0][0]
if len(segment.text) > 0
]
for i, expected_segment in enumerate(expected_segments):
assert expected_segment.start == segments[i].start
assert expected_segment.end == segments[i].end
@ -90,82 +126,164 @@ class TestWhisperCppFileTranscriber:
class TestWhisperFileTranscriber:
@pytest.mark.parametrize(
'output_format,expected_file_path,default_output_file_name',
"output_format,expected_file_path,default_output_file_name",
[
(OutputFormat.SRT, '/a/b/c-translate--Whisper-tiny.srt', '{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}'),
])
def test_default_output_file2(self, output_format: OutputFormat, expected_file_path: str, default_output_file_name: str):
(
OutputFormat.SRT,
"/a/b/c-translate--Whisper-tiny.srt",
"{{ input_file_name }}-{{ task }}-{{ language }}-{{ model_type }}-{{ model_size }}",
),
],
)
def test_default_output_file2(
self,
output_format: OutputFormat,
expected_file_path: str,
default_output_file_name: str,
):
file_path = get_default_output_file_path(
task=FileTranscriptionTask(
file_path='/a/b/c.mp4',
file_path="/a/b/c.mp4",
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name=default_output_file_name),
model_path=''),
output_format=output_format)
file_transcription_options=FileTranscriptionOptions(
file_paths=[], default_output_file_name=default_output_file_name
),
model_path="",
),
output_format=output_format,
)
assert file_path == expected_file_path
def test_default_output_file(self):
srt = get_default_output_file_path(
task=FileTranscriptionTask(
file_path='/a/b/c.mp4',
file_path="/a/b/c.mp4",
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name='{{ input_file_name }} (Translated on {{ date_time }})'),
model_path=''),
output_format=OutputFormat.TXT)
assert srt.startswith('/a/b/c (Translated on ')
assert srt.endswith('.txt')
file_transcription_options=FileTranscriptionOptions(
file_paths=[],
default_output_file_name="{{ input_file_name }} (Translated on {{ date_time }})",
),
model_path="",
),
output_format=OutputFormat.TXT,
)
assert srt.startswith("/a/b/c (Translated on ")
assert srt.endswith(".txt")
srt = get_default_output_file_path(
task=FileTranscriptionTask(
file_path='/a/b/c.mp4',
file_path="/a/b/c.mp4",
transcription_options=TranscriptionOptions(task=Task.TRANSLATE),
file_transcription_options=FileTranscriptionOptions(file_paths=[], default_output_file_name='{{ input_file_name }} (Translated on {{ date_time }})'),
model_path=''),
output_format=OutputFormat.SRT)
assert srt.startswith('/a/b/c (Translated on ')
assert srt.endswith('.srt')
file_transcription_options=FileTranscriptionOptions(
file_paths=[],
default_output_file_name="{{ input_file_name }} (Translated on {{ date_time }})",
),
model_path="",
),
output_format=OutputFormat.SRT,
)
assert srt.startswith("/a/b/c (Translated on ")
assert srt.endswith(".srt")
@pytest.mark.parametrize(
'word_level_timings,expected_segments,model,check_progress',
"word_level_timings,expected_segments,model,check_progress",
[
(False, [Segment(0, 6560,
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances')],
TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True),
(True, [Segment(40, 299, ' Bien'), Segment(299, 329, 'venue dans')],
TranscriptionModel(model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY), True),
(False, [Segment(0, 8517,
' Bienvenue dans Passe-Relle. Un podcast pensé pour évêyer la curiosité des apprenances '
'et des apprenances de français.')],
TranscriptionModel(model_type=ModelType.HUGGING_FACE,
hugging_face_model_id='openai/whisper-tiny'), False),
(
False,
[
Segment(
0,
6560,
" Bienvenue dans Passe-Relle. Un podcast pensé pour évêiller la curiosité des apprenances",
)
],
TranscriptionModel(
model_type=ModelType.WHISPER,
whisper_model_size=WhisperModelSize.TINY,
),
True,
),
(
True,
[Segment(40, 299, " Bien"), Segment(299, 329, "venue dans")],
TranscriptionModel(
model_type=ModelType.WHISPER,
whisper_model_size=WhisperModelSize.TINY,
),
True,
),
(
False,
[
Segment(
0,
8517,
" Bienvenue dans Passe-Relle. Un podcast pensé pour évêyer la curiosité des apprenances "
"et des apprenances de français.",
)
],
TranscriptionModel(
model_type=ModelType.HUGGING_FACE,
hugging_face_model_id="openai/whisper-tiny",
),
False,
),
pytest.param(
False, [Segment(start=0, end=8400,
text=' Bienvenue dans Passrel, un podcast pensé pour éveiller la curiosité des apprenances et des apprenances de français.')],
TranscriptionModel(model_type=ModelType.FASTER_WHISPER, whisper_model_size=WhisperModelSize.TINY), True,
marks=pytest.mark.skipif(platform.system() == 'Darwin',
reason='Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087')
)
])
def test_transcribe(self, qtbot: QtBot, word_level_timings: bool, expected_segments: List[Segment],
model: TranscriptionModel, check_progress):
False,
[
Segment(
start=0,
end=8400,
text=" Bienvenue dans Passrel, un podcast pensé pour éveiller la curiosité des apprenances et des apprenances de français.",
)
],
TranscriptionModel(
model_type=ModelType.FASTER_WHISPER,
whisper_model_size=WhisperModelSize.TINY,
),
True,
marks=pytest.mark.skipif(
platform.system() == "Darwin",
reason="Error with libiomp5 already initialized on GH action runner: https://github.com/chidiwilliams/buzz/actions/runs/4657331262/jobs/8241832087",
),
),
],
)
def test_transcribe(
self,
qtbot: QtBot,
word_level_timings: bool,
expected_segments: List[Segment],
model: TranscriptionModel,
check_progress,
):
mock_progress = Mock()
mock_completed = Mock()
transcription_options = TranscriptionOptions(language='fr', task=Task.TRANSCRIBE,
word_level_timings=word_level_timings,
model=model)
transcription_options = TranscriptionOptions(
language="fr",
task=Task.TRANSCRIBE,
word_level_timings=word_level_timings,
model=model,
)
model_path = get_model_path(transcription_options.model)
file_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'testdata/whisper-french.mp3'))
file_transcription_options = FileTranscriptionOptions(
file_paths=[file_path])
file_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "testdata/whisper-french.mp3")
)
file_transcription_options = FileTranscriptionOptions(file_paths=[file_path])
transcriber = WhisperFileTranscriber(
task=FileTranscriptionTask(transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
file_path=file_path, model_path=model_path))
task=FileTranscriptionTask(
transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
file_path=file_path,
model_path=model_path,
)
)
transcriber.progress.connect(mock_progress)
transcriber.completed.connect(mock_completed)
with qtbot.wait_signal(transcriber.progress, timeout=10 * 6000), qtbot.wait_signal(transcriber.completed,
timeout=10 * 6000):
with qtbot.wait_signal(
transcriber.progress, timeout=10 * 6000
), qtbot.wait_signal(transcriber.completed, timeout=10 * 6000):
transcriber.run()
# Skip checking progress...
@ -182,26 +300,37 @@ class TestWhisperFileTranscriber:
mock_completed.assert_called()
segments = mock_completed.call_args[0][0]
assert len(segments) >= len(expected_segments)
for (i, expected_segment) in enumerate(expected_segments):
for i, expected_segment in enumerate(expected_segments):
assert segments[i] == expected_segment
@pytest.mark.skip()
def test_transcribe_stop(self):
output_file_path = os.path.join(tempfile.gettempdir(), 'whisper.txt')
output_file_path = os.path.join(tempfile.gettempdir(), "whisper.txt")
if os.path.exists(output_file_path):
os.remove(output_file_path)
file_transcription_options = FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3'])
file_paths=["testdata/whisper-french.mp3"]
)
transcription_options = TranscriptionOptions(
language='fr', task=Task.TRANSCRIBE, word_level_timings=False,
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY))
language="fr",
task=Task.TRANSCRIBE,
word_level_timings=False,
model=TranscriptionModel(
model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY,
),
)
model_path = get_model_path(transcription_options.model)
transcriber = WhisperFileTranscriber(
task=FileTranscriptionTask(model_path=model_path, transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
file_path='testdata/whisper-french.mp3'))
task=FileTranscriptionTask(
model_path=model_path,
transcription_options=transcription_options,
file_transcription_options=file_transcription_options,
file_path="testdata/whisper-french.mp3",
)
)
transcriber.run()
time.sleep(1)
transcriber.stop()
@ -212,40 +341,54 @@ class TestWhisperFileTranscriber:
class TestToTimestamp:
def test_to_timestamp(self):
assert to_timestamp(0) == '00:00:00.000'
assert to_timestamp(123456789) == '34:17:36.789'
assert to_timestamp(0) == "00:00:00.000"
assert to_timestamp(123456789) == "34:17:36.789"
class TestWhisperCpp:
def test_transcribe(self):
transcription_options = TranscriptionOptions(
model=TranscriptionModel(model_type=ModelType.WHISPER_CPP, whisper_model_size=WhisperModelSize.TINY))
model=TranscriptionModel(
model_type=ModelType.WHISPER_CPP,
whisper_model_size=WhisperModelSize.TINY,
)
)
model_path = get_model_path(transcription_options.model)
whisper_cpp = WhisperCpp(model=model_path)
params = whisper_cpp_params(
language='fr', task=Task.TRANSCRIBE, word_level_timings=False)
language="fr", task=Task.TRANSCRIBE, word_level_timings=False
)
result = whisper_cpp.transcribe(
audio='testdata/whisper-french.mp3', params=params)
audio="testdata/whisper-french.mp3", params=params
)
assert 'Bienvenue dans Passe' in result['text']
assert "Bienvenue dans Passe" in result["text"]
@pytest.mark.parametrize(
'output_format,output_text',
"output_format,output_text",
[
(OutputFormat.TXT, 'Bien\nvenue dans\n'),
(OutputFormat.TXT, "Bien\nvenue dans\n"),
(
OutputFormat.SRT,
'1\n00:00:00,040 --> 00:00:00,299\nBien\n\n2\n00:00:00,299 --> 00:00:00,329\nvenue dans\n\n'),
(OutputFormat.VTT,
'WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n'),
])
def test_write_output(tmp_path: pathlib.Path, output_format: OutputFormat, output_text: str):
output_file_path = tmp_path / 'whisper.txt'
segments = [Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')]
OutputFormat.SRT,
"1\n00:00:00,040 --> 00:00:00,299\nBien\n\n2\n00:00:00,299 --> 00:00:00,329\nvenue dans\n\n",
),
(
OutputFormat.VTT,
"WEBVTT\n\n00:00:00.040 --> 00:00:00.299\nBien\n\n00:00:00.299 --> 00:00:00.329\nvenue dans\n\n",
),
],
)
def test_write_output(
tmp_path: pathlib.Path, output_format: OutputFormat, output_text: str
):
output_file_path = tmp_path / "whisper.txt"
segments = [Segment(40, 299, "Bien"), Segment(299, 329, "venue dans")]
write_output(path=str(output_file_path), segments=segments, output_format=output_format)
write_output(
path=str(output_file_path), segments=segments, output_format=output_format
)
output_file = open(output_file_path, 'r', encoding='utf-8')
output_file = open(output_file_path, "r", encoding="utf-8")
assert output_text == output_file.read()

View file

@ -3,8 +3,9 @@ from buzz.transformers_whisper import load_model
class TestTransformersWhisper:
def test_should_transcribe(self):
model = load_model('openai/whisper-tiny')
model = load_model("openai/whisper-tiny")
result = model.transcribe(
audio='testdata/whisper-french.mp3', language='fr', task='transcribe')
audio="testdata/whisper-french.mp3", language="fr", task="transcribe"
)
assert 'Bienvenue dans Passe' in result['text']
assert "Bienvenue dans Passe" in result["text"]

View file

@ -9,13 +9,19 @@ from buzz.widgets.transcriber.file_transcriber_widget import FileTranscriberWidg
class TestFileTranscriberWidget:
def test_should_set_window_title(self, qtbot: QtBot):
widget = FileTranscriberWidget(
file_paths=['testdata/whisper-french.mp3'], default_output_file_name='', parent=None)
file_paths=["testdata/whisper-french.mp3"],
default_output_file_name="",
parent=None,
)
qtbot.add_widget(widget)
assert widget.windowTitle() == 'whisper-french.mp3'
assert widget.windowTitle() == "whisper-french.mp3"
def test_should_emit_triggered_event(self, qtbot: QtBot):
widget = FileTranscriberWidget(
file_paths=['testdata/whisper-french.mp3'], default_output_file_name='', parent=None)
file_paths=["testdata/whisper-french.mp3"],
default_output_file_name="",
parent=None,
)
qtbot.add_widget(widget)
mock_triggered = Mock()
@ -24,9 +30,11 @@ class TestFileTranscriberWidget:
with qtbot.wait_signal(widget.triggered, timeout=30 * 1000):
qtbot.mouseClick(widget.run_button, Qt.MouseButton.LeftButton)
transcription_options, file_transcription_options, model_path = mock_triggered.call_args[
0][0]
(
transcription_options,
file_transcription_options,
model_path,
) = mock_triggered.call_args[0][0]
assert transcription_options.language is None
assert file_transcription_options.file_paths == [
'testdata/whisper-french.mp3']
assert file_transcription_options.file_paths == ["testdata/whisper-french.mp3"]
assert len(model_path) > 0

View file

@ -8,7 +8,7 @@ class TestModelDownloadProgressDialog:
def test_should_show_dialog(self, qtbot):
dialog = ModelDownloadProgressDialog(model_type=ModelType.WHISPER, parent=None)
qtbot.add_widget(dialog)
assert dialog.labelText() == 'Downloading model (0%)'
assert dialog.labelText() == "Downloading model (0%)"
def test_should_update_label_on_progress(self, qtbot):
dialog = ModelDownloadProgressDialog(model_type=ModelType.WHISPER, parent=None)
@ -16,12 +16,10 @@ class TestModelDownloadProgressDialog:
dialog.set_value(0.0)
dialog.set_value(0.01)
assert dialog.labelText().startswith(
'Downloading model (1%')
assert dialog.labelText().startswith("Downloading model (1%")
dialog.set_value(0.1)
assert dialog.labelText().startswith(
'Downloading model (10%')
assert dialog.labelText().startswith("Downloading model (10%")
# Other windows should not be processing while models are being downloaded
def test_should_be_an_application_modal(self, qtbot):

View file

@ -7,8 +7,8 @@ class TestModelTypeComboBox:
qtbot.add_widget(widget)
assert widget.count() == 5
assert widget.itemText(0) == 'Whisper'
assert widget.itemText(1) == 'Whisper.cpp'
assert widget.itemText(2) == 'Hugging Face'
assert widget.itemText(3) == 'Faster Whisper'
assert widget.itemText(4) == 'OpenAI Whisper API'
assert widget.itemText(0) == "Whisper"
assert widget.itemText(1) == "Whisper.cpp"
assert widget.itemText(2) == "Hugging Face"
assert widget.itemText(3) == "Faster Whisper"
assert widget.itemText(4) == "OpenAI Whisper API"

View file

@ -3,14 +3,14 @@ from buzz.widgets.openai_api_key_line_edit import OpenAIAPIKeyLineEdit
class TestOpenAIAPIKeyLineEdit:
def test_should_emit_key_changed(self, qtbot):
line_edit = OpenAIAPIKeyLineEdit(key='')
line_edit = OpenAIAPIKeyLineEdit(key="")
qtbot.add_widget(line_edit)
with qtbot.wait_signal(line_edit.key_changed):
line_edit.setText('abcdefg')
line_edit.setText("abcdefg")
def test_should_toggle_visibility(self, qtbot):
line_edit = OpenAIAPIKeyLineEdit(key='')
line_edit = OpenAIAPIKeyLineEdit(key="")
qtbot.add_widget(line_edit)
assert line_edit.echoMode() == OpenAIAPIKeyLineEdit.EchoMode.Password

View file

@ -4,32 +4,35 @@ import pytest
from PyQt6.QtWidgets import QPushButton, QMessageBox, QLineEdit
from buzz.store.keyring_store import KeyringStore
from buzz.widgets.preferences_dialog.general_preferences_widget import \
GeneralPreferencesWidget
from buzz.widgets.preferences_dialog.general_preferences_widget import (
GeneralPreferencesWidget,
)
class TestGeneralPreferencesWidget:
def test_should_disable_test_button_if_no_api_key(self, qtbot):
widget = GeneralPreferencesWidget(keyring_store=self.get_keyring_store(''),
default_export_file_name='')
widget = GeneralPreferencesWidget(
keyring_store=self.get_keyring_store(""), default_export_file_name=""
)
qtbot.add_widget(widget)
test_button = widget.findChild(QPushButton)
assert isinstance(test_button, QPushButton)
assert test_button.text() == 'Test'
assert test_button.text() == "Test"
assert not test_button.isEnabled()
line_edit = widget.findChild(QLineEdit)
assert isinstance(line_edit, QLineEdit)
line_edit.setText('123')
line_edit.setText("123")
assert test_button.isEnabled()
def test_should_test_openai_api_key(self, qtbot):
widget = GeneralPreferencesWidget(
keyring_store=self.get_keyring_store('wrong-api-key'),
default_export_file_name='')
keyring_store=self.get_keyring_store("wrong-api-key"),
default_export_file_name="",
)
qtbot.add_widget(widget)
test_button = widget.findChild(QPushButton)
@ -42,9 +45,11 @@ class TestGeneralPreferencesWidget:
def mock_called():
mock.assert_called()
assert mock.call_args[0][1] == 'OpenAI API Key Test'
assert mock.call_args[0][
2] == 'Incorrect API key provided: wrong-ap*-key. You can find your API key at https://platform.openai.com/account/api-keys.'
assert mock.call_args[0][1] == "OpenAI API Key Test"
assert (
mock.call_args[0][2]
== "Incorrect API key provided: wrong-ap*-key. You can find your API key at https://platform.openai.com/account/api-keys."
)
qtbot.waitUntil(mock_called)

View file

@ -5,16 +5,20 @@ from PyQt6.QtCore import Qt
from PyQt6.QtWidgets import QComboBox, QPushButton
from pytestqt.qtbot import QtBot
from buzz.model_loader import get_whisper_file_path, WhisperModelSize, \
TranscriptionModel, \
ModelType
from buzz.widgets.preferences_dialog.models_preferences_widget import \
ModelsPreferencesWidget
from buzz.model_loader import (
get_whisper_file_path,
WhisperModelSize,
TranscriptionModel,
ModelType,
)
from buzz.widgets.preferences_dialog.models_preferences_widget import (
ModelsPreferencesWidget,
)
from tests.model_loader import get_model_path
class TestModelsPreferencesWidget:
@pytest.fixture(scope='class')
@pytest.fixture(scope="class")
def clear_model_cache(self):
file_path = get_whisper_file_path(size=WhisperModelSize.TINY)
if os.path.isfile(file_path):
@ -25,10 +29,10 @@ class TestModelsPreferencesWidget:
qtbot.add_widget(widget)
first_item = widget.model_list_widget.topLevelItem(0)
assert first_item.text(0) == 'Downloaded'
assert first_item.text(0) == "Downloaded"
second_item = widget.model_list_widget.topLevelItem(1)
assert second_item.text(0) == 'Available for Download'
assert second_item.text(0) == "Available for Download"
def test_should_change_model_type(self, qtbot):
widget = ModelsPreferencesWidget()
@ -36,36 +40,38 @@ class TestModelsPreferencesWidget:
combo_box = widget.findChild(QComboBox)
assert isinstance(combo_box, QComboBox)
combo_box.setCurrentText('Faster Whisper')
combo_box.setCurrentText("Faster Whisper")
first_item = widget.model_list_widget.topLevelItem(0)
assert first_item.text(0) == 'Downloaded'
assert first_item.text(0) == "Downloaded"
second_item = widget.model_list_widget.topLevelItem(1)
assert second_item.text(0) == 'Available for Download'
assert second_item.text(0) == "Available for Download"
def test_should_download_model(self, qtbot: QtBot, clear_model_cache):
# make progress dialog non-modal to unblock qtbot.wait_until
widget = ModelsPreferencesWidget(
progress_dialog_modality=Qt.WindowModality.NonModal)
progress_dialog_modality=Qt.WindowModality.NonModal
)
qtbot.add_widget(widget)
model = TranscriptionModel(model_type=ModelType.WHISPER,
whisper_model_size=WhisperModelSize.TINY)
model = TranscriptionModel(
model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY
)
assert model.get_local_model_path() is None
available_item = widget.model_list_widget.topLevelItem(1)
assert available_item.text(0) == 'Available for Download'
assert available_item.text(0) == "Available for Download"
tiny_item = available_item.child(0)
assert tiny_item.text(0) == 'Tiny'
assert tiny_item.text(0) == "Tiny"
tiny_item.setSelected(True)
download_button = widget.findChild(QPushButton, 'DownloadButton')
download_button = widget.findChild(QPushButton, "DownloadButton")
assert isinstance(download_button, QPushButton)
assert download_button.text() == 'Download'
assert download_button.text() == "Download"
download_button.click()
def downloaded_model():
@ -73,22 +79,26 @@ class TestModelsPreferencesWidget:
_downloaded_item = widget.model_list_widget.topLevelItem(0)
assert _downloaded_item.childCount() > 0
assert _downloaded_item.child(0).text(0) == 'Tiny'
assert _downloaded_item.child(0).text(0) == "Tiny"
_available_item = widget.model_list_widget.topLevelItem(1)
assert _available_item.childCount() == 0 or _available_item.child(0).text(
0) != 'Tiny'
assert (
_available_item.childCount() == 0
or _available_item.child(0).text(0) != "Tiny"
)
# model file exists
assert os.path.isfile(get_whisper_file_path(size=model.whisper_model_size))
qtbot.wait_until(callback=downloaded_model, timeout=60_000)
@pytest.fixture(scope='class')
@pytest.fixture(scope="class")
def whisper_tiny_model_path(self) -> str:
return get_model_path(transcription_model=TranscriptionModel(
model_type=ModelType.WHISPER,
whisper_model_size=WhisperModelSize.TINY))
return get_model_path(
transcription_model=TranscriptionModel(
model_type=ModelType.WHISPER, whisper_model_size=WhisperModelSize.TINY
)
)
def test_should_show_downloaded_model(self, qtbot, whisper_tiny_model_path):
widget = ModelsPreferencesWidget()
@ -96,15 +106,16 @@ class TestModelsPreferencesWidget:
qtbot.add_widget(widget)
available_item = widget.model_list_widget.topLevelItem(0)
assert available_item.text(0) == 'Downloaded'
assert available_item.text(0) == "Downloaded"
tiny_item = available_item.child(0)
assert tiny_item.text(0) == 'Tiny'
assert tiny_item.text(0) == "Tiny"
tiny_item.setSelected(True)
delete_button = widget.findChild(QPushButton, 'DeleteButton')
delete_button = widget.findChild(QPushButton, "DeleteButton")
assert delete_button.isVisible()
show_file_location_button = widget.findChild(QPushButton,
'ShowFileLocationButton')
show_file_location_button = widget.findChild(
QPushButton, "ShowFileLocationButton"
)
assert show_file_location_button.isVisible()

View file

@ -6,14 +6,14 @@ from buzz.widgets.preferences_dialog.preferences_dialog import PreferencesDialog
class TestPreferencesDialog:
def test_create(self, qtbot: QtBot):
dialog = PreferencesDialog(shortcuts={}, default_export_file_name='')
dialog = PreferencesDialog(shortcuts={}, default_export_file_name="")
qtbot.add_widget(dialog)
assert dialog.windowTitle() == 'Preferences'
assert dialog.windowTitle() == "Preferences"
tab_widget = dialog.findChild(QTabWidget)
assert isinstance(tab_widget, QTabWidget)
assert tab_widget.count() == 3
assert tab_widget.tabText(0) == 'General'
assert tab_widget.tabText(1) == 'Models'
assert tab_widget.tabText(2) == 'Shortcuts'
assert tab_widget.tabText(0) == "General"
assert tab_widget.tabText(1) == "Models"
assert tab_widget.tabText(2) == "Shortcuts"

View file

@ -1,14 +1,17 @@
from PyQt6.QtWidgets import QPushButton, QLabel
from buzz.settings.shortcut import Shortcut
from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import \
ShortcutsEditorPreferencesWidget
from buzz.widgets.preferences_dialog.shortcuts_editor_preferences_widget import (
ShortcutsEditorPreferencesWidget,
)
from buzz.widgets.sequence_edit import SequenceEdit
class TestShortcutsEditorWidget:
def test_should_reset_to_defaults(self, qtbot):
widget = ShortcutsEditorPreferencesWidget(shortcuts=Shortcut.get_default_shortcuts())
widget = ShortcutsEditorPreferencesWidget(
shortcuts=Shortcut.get_default_shortcuts()
)
qtbot.add_widget(widget)
reset_button = widget.findChild(QPushButton)
@ -19,12 +22,13 @@ class TestShortcutsEditorWidget:
sequence_edits = widget.findChildren(SequenceEdit)
expected = (
('Open Record Window', 'Ctrl+R'),
('Import File', 'Ctrl+O'),
('Open Preferences Window', 'Ctrl+,'),
('Open Transcript Viewer', 'Ctrl+E'),
('Clear History', 'Ctrl+S'),
('Cancel Transcription', 'Ctrl+X'))
("Open Record Window", "Ctrl+R"),
("Import File", "Ctrl+O"),
("Open Preferences Window", "Ctrl+,"),
("Open Transcript Viewer", "Ctrl+E"),
("Clear History", "Ctrl+S"),
("Cancel Transcription", "Ctrl+X"),
)
for i, (label, sequence_edit) in enumerate(zip(labels, sequence_edits)):
assert isinstance(label, QLabel)

View file

@ -2,57 +2,70 @@ import datetime
from pytestqt.qtbot import QtBot
from buzz.transcriber import FileTranscriptionTask, TranscriptionOptions, FileTranscriptionOptions
from buzz.transcriber import (
FileTranscriptionTask,
TranscriptionOptions,
FileTranscriptionOptions,
)
from buzz.widgets.transcription_tasks_table_widget import TranscriptionTasksTableWidget
class TestTranscriptionTasksTableWidget:
def test_upsert_task(self, qtbot: QtBot):
widget = TranscriptionTasksTableWidget()
qtbot.add_widget(widget)
task = FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3',
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3']), model_path='',
status=FileTranscriptionTask.Status.QUEUED)
task = FileTranscriptionTask(
id=0,
file_path="testdata/whisper-french.mp3",
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(
file_paths=["testdata/whisper-french.mp3"]
),
model_path="",
status=FileTranscriptionTask.Status.QUEUED,
)
task.queued_at = datetime.datetime(2023, 4, 12, 0, 0, 0)
task.started_at = datetime.datetime(2023, 4, 12, 0, 0, 5)
widget.upsert_task(task)
assert widget.rowCount() == 1
assert widget.item(0, 1).text() == 'whisper-french.mp3'
assert widget.item(0, 2).text() == 'Queued'
assert widget.item(0, 1).text() == "whisper-french.mp3"
assert widget.item(0, 2).text() == "Queued"
task.status = FileTranscriptionTask.Status.IN_PROGRESS
task.fraction_completed = 0.3524
widget.upsert_task(task)
assert widget.rowCount() == 1
assert widget.item(0, 1).text() == 'whisper-french.mp3'
assert widget.item(0, 2).text() == 'In Progress (35%)'
assert widget.item(0, 1).text() == "whisper-french.mp3"
assert widget.item(0, 2).text() == "In Progress (35%)"
task.status = FileTranscriptionTask.Status.COMPLETED
task.completed_at = datetime.datetime(2023, 4, 12, 0, 0, 10)
widget.upsert_task(task)
assert widget.rowCount() == 1
assert widget.item(0, 1).text() == 'whisper-french.mp3'
assert widget.item(0, 2).text() == 'Completed (5s)'
assert widget.item(0, 1).text() == "whisper-french.mp3"
assert widget.item(0, 2).text() == "Completed (5s)"
def test_upsert_task_no_timings(self, qtbot: QtBot):
widget = TranscriptionTasksTableWidget()
qtbot.add_widget(widget)
task = FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3',
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3']), model_path='',
status=FileTranscriptionTask.Status.COMPLETED)
task = FileTranscriptionTask(
id=0,
file_path="testdata/whisper-french.mp3",
transcription_options=TranscriptionOptions(),
file_transcription_options=FileTranscriptionOptions(
file_paths=["testdata/whisper-french.mp3"]
),
model_path="",
status=FileTranscriptionTask.Status.COMPLETED,
)
widget.upsert_task(task)
assert widget.rowCount() == 1
assert widget.item(0, 1).text() == 'whisper-french.mp3'
assert widget.item(0, 2).text() == 'Completed'
assert widget.item(0, 1).text() == "whisper-french.mp3"
assert widget.item(0, 2).text() == "Completed"

View file

@ -6,69 +6,85 @@ from PyQt6.QtGui import QKeyEvent
from PyQt6.QtWidgets import QPushButton, QToolBar, QToolButton
from pytestqt.qtbot import QtBot
from buzz.transcriber import FileTranscriptionTask, FileTranscriptionOptions, TranscriptionOptions, Segment
from buzz.widgets.transcription_segments_editor_widget import TranscriptionSegmentsEditorWidget
from buzz.transcriber import (
FileTranscriptionTask,
FileTranscriptionOptions,
TranscriptionOptions,
Segment,
)
from buzz.widgets.transcription_segments_editor_widget import (
TranscriptionSegmentsEditorWidget,
)
from buzz.widgets.transcription_viewer_widget import TranscriptionViewerWidget
class TestTranscriptionViewerWidget:
@pytest.fixture()
def task(self) -> FileTranscriptionTask:
return FileTranscriptionTask(id=0, file_path='testdata/whisper-french.mp3',
file_transcription_options=FileTranscriptionOptions(
file_paths=['testdata/whisper-french.mp3']),
transcription_options=TranscriptionOptions(),
segments=[Segment(40, 299, 'Bien'), Segment(299, 329, 'venue dans')],
model_path='')
return FileTranscriptionTask(
id=0,
file_path="testdata/whisper-french.mp3",
file_transcription_options=FileTranscriptionOptions(
file_paths=["testdata/whisper-french.mp3"]
),
transcription_options=TranscriptionOptions(),
segments=[Segment(40, 299, "Bien"), Segment(299, 329, "venue dans")],
model_path="",
)
def test_should_display_segments(self, qtbot: QtBot, task):
widget = TranscriptionViewerWidget(
transcription_task=task, open_transcription_output=False)
transcription_task=task, open_transcription_output=False
)
qtbot.add_widget(widget)
assert widget.windowTitle() == 'whisper-french.mp3'
assert widget.windowTitle() == "whisper-french.mp3"
editor = widget.findChild(TranscriptionSegmentsEditorWidget)
assert isinstance(editor, TranscriptionSegmentsEditorWidget)
assert editor.item(0, 0).text() == '00:00:00.040'
assert editor.item(0, 1).text() == '00:00:00.299'
assert editor.item(0, 2).text() == 'Bien'
assert editor.item(0, 0).text() == "00:00:00.040"
assert editor.item(0, 1).text() == "00:00:00.299"
assert editor.item(0, 2).text() == "Bien"
def test_should_update_segment_text(self, qtbot, task):
widget = TranscriptionViewerWidget(
transcription_task=task, open_transcription_output=False)
transcription_task=task, open_transcription_output=False
)
qtbot.add_widget(widget)
editor = widget.findChild(TranscriptionSegmentsEditorWidget)
assert isinstance(editor, TranscriptionSegmentsEditorWidget)
# Change text
editor.item(0, 2).setText('Biens')
assert task.segments[0].text == 'Biens'
editor.item(0, 2).setText("Biens")
assert task.segments[0].text == "Biens"
# Undo
toolbar = widget.findChild(QToolBar)
undo_action, redo_action = toolbar.actions()
undo_action.trigger()
assert task.segments[0].text == 'Bien'
assert task.segments[0].text == "Bien"
redo_action.trigger()
assert task.segments[0].text == 'Biens'
assert task.segments[0].text == "Biens"
def test_should_export_segments(self, tmp_path: pathlib.Path, qtbot: QtBot, task):
widget = TranscriptionViewerWidget(
transcription_task=task, open_transcription_output=False)
transcription_task=task, open_transcription_output=False
)
qtbot.add_widget(widget)
export_button = widget.findChild(QPushButton)
assert isinstance(export_button, QPushButton)
output_file_path = tmp_path / 'whisper.txt'
with patch('PyQt6.QtWidgets.QFileDialog.getSaveFileName') as save_file_name_mock:
save_file_name_mock.return_value = (str(output_file_path), '')
output_file_path = tmp_path / "whisper.txt"
with patch(
"PyQt6.QtWidgets.QFileDialog.getSaveFileName"
) as save_file_name_mock:
save_file_name_mock.return_value = (str(output_file_path), "")
export_button.menu().actions()[0].trigger()
output_file = open(output_file_path, 'r', encoding='utf-8')
assert 'Bien\nvenue dans' in output_file.read()
output_file = open(output_file_path, "r", encoding="utf-8")
assert "Bien\nvenue dans" in output_file.read()