Skip to content

Commit

Permalink
Add some type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
Josef-Friedrich committed Dec 30, 2023
1 parent f6e137a commit 262cf05
Show file tree
Hide file tree
Showing 6 changed files with 258 additions and 250 deletions.
12 changes: 11 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,15 @@
"**/.ropeproject": true,
"**/_build/": true,
"**/.tex/": true
}
},
"python.defaultInterpreterPath": ".venv/bin/python",
"python.testing.unittestArgs": [
"-v",
"-s",
"./tests",
"-p",
"test_*.py"
],
"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true,
}
99 changes: 69 additions & 30 deletions mscxyz/meta.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
"""Class for metadata maniplation"""

from __future__ import annotations

import json
import re
import typing

import lxml
import lxml.etree
import tmep

from mscxyz.score_file_classes import MscoreXmlTree
from mscxyz.utils import color, get_args

if typing.TYPE_CHECKING:
from lxml.etree import _Element


class ReadOnlyFieldError(Exception):
def __init__(self, field: str):
Expand All @@ -18,7 +24,7 @@ def __init__(self, field: str):


class UnkownFieldError(Exception):
def __init__(self, field: str, valid_fields: typing.Sequence):
def __init__(self, field: str, valid_fields: typing.Sequence[str]):
self.msg = "Unkown field of name “{}”! Valid field names are: {}".format(
field, ", ".join(valid_fields)
)
Expand All @@ -40,7 +46,16 @@ def __init__(self, format_string: str):
Exception.__init__(self, self.msg)


def distribute_field(source, format_string: str):
def distribute_field(source: str, format_string: str) -> dict[str, str]:
"""
Distributes the values from the source string into a dictionary based on the format string.
:param source: The source string from which values will be extracted.
:param format_string: The format string that specifies the pattern of the values to be extracted.
:return: A dictionary mapping field names to their corresponding values.
:raises FormatStringNoFieldError: If the format string does not contain any field markers.
:raises UnmatchedFormatStringError: If the format string does not match the source string.
"""
fields = re.findall(r"\$([a-z_]*)", format_string)
if not fields:
raise FormatStringNoFieldError(format_string)
Expand All @@ -52,12 +67,26 @@ def distribute_field(source, format_string: str):
return dict(zip(fields, values))


def to_underscore(field):
def to_underscore(field: str) -> str:
"""
Convert a camel case string to snake case.
:param field: The camel case string to be converted.
:return: The snake case representation of the input string.
"""
return re.sub("([A-Z]+)", r"_\1", field).lower()


def export_to_dict(obj, fields):
out = {}
def export_to_dict(obj: object, fields: typing.Iterable[str]) -> dict[str, str]:
"""
Export the specified fields of an object to a dictionary.
:param obj: The object to export.
:param fields: The fields to include in the dictionary.
:return: A dictionary containing the specified fields and their values.
"""
out: dict[str, str] = {}
for field in fields:
value = getattr(obj, field)
if not value:
Expand Down Expand Up @@ -102,30 +131,32 @@ class MetaTag:
"workTitle",
)

xml_root: _Element

@staticmethod
def _to_camel_case(field):
def _to_camel_case(field: str) -> str:
return re.sub(r"(?!^)_([a-zA-Z])", lambda match: match.group(1).upper(), field)

def __init__(self, xml_root):
def __init__(self, xml_root: _Element) -> None:
self.xml_root = xml_root

def _get_element(self, field: str):
def _get_element(self, field: str) -> _Element | None:
for element in self.xml_root.xpath('//metaTag[@name="' + field + '"]'):
return element

def _get_text(self, field: str) -> str:
element = self._get_element(field)
def _get_text(self, field: str) -> str | None:
element: _Element | None = self._get_element(field)
if hasattr(element, "text"):
return element.text

def __getattr__(self, field):
def __getattr__(self, field: str):
field = self._to_camel_case(field)
if field not in self.fields:
raise UnkownFieldError(field, self.fields)
else:
return self._get_text(field)

