diff --git a/README.md b/README.md index 900d7bf..9954764 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,6 @@ python -m graphviz2drawio test/directed/hello.gv.txt ## Roadmap * Migrate to uv/hatch for packaging and dep mgmt -* Support for fill gradient * Support compatible [arrows](https://graphviz.org/docs/attr-types/arrowType/) * Support [multiple edges](https://graphviz.org/Gallery/directed/switch.html) * Support [edges with links](https://graphviz.org/Gallery/directed/pprof.html) diff --git a/graphviz2drawio/models/SVG.py b/graphviz2drawio/models/SVG.py index 5b8ccfc..6386196 100644 --- a/graphviz2drawio/models/SVG.py +++ b/graphviz2drawio/models/SVG.py @@ -11,8 +11,8 @@ def get_first(g: Element, tag: str) -> Element | None: return g.find(f"./{NS_SVG}{tag}") -def count_tags(g: Element, tag: str) -> int: - return len(g.findall(f"./{NS_SVG}{tag}")) +def findall(g: Element, tag: str) -> list[Element]: + return g.findall(f"./{NS_SVG}{tag}") def get_title(g: Element) -> str | None: diff --git a/graphviz2drawio/models/SvgParser.py b/graphviz2drawio/models/SvgParser.py index 53639c0..3081bdf 100644 --- a/graphviz2drawio/models/SvgParser.py +++ b/graphviz2drawio/models/SvgParser.py @@ -1,11 +1,16 @@ +import re from collections import OrderedDict +from collections.abc import Iterable +from math import isclose from xml.etree import ElementTree from graphviz2drawio.mx.Edge import Edge from graphviz2drawio.mx.EdgeFactory import EdgeFactory -from graphviz2drawio.mx.Node import Node +from graphviz2drawio.mx.Node import Gradient, Node from graphviz2drawio.mx.NodeFactory import NodeFactory +from ..mx.Curve import LINE_TOLERANCE +from ..mx.utils import adjust_color_opacity from . import SVG from .commented_tree_builder import COMMENT, CommentedTreeBuilder from .CoordsTranslate import CoordsTranslate @@ -29,17 +34,28 @@ def parse_nodes_edges_clusters( nodes: OrderedDict[str, Node] = OrderedDict() edges: OrderedDict[str, Edge] = OrderedDict() clusters: OrderedDict[str, Node] = OrderedDict() + gradients = dict[str, Gradient]() prev_comment = None for g in root: if g.tag == COMMENT: prev_comment = g.text + elif SVG.is_tag(g, "defs"): + for gradient in _extract_gradients(g): + gradients[gradient[0]] = gradient[1:] elif SVG.is_tag(g, "g"): title = prev_comment or SVG.get_title(g) if title is None: raise MissingTitleError(g) + if (defs := SVG.get_first(g, "defs")) is not None: + for gradient in _extract_gradients(defs): + gradients[gradient[0]] = gradient[1:] if g.attrib["class"] == "node": - nodes[title] = node_factory.from_svg(g, labelloc="c") + nodes[title] = node_factory.from_svg( + g, + labelloc="c", + gradients=gradients, + ) elif g.attrib["class"] == "edge": # We need to merge edges with the same source and target # GV represents multiple labels with multiple edges @@ -52,6 +68,64 @@ def parse_nodes_edges_clusters( else: edges[edge.key_for_label] = edge elif g.attrib["class"] == "cluster": - clusters[title] = node_factory.from_svg(g, labelloc="t") + clusters[title] = node_factory.from_svg( + g, + labelloc="t", + gradients=gradients, + ) return nodes, list(edges.values()), clusters + + +_stop_color_re = re.compile(r"stop-color:([^;]+);") +_stop_opacity_re = re.compile(r"stop-opacity:([^;]+);") + + +def _extract_stop_color(stop: ElementTree.Element) -> str | None: + style = stop.attrib.get("style", "") + if (color := _stop_color_re.search(style)) is not None: + if (opacity := _stop_opacity_re.search(style)) is not None: + return adjust_color_opacity(color.group(1), float(opacity.group(1))) + return None + + +def _extract_gradients( + defs: ElementTree.Element, +) -> Iterable[tuple[str, str, str, str]]: + for radial_gradient in SVG.findall(defs, "radialGradient"): + stops = SVG.findall(radial_gradient, "stop") + start_color = _extract_stop_color(stops[0]) + end_color = _extract_stop_color(stops[-1]) + if start_color is None or end_color is None: + continue + yield ( + radial_gradient.attrib["id"], + start_color, + end_color, + "radial", + ) + for linear_gradient in SVG.findall(defs, "linearGradient"): + stops = SVG.findall(linear_gradient, "stop") + + start_color = _extract_stop_color(stops[0]) + end_color = _extract_stop_color(stops[-1]) + if start_color is None or end_color is None: + continue + + y1 = float(linear_gradient.attrib["y1"]) + y2 = float(linear_gradient.attrib["y2"]) + + gradient_direction = "north" + if isclose(y1, y2, rel_tol=LINE_TOLERANCE): + x1 = float(linear_gradient.attrib["y1"]) + x2 = float(linear_gradient.attrib["y2"]) + gradient_direction = "east" if x1 < x2 else "west" + elif y1 < y2: + gradient_direction = "south" + + yield ( + linear_gradient.attrib["id"], + start_color, + end_color, + gradient_direction, + ) diff --git a/graphviz2drawio/mx/Node.py b/graphviz2drawio/mx/Node.py index 8e6c8cb..64e74cc 100644 --- a/graphviz2drawio/mx/Node.py +++ b/graphviz2drawio/mx/Node.py @@ -1,9 +1,13 @@ +from typing import TypeAlias + from ..models.Rect import Rect from .GraphObj import GraphObj from .MxConst import VERTICAL_ALIGN from .Styles import Styles from .Text import Text +Gradient: TypeAlias = tuple[str, str | None, str] + class Node(GraphObj): def __init__( @@ -12,7 +16,7 @@ def __init__( gid: str, rect: Rect | None, texts: list[Text], - fill: str, + fill: str | Gradient, stroke: str, shape: str, labelloc: str, @@ -46,13 +50,21 @@ def texts_to_mx_value(self) -> str: def get_node_style(self) -> str: style_for_shape = Styles.get_for_shape(self.shape) dashed = 1 if self.dashed else 0 + additional_styling = "" attributes = { - "fill": self.fill, "stroke": self.stroke, "stroke_width": self.stroke_width, "dashed": dashed, } + if isinstance(self.fill, str): + attributes["fill"] = self.fill + elif type(self.fill) is tuple: + attributes["fill"] = self.fill[0] + additional_styling += ( + f"gradientColor={self.fill[1]};gradientDirection={self.fill[2]};" + ) + if (rect := self.rect) is not None and (image_path := rect.image) is not None: from graphviz2drawio.mx.image import image_data_for_path @@ -60,7 +72,7 @@ def get_node_style(self) -> str: attributes["vertical_align"] = VERTICAL_ALIGN.get(self.labelloc, "middle") - return style_for_shape.format(**attributes) + return style_for_shape.format(**attributes) + additional_styling def __repr__(self) -> str: return ( diff --git a/graphviz2drawio/mx/NodeFactory.py b/graphviz2drawio/mx/NodeFactory.py index 100c377..6795493 100644 --- a/graphviz2drawio/mx/NodeFactory.py +++ b/graphviz2drawio/mx/NodeFactory.py @@ -1,3 +1,4 @@ +import re from xml.etree.ElementTree import Element from graphviz2drawio.models import SVG @@ -6,7 +7,7 @@ from ..models.Errors import MissingIdentifiersError from . import MxConst, Shape from .MxConst import DEFAULT_STROKE_WIDTH -from .Node import Node +from .Node import Gradient, Node from .RectFactory import rect_from_ellipse_svg, rect_from_image, rect_from_svg_points from .Text import Text from .utils import adjust_color_opacity @@ -17,11 +18,16 @@ def __init__(self, coords: CoordsTranslate) -> None: super().__init__() self.coords = coords - def from_svg(self, g: Element, labelloc: str) -> Node: + def from_svg( + self, + g: Element, + labelloc: str, + gradients: dict[str, Gradient], + ) -> Node: sid = g.attrib["id"] gid = SVG.get_title(g) rect = None - fill = MxConst.NONE + fill: str | Gradient = MxConst.NONE stroke = MxConst.NONE stroke_width = DEFAULT_STROKE_WIDTH dashed = False @@ -36,7 +42,8 @@ def from_svg(self, g: Element, labelloc: str) -> Node: if (polygon := SVG.get_first(g, "polygon")) is not None: rect = rect_from_svg_points(self.coords, polygon.attrib["points"]) shape = Shape.RECT - fill, stroke = self._extract_fill_and_stroke(polygon) + fill = self._extract_fill(polygon, gradients) + stroke = self._extract_stroke(polygon) stroke_width = polygon.attrib.get("stroke-width", DEFAULT_STROKE_WIDTH) if "stroke-dasharray" in polygon.attrib: dashed = True @@ -50,10 +57,11 @@ def from_svg(self, g: Element, labelloc: str) -> Node: ) shape = ( Shape.ELLIPSE - if SVG.count_tags(g, "ellipse") == 1 + if len(SVG.findall(g, "ellipse")) == 1 else Shape.DOUBLE_CIRCLE ) - fill, stroke = self._extract_fill_and_stroke(ellipse) + fill = self._extract_fill(ellipse, gradients) + stroke = self._extract_stroke(ellipse) stroke_width = ellipse.attrib.get("stroke-width", DEFAULT_STROKE_WIDTH) if "stroke-dasharray" in ellipse.attrib: dashed = True @@ -76,15 +84,25 @@ def from_svg(self, g: Element, labelloc: str) -> Node: dashed=dashed, ) + _fill_url_re = re.compile(r"url\(#([^)]+)\)") + @staticmethod - def _extract_fill_and_stroke(g: Element) -> tuple[str, str]: + def _extract_fill(g: Element, gradients: dict[str, Gradient]) -> str | Gradient: fill = g.attrib.get("fill", MxConst.NONE) - stroke = g.attrib.get("stroke", MxConst.NONE) + if fill.startswith("url"): + match = NodeFactory._fill_url_re.search(fill) + if match is not None: + return gradients[match.group(1)] if "fill-opacity" in g.attrib and fill != MxConst.NONE: fill = adjust_color_opacity(fill, float(g.attrib["fill-opacity"])) + return fill + + @staticmethod + def _extract_stroke(g: Element) -> str: + stroke = g.attrib.get("stroke", MxConst.NONE) if "stroke-opacity" in g.attrib and stroke != MxConst.NONE: stroke = adjust_color_opacity(stroke, float(g.attrib["stroke-opacity"])) - return fill, stroke + return stroke def _extract_texts(self, g: Element) -> tuple[list[Text], complex | None]: texts = [] diff --git a/graphviz2drawio/mx/utils.py b/graphviz2drawio/mx/utils.py index 6350f73..147b111 100644 --- a/graphviz2drawio/mx/utils.py +++ b/graphviz2drawio/mx/utils.py @@ -4,7 +4,10 @@ def adjust_color_opacity(hex_color: str, opacity: float) -> str: hex_color = hex_color.lstrip("#") # Convert hex to RGB - r, g, b = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) + try: + r, g, b = tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) + except ValueError: + return hex_color # Apply opacity over white background r = int(r * opacity + 255 * (1 - opacity)) diff --git a/pyproject.toml b/pyproject.toml index b3768b7..d2558c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ line-ending = "auto" "graphviz2drawio/version.py" = ["T201"] "doc/source/conf.py" = ["A001", "ERA001", "INP001"] "graphviz2drawio/models/commented_tree_builder.py" = ["ANN001", "ANN201", "ANN204"] +"graphviz2drawio/models/SvgParser.py" = ["C901", "PLR0912"] [tool.pytest.ini_options] pythonpath = ". venv/lib/python3.12/site-packages"