Skip to content

Commit

Permalink
New class Xml
Browse files Browse the repository at this point in the history
  • Loading branch information
Josef-Friedrich committed Jan 20, 2024
1 parent e203a14 commit c91132d
Show file tree
Hide file tree
Showing 2 changed files with 335 additions and 0 deletions.
266 changes: 266 additions & 0 deletions mscxyz/xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
from __future__ import annotations

import typing
from io import TextIOWrapper
from pathlib import Path
from typing import Literal, Optional

import lxml
import lxml.etree
from lxml.etree import _Element, _ElementTree

if typing.TYPE_CHECKING:
from lxml.etree import _DictAnyStr, _XPathObject


ListExtension = Literal["mscz", "mscx", "both"]

ElementLike = _Element | _ElementTree | None


class Xml:
"""A wrapper around lxml.etree"""

element: _Element

def __init__(self, element: _Element) -> None:
self.element = element

def __get_element(self, element: ElementLike = None) -> _Element:
if isinstance(element, _ElementTree):
return element.getroot()
if element is None:
return self.element
return element

def __normalize_element(self, element: ElementLike = None) -> _Element | None:
if isinstance(element, _ElementTree):
return element.getroot()
return element

@staticmethod
def read(path: str | Path | TextIOWrapper) -> _Element:
"""
Read an XML file and return the root element.
:param path: The path to the XML file.
:return: The root element of the XML file.
"""
return lxml.etree.parse(path).getroot()

@staticmethod
def from_file(path: str | Path | TextIOWrapper) -> Xml:
return Xml(Xml.read(path))

@staticmethod
def parse(string: str) -> _Element:
return lxml.etree.XML(string)

def tostring(self, element: ElementLike = None) -> str:
"""
Convert the XML element or tree to a string.
:param element: The XML element or tree to write.
"""
element = self.__get_element(element)
# maybe use: xml_declaration=True, pretty_print=True
# TestFileCompare not passing ...
return (
'<?xml version="1.0" encoding="UTF-8"?>\n'
+ lxml.etree.tostring(element, encoding="UTF-8").decode("utf-8")
+ "\n"
)

def write(self, path: str | Path, element: ElementLike = None) -> None:
"""
Write the XML element or tree to the specified file.
:param path: The path to the file.
:param element: The XML element or tree to write.
:return: None
"""
element = self.__get_element(element)
with open(path, "w") as document:
document.write(self.tostring(element))

def find_safe(self, path: str, element: ElementLike = None) -> _Element:
"""
Find an element in the given XML element using the specified element path.
:param path: The path to the desired element.
:param element: The XML element to search within.
:return: The found element.
:raises ValueError: If the element is not found.
"""

element = self.__get_element(element)
result: _Element | None = element.find(path)
if result is None:
raise ValueError(f"Path {path} not found in element {element}!")
return result

def xpath(self, path: str, element: ElementLike = None) -> _Element | None:
"""
Find the first matching element in the XML tree using XPath.
:param path: The XPath expression to search for.
:param element: The root element of the XML tree.
:return: The first matching element or None if no match is found.
"""
element = self.__get_element(element)
output: list[_Element] | None = self.xpathall(path, element)
if output and len(output) > 0:
return output[0]

return None

def xpath_safe(self, path: str, element: ElementLike = None) -> _Element:
"""
Safely retrieves the first matching XML element using the given XPath expression.
:param path: The XPath expression to match elements.
:param element: The XML element to search within.
:return: The first matching XML element.XPath
:raises ValueError: If more than one element is found matching the XPath expression.
"""
element = self.__get_element(element)
output: list[_Element] = self.xpathall_safe(
path,
element,
)
if len(output) > 1:
raise ValueError(
f"XPath “{path}” found more than one element in {element}!"
)
return output[0]

def xpathall(self, path: str, element: ElementLike = None) -> list[_Element] | None:
"""
Returns a list of elements matching the given XPath expression.
:param path: The XPath expression to match elements.
:param element: The XML element to search within.
:return: A list of elements matching the XPath expression, or None if no
elements are found.
"""
element = self.__get_element(element)
result: _XPathObject = element.xpath(path)
output: list[_Element] = []

if isinstance(result, list):
for item in result:
if isinstance(item, _Element):
output.append(item)