def __setattr__(self, field, value):
def __setattr__(self, field: str, value: str) -> None:
if field == "xml_root" or field == "fields":
self.__dict__[field] = value
else:
Expand Down Expand Up @@ -183,7 +214,9 @@ class Vbox:
"Title",
)

def __init__(self, xml_root):
xml_root: _Element

def __init__(self, xml_root: _Element):
self.xml_root = xml_root
xpath = '/museScore/Score/Staff[@id="1"]'
if not xml_root.xpath(xpath + "/VBox"):
Expand Down Expand Up @@ -256,7 +289,9 @@ class Combined(MscoreXmlTree):
"title",
)

def __init__(self, xml_root):
xml_root: _Element

def __init__(self, xml_root: _Element):
self.xml_root = xml_root
self.metatag = MetaTag(xml_root)
self.vbox = Vbox(xml_root)
Expand Down Expand Up @@ -302,15 +337,15 @@ def lyricist(self, value):
class InterfaceReadWrite:
objects = ("metatag", "vbox", "combined")

def __init__(self, xml_root):
def __init__(self, xml_root: _Element) -> None:
self.metatag = MetaTag(xml_root)
self.vbox = Vbox(xml_root)
self.combined = Combined(xml_root)
self.fields = self.get_all_fields()

@staticmethod
def get_all_fields():
fields = []
def get_all_fields() -> list[str]:
fields: list[str] = []
for field in MetaTag.fields:
fields.append("metatag_" + to_underscore(field))
for field in Vbox.fields:
Expand All @@ -320,7 +355,7 @@ def get_all_fields():
return sorted(fields)

@staticmethod
def _split(field):
def _split(field: str):
match = re.search(r"([^_]*)_(.*)", field)
if not match:
raise ValueError("Field “" + field + "” can’t be splitted!")
Expand Down Expand Up @@ -358,40 +393,44 @@ class InterfaceReadOnly:
"readonly_relpath_backup",
]

def __init__(self, tree):
xml_tree: MscoreXmlTree

def __init__(self, tree: MscoreXmlTree):
self.xml_tree = tree

@property
def readonly_abspath(self):
def readonly_abspath(self) -> str:
return self.xml_tree.abspath

@property
def readonly_basename(self):
def readonly_basename(self) -> str:
return self.xml_tree.basename

@property
def readonly_dirname(self):
def readonly_dirname(self) -> str:
return self.xml_tree.dirname

@property
def readonly_extension(self):
def readonly_extension(self) -> str:
return self.xml_tree.extension

@property
def readonly_filename(self):
def readonly_filename(self) -> str:
return self.xml_tree.filename

@property
def readonly_relpath(self):
def readonly_relpath(self) -> str:
return self.xml_tree.relpath

@property
def readonly_relpath_backup(self):
def readonly_relpath_backup(self) -> str:
return self.xml_tree.relpath_backup


class Interface:
def __init__(self, tree):
xml_tree: MscoreXmlTree

def __init__(self, tree: MscoreXmlTree):
self.xml_tree = tree
self.read_only = InterfaceReadOnly(tree)
self.read_write = InterfaceReadWrite(tree.xml_root)
Expand All @@ -404,13 +443,13 @@ def get_all_fields():
def export_to_dict(self):
return export_to_dict(self, self.fields)

def __getattr__(self, field):
def __getattr__(self, field: str):
if re.match(r"^readonly_", field):
return getattr(self.read_only, field)
else:
return getattr(self.read_write, field)

def __setattr__(self, field, value):
def __setattr__(self, field: str, value):
if field in ("xml_tree", "read_only", "read_write", "fields"):
self.__dict__[field] = value
elif not re.match(r"^readonly_", field):
Expand All @@ -420,7 +459,7 @@ def __setattr__(self, field, value):


class Meta(MscoreXmlTree):
def __init__(self, relpath):
def __init__(self, relpath: str):
super(Meta, self).__init__(relpath)

if not self.errors:
Expand Down
Loading

0 comments on commit 262cf05

Please sign in to comment.