diff --git a/mscxyz/score.py b/mscxyz/score.py index 7d8d049..b96d23c 100644 --- a/mscxyz/score.py +++ b/mscxyz/score.py @@ -20,6 +20,7 @@ from mscxyz.meta import Meta from mscxyz.settings import get_args from mscxyz.style import Style +from mscxyz.xml import Xml if typing.TYPE_CHECKING: from lxml.etree import _XPathObject @@ -46,6 +47,8 @@ class Score: xml_root: _Element """The root element of the XML tree. See the `lxml API `_.""" + xml: Xml + version: float """The MuseScore version, for example 2.03 or 3.01""" @@ -74,19 +77,23 @@ def __init__(self, src: str | Path) -> None: self.xml_file = str(self.path) self.errors = [] - try: - self.xml_root = lxml.etree.parse(self.xml_file).getroot() - except lxml.etree.XMLSyntaxError as e: - self.errors.append(e) - else: + + element, e = Xml.parse_file_try(self.xml_file) + + if element is not None: + self.xml_root = element + self.xml = Xml(element) self.version = self.get_version() + if e is not None: + self.errors.append(e) + if self.extension == "mscz" and self.version_major == 4 and self.zip_container: self.style_file = self.zip_container.score_style_file @property def xml_string(self) -> str: - return utils.xml.tostring(self.xml_root) + return self.xml.tostring(self.xml_root) @property def version_major(self) -> int: diff --git a/mscxyz/xml.py b/mscxyz/xml.py index 1374f43..1dfaf65 100644 --- a/mscxyz/xml.py +++ b/mscxyz/xml.py @@ -21,16 +21,16 @@ class Xml: """A wrapper around lxml.etree""" - element: _Element + root: _Element def __init__(self, element: _Element) -> None: - self.element = element + self.root = element def __get_element(self, element: ElementLike = None) -> _Element: if isinstance(element, _ElementTree): return element.getroot() if element is None: - return self.element + return self.root return element def __normalize_element(self, element: ElementLike = None) -> _Element | None: @@ -39,7 +39,7 @@ def __normalize_element(self, element: ElementLike = None) -> _Element | None: return element @staticmethod - def read(path: str | Path | TextIOWrapper) -> _Element: + def parse_file(path: str | Path | TextIOWrapper) -> _Element: """ Read an XML file and return the root element. @@ -50,12 +50,24 @@ def read(path: str | Path | TextIOWrapper) -> _Element: return lxml.etree.parse(path).getroot() @staticmethod - def from_file(path: str | Path | TextIOWrapper) -> Xml: - return Xml(Xml.read(path)) + def parse_string(xml_markup: str) -> _Element: + return lxml.etree.XML(xml_markup) @staticmethod - def parse(string: str) -> _Element: - return lxml.etree.XML(string) + def parse_file_try( + path: str | Path | TextIOWrapper, + ) -> tuple[_Element | None, Exception | None]: + element: _Element | None = None + error: Exception | None = None + try: + element = lxml.etree.parse(path).getroot() + except lxml.etree.XMLSyntaxError as e: + error = e + return (element, error) + + @staticmethod + def new(path: str | Path | TextIOWrapper) -> Xml: + return Xml(Xml.parse_file(path)) def tostring(self, element: ElementLike = None) -> str: """ diff --git a/tests/test_xml.py b/tests/test_xml.py index 07058d0..889f476 100644 --- a/tests/test_xml.py +++ b/tests/test_xml.py @@ -13,17 +13,17 @@ root = helper.get_xml_root("score.mscz", 4) -xml = Xml.from_file(xml_file) +xml = Xml.new(xml_file) def test_read() -> None: - element = xml.read(xml_file) + element = xml.parse_file(xml_file) assert element.tag == "museScore" def test_from_file() -> None: - x = xml.from_file(xml_file) - assert x.element.tag == "museScore" + x = xml.new(xml_file) + assert x.root.tag == "museScore" def test_find_safe() -> None: @@ -60,7 +60,7 @@ def test_xpathall_safe() -> None: def test_xml_write(tmp_path: Path) -> None: dest = tmp_path / "test.xml" - element = Xml.parse("") + element = Xml.parse_string("") xml.write(dest, element) result: str = utils.read_file(dest) assert result == (