if len(output) > 0:
return output

return None

def xpathall_safe(self, path: str, element: ElementLike = None) -> list[_Element]:
"""
Safely retrieves a list of elements matching the given XPath expression within
the specified element.
:param element: The XML element to search within.
:param path: The XPath expression to match elements.
:return: A list of elements matching the XPath expression.
:raises ValueError: If the XPath expression is not found in the element.
"""
element = self.__get_element(element)
output: list[_Element] | None = self.xpathall(path, element)
if output is None:
raise ValueError(f"XPath “{path}” not found in element {element}!")
return output

def get_text(self, element: ElementLike = None) -> str | None:
"""
Get the text content of an XML element.
:param element: The XML element.
:return: The text content of the XML element, or None if the element is None.
"""
element = self.__normalize_element(element)
if element is None:
return None
if element.text is None:
return None
return element.text

def get_text_safe(self, element: ElementLike = None) -> str:
"""
Safely retrieves the text content from an XML element.
:param element: The XML element to retrieve the text from.
:return: The text content of the element.
:raises ValueError: If the element is None or has no text content.
"""
element = self.__get_element(element)
if element.text is None:
raise ValueError(f"Element {element} has no text!")
return element.text

def set_text(
self, path: str, value: str | int | float, element: ElementLike = None
) -> None:
"""
Set the text value of an XML element at the specified element path.
:param element: The XML element to modify.
:param path: The element path expression to locate the target element.
:param value: The new value to set for the element's text.
:return: None
"""
self.find_safe(path, element).text = str(value)

@staticmethod
def replace(old: _Element, new: _Element) -> None:
parent: _Element | None = old.getparent()
if parent is not None:
parent.replace(old, new)

@staticmethod
def remove(element: _Element | None) -> None:
"""
Remove the given element from its parent.
:param element: The element to be removed.
"""
if element is None:
return None

parent: _Element | None = element.getparent()
if parent is None:
return None

parent.remove(element)

@staticmethod
def create_element(tag_name: str) -> _Element:
return lxml.etree.Element(tag_name)

@staticmethod
def create_sub_element(
parent: _Element,
tag_name: str,
text: Optional[str] = None,
attrib: Optional[_DictAnyStr] = None,
) -> _Element:
element: _Element = lxml.etree.SubElement(parent, tag_name, attrib=attrib)
if text:
element.text = text
return element
69 changes: 69 additions & 0 deletions tests/test_xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Test submodule “utils.py”."""
from __future__ import annotations

from pathlib import Path

import pytest

from mscxyz import utils
from mscxyz.xml import Xml
from tests import helper

xml_file = helper.get_file("simple.mscx", 4)

root = helper.get_xml_root("score.mscz", 4)

xml = Xml.from_file(xml_file)


def test_read() -> None:
element = xml.read(xml_file)
assert element.tag == "museScore"


def test_from_file() -> None:
x = xml.from_file(xml_file)
assert x.element.tag == "museScore"


def test_find_safe() -> None:
element = xml.find_safe(".//Score")
assert element.tag == "Score"


def test_xpath() -> None:
element = xml.xpath(".//xxxxxxx")
assert element is None


class TestXpathSave:
def test_xpath_safe(self) -> None:
element = xml.xpath_safe(".//Score")
assert element.tag == "Score"

def test_xpath_safe_raise(self) -> None:
with pytest.raises(ValueError) as e:
xml.xpath_safe(".//metaTag")
assert "XPath “.//metaTag” found more than one element in" in e.value.args[0]


def test_xpathall() -> None:
element = xml.xpathall(".//xxxxxxx")
assert element is None


def test_xpathall_safe() -> None:
element = xml.xpathall_safe(".//metaTag")
assert isinstance(element, list)
assert len(element) == 16


def test_xml_write(tmp_path: Path) -> None:
dest = tmp_path / "test.xml"
element = Xml.parse("<root><a><b/><c/></a><d><e/></d></root>")
xml.write(dest, element)
result: str = utils.read_file(dest)
assert result == (
'<?xml version="1.0" encoding="UTF-8"?>\n'
"<root><a><b/><c/></a><d><e/></d></root>\n"
)

0 comments on commit c91132d

Please sign in to comment.