diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6357aee --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +[build-system] +requires = ["setuptools>=68"] +build-backend = "setuptools.build_meta" + +[project] +name = "r2s" +version = "0.1.0" +description = "Raster to SVG conversion studio (CLI + GUI)" +requires-python = ">=3.10" +dependencies = [ + "numpy>=1.24", + "opencv-python>=4.8", + "PySide6>=6.6", + "svgwrite>=1.4", +] + +[project.optional-dependencies] +preview = ["cairosvg>=2.7"] + +[project.scripts] +r2s = "r2s.cli:main" + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] diff --git a/src/r2s/__init__.py b/src/r2s/__init__.py new file mode 100644 index 0000000..4178476 --- /dev/null +++ b/src/r2s/__init__.py @@ -0,0 +1 @@ +__all__ = ["pipeline"] diff --git a/src/r2s/cli.py b/src/r2s/cli.py new file mode 100644 index 0000000..a57afe3 --- /dev/null +++ b/src/r2s/cli.py @@ -0,0 +1,52 @@ +from __future__ import annotations +import argparse +import json +from pathlib import Path + +from .pipeline import load_bgr, run_style, available_styles + +def cmd_list(args: argparse.Namespace) -> int: + styles = available_styles() + print(json.dumps(styles, indent=2)) + return 0 + +def cmd_convert(args: argparse.Namespace) -> int: + params = {} + if args.params: + params = json.loads(args.params) + + bgr = load_bgr(args.in_path) + res = run_style(bgr, args.style, params) + + out = Path(args.out_path) + out.write_text(res.svg, encoding="utf-8") + print(f"Wrote SVG: {out}") + if args.meta: + print("Meta:", json.dumps(res.meta, indent=2)) + return 0 + +def cmd_gui(args: argparse.Namespace) -> int: + from .ui.app import run_app + run_app() + return 0 + +def main() -> None: + p = argparse.ArgumentParser(prog="r2s", description="Raster2SVG Studio") + sub = p.add_subparsers(dest="cmd", required=True) + + s_list = sub.add_parser("list-styles", help="List styles and default parameters") + s_list.set_defaults(func=cmd_list) + + s_conv = sub.add_parser("convert", help="Convert an image to SVG") + s_conv.add_argument("--in", dest="in_path", required=True) + s_conv.add_argument("--out", dest="out_path", required=True) + s_conv.add_argument("--style", required=True) + s_conv.add_argument("--params", help="JSON string of parameter overrides") + s_conv.add_argument("--meta", action="store_true") + s_conv.set_defaults(func=cmd_convert) + + s_gui = sub.add_parser("gui", help="Launch GUI") + s_gui.set_defaults(func=cmd_gui) + + args = p.parse_args() + raise SystemExit(args.func(args)) diff --git a/src/r2s/pipeline.py b/src/r2s/pipeline.py new file mode 100644 index 0000000..41fc00a --- /dev/null +++ b/src/r2s/pipeline.py @@ -0,0 +1,29 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Dict +import cv2 +import numpy as np + +from .styles import ALL_STYLES +from .styles.base import StyleResult + +def load_bgr(path: str) -> np.ndarray: + img = cv2.imread(path, cv2.IMREAD_COLOR) + if img is None: + raise ValueError(f"Could not read image: {path}") + return img + +def run_style(bgr: np.ndarray, style_name: str, params: Dict[str, Any]) -> StyleResult: + if style_name not in ALL_STYLES: + raise ValueError(f"Unknown style '{style_name}'. Available: {list(ALL_STYLES.keys())}") + style = ALL_STYLES[style_name] + # Merge defaults with provided params + p = style.default_params() + p.update(params or {}) + return style.run(bgr, p) + +def available_styles() -> Dict[str, Dict[str, Any]]: + out = {} + for k, s in ALL_STYLES.items(): + out[k] = s.default_params() + return out diff --git a/src/r2s/preview.py b/src/r2s/preview.py new file mode 100644 index 0000000..06b5dac --- /dev/null +++ b/src/r2s/preview.py @@ -0,0 +1,22 @@ +from __future__ import annotations +from typing import Optional +import numpy as np +import cv2 + +def svg_to_bgr(svg_text: str, width_px: int, height_px: int) -> Optional[np.ndarray]: + """ + Returns a BGR uint8 image rendered from SVG, or None if rendering is unavailable. + Requires cairosvg. + """ + try: + import cairosvg # type: ignore + except Exception: + return None + + try: + png_bytes = cairosvg.svg2png(bytestring=svg_text.encode("utf-8"), output_width=width_px, output_height=height_px) + png = np.frombuffer(png_bytes, dtype=np.uint8) + img = cv2.imdecode(png, cv2.IMREAD_COLOR) + return img + except Exception: + return None diff --git a/src/r2s/styles/__init__.py b/src/r2s/styles/__init__.py new file mode 100644 index 0000000..b548ce5 --- /dev/null +++ b/src/r2s/styles/__init__.py @@ -0,0 +1,11 @@ +from .posterized import PosterizedStyle +from .lineart import LineArtStyle +from .woodcut import WoodcutStyle +from .pontillist import PontillistStyle + +ALL_STYLES = { + PosterizedStyle().name: PosterizedStyle(), + LineArtStyle().name: LineArtStyle(), + WoodcutStyle().name: WoodcutStyle(), + PontillistStyle().name: PontillistStyle(), +} diff --git a/src/r2s/styles/base.py b/src/r2s/styles/base.py new file mode 100644 index 0000000..fdcf71c --- /dev/null +++ b/src/r2s/styles/base.py @@ -0,0 +1,50 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Dict, Protocol, Tuple, Literal, Optional +import numpy as np + + +ParamType = Literal["int", "float", "bool", "str", "choice"] + +@dataclass(frozen=True) +class ParamSpec: + key: str + label: str + ptype: ParamType + default: Any + min: Optional[float] = None + max: Optional[float] = None + step: Optional[float] = None + choices: Optional[list[str]] = None + help: str = "" + + +@dataclass +class StyleResult: + # A raster preview image (uint8 BGR) of the processed stage + preview_bgr: np.ndarray + # SVG text output + svg: str + # Optional debug artifacts + meta: Dict[str, Any] + +class Style(Protocol): + name: str + + def default_params(self) -> Dict[str, Any]: ... + def param_specs(self) -> list[ParamSpec]: ... + def run(self, bgr: np.ndarray, params: Dict[str, Any]) -> StyleResult: ... + +def clamp_int(v: Any, lo: int, hi: int, default: int) -> int: + try: + iv = int(v) + return max(lo, min(hi, iv)) + except Exception: + return default + +def clamp_float(v: Any, lo: float, hi: float, default: float) -> float: + try: + fv = float(v) + return max(lo, min(hi, fv)) + except Exception: + return default diff --git a/src/r2s/styles/lineart.py b/src/r2s/styles/lineart.py new file mode 100644 index 0000000..c00beea --- /dev/null +++ b/src/r2s/styles/lineart.py @@ -0,0 +1,101 @@ +from __future__ import annotations +from typing import Any, Dict, List, Tuple +import numpy as np +import cv2 + +from .base import StyleResult, clamp_int, clamp_float +from ..svg_render import svg_from_stroked_contours + +from .base import ParamSpec + + + +class LineArtStyle: + name = "lineart" + + def default_params(self) -> Dict[str, Any]: + return { + "mode": "adaptive", # "adaptive" or "fixed" + "threshold": 128, # for fixed + "block_size": 31, # adaptive: odd >= 3 + "c": 7, # adaptive C + "invert": True, # ink vs paper + "min_area": 40, + "simplify": 1.0, + "stroke_width": 1.2, + "scale": 1.0, + } + + def param_specs(self) -> list[ParamSpec]: + d = self.default_params() + return [ + ParamSpec("mode", "Threshold mode", "choice", d["mode"], choices=["adaptive", "fixed"], + help="Adaptive: local threshold; Fixed: global threshold."), + ParamSpec("threshold", "Fixed threshold", "int", d["threshold"], 0, 255, 1, + help="Used only in Fixed mode."), + ParamSpec("block_size", "Adaptive block size", "int", d["block_size"], 3, 201, 2, + help="Odd integer; larger = smoother threshold."), + ParamSpec("c", "Adaptive C", "int", d["c"], -50, 50, 1, + help="Subtracted from local mean in adaptive threshold."), + ParamSpec("invert", "Invert (ink on white)", "bool", d["invert"]), + ParamSpec("min_area", "Min component area", "int", d["min_area"], 0, 1000000, 10), + ParamSpec("simplify", "Path simplify ε", "float", d["simplify"], 0.0, 20.0, 0.1), + ParamSpec("stroke_width", "Stroke width", "float", d["stroke_width"], 0.1, 50.0, 0.1), + ] + + def run(self, bgr: np.ndarray, params: Dict[str, Any]) -> StyleResult: + mode = str(params.get("mode", "adaptive")).lower().strip() + threshold = clamp_int(params.get("threshold"), 0, 255, 128) + block_size = clamp_int(params.get("block_size"), 3, 201, 31) + if block_size % 2 == 0: + block_size += 1 + c = clamp_int(params.get("c"), -50, 50, 7) + invert = bool(params.get("invert", True)) + min_area = clamp_int(params.get("min_area"), 0, 10_000_000, 40) + simplify = clamp_float(params.get("simplify"), 0.0, 20.0, 1.0) + stroke_width = clamp_float(params.get("stroke_width"), 0.0, 50.0, 1.2) + scale = clamp_float(params.get("scale"), 0.05, 10.0, 1.0) + + gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) + gray = cv2.GaussianBlur(gray, (5, 5), 0) + + if mode == "fixed": + _, bw = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY) + else: + bw = cv2.adaptiveThreshold( + gray, 255, + cv2.ADAPTIVE_THRESH_GAUSSIAN_C, + cv2.THRESH_BINARY, + block_size, + c + ) + + # Invert so "ink" is white for contour extraction if requested + if invert: + bw = 255 - bw + + # Clean specks and thicken slightly for nicer tracing + bw = cv2.morphologyEx(bw, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1) + bw = cv2.morphologyEx(bw, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8), iterations=1) + + contours, _ = cv2.findContours(bw, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + good = [] + for cnt in contours: + area = cv2.contourArea(cnt) + if area >= min_area: + good.append(cnt) + + # Preview: show ink as black on white + preview = 255 - bw if invert else bw + preview_bgr = cv2.cvtColor(preview, cv2.COLOR_GRAY2BGR) + + svg = svg_from_stroked_contours( + width_px=bgr.shape[1], + height_px=bgr.shape[0], + contours=good, + simplify_eps=simplify, + stroke_width=stroke_width, + scale=scale, + ) + meta = {"contours": len(good)} + return StyleResult(preview_bgr=preview_bgr, svg=svg, meta=meta) diff --git a/src/r2s/styles/pontillist.py b/src/r2s/styles/pontillist.py new file mode 100644 index 0000000..83dfaeb --- /dev/null +++ b/src/r2s/styles/pontillist.py @@ -0,0 +1,231 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Tuple +import numpy as np +import cv2 + +from .base import StyleResult, clamp_int, clamp_float, ParamSpec +from ..svg_render import svg_from_pontillist + + +class PontillistStyle: + name = "pontillist" + + def default_params(self) -> Dict[str, Any]: + return { + # Palette / segmentation (posterized backbone) + "n_colors": 6, # 2..20 + "blur": 1, # 0..10 (median blur) + "min_region_area": 200, # drop tiny regions for boundary + color stability + + # Tone mapping + "clahe": False, + "clahe_clip": 2.0, + "clahe_grid": 8, + "tone_gamma": 1.6, # higher -> more dots in darks, fewer in lights + + # Stippling / dots + "grid_step": 8.0, # base sampling pitch (px); smaller => more candidates + "jitter": 0.8, # 0..1 fraction of grid_step + "accept_scale": 1.0, # scales acceptance probability; >1 more dots, <1 fewer + "max_points": 30000, # hard cap for performance / SVG size + "dot_radius_min": 0.6, + "dot_radius_max": 2.2, + "radius_mode": "by_tone", # "fixed" | "by_tone" + "fixed_radius": 1.2, + + # Boundaries + "draw_boundaries": True, + "boundary_width": 0.8, + + # Output scale (keep global if you prefer; otherwise include) + "scale": 1.0, + + # Determinism + "seed": 1, + } + + def param_specs(self) -> list[ParamSpec]: + d = self.default_params() + return [ + ParamSpec("n_colors", "Number of colors", "int", d["n_colors"], min=2, max=20, step=1, + help="Palette size (k-means in Lab). Dots use region palette colors."), + ParamSpec("blur", "Pre-blur radius", "int", d["blur"], min=0, max=10, step=1, + help="Median blur before segmentation; reduces speckle regions."), + ParamSpec("min_region_area", "Min region area", "int", d["min_region_area"], min=0, max=1_000_000, step=50, + help="Regions smaller than this are ignored for boundary drawing and dot coloring stability."), + + ParamSpec("clahe", "Use CLAHE contrast", "bool", d["clahe"], + help="Local contrast enhancement before tone mapping."), + ParamSpec("clahe_clip", "CLAHE clip limit", "float", d["clahe_clip"], min=0.1, max=10.0, step=0.1, + help="Higher values increase local contrast but can amplify noise."), + ParamSpec("clahe_grid", "CLAHE grid size", "int", d["clahe_grid"], min=2, max=64, step=1, + help="Tile grid size for CLAHE."), + ParamSpec("tone_gamma", "Tone gamma", "float", d["tone_gamma"], min=0.3, max=4.0, step=0.05, + help="Controls dot density vs tone. Higher => denser in dark regions."), + + ParamSpec("grid_step", "Grid step (px)", "float", d["grid_step"], min=2.0, max=50.0, step=0.5, + help="Candidate sampling pitch. Smaller => more candidates and more dots (slower)."), + ParamSpec("jitter", "Jitter (0..1)", "float", d["jitter"], min=0.0, max=1.0, step=0.05, + help="Random jitter fraction of grid step to avoid a mechanical pattern."), + ParamSpec("accept_scale", "Acceptance scale", "float", d["accept_scale"], min=0.1, max=3.0, step=0.05, + help="Scales acceptance probability. >1 yields more dots overall."), + ParamSpec("max_points", "Max points", "int", d["max_points"], min=1000, max=500_000, step=1000, + help="Hard cap on number of dots for performance and SVG size."), + + ParamSpec("radius_mode", "Radius mode", "choice", d["radius_mode"], choices=["fixed", "by_tone"], + help="Fixed: all dots same size. By_tone: radius varies with tone."), + ParamSpec("fixed_radius", "Fixed radius", "float", d["fixed_radius"], min=0.1, max=10.0, step=0.1, + help="Dot radius when Radius mode is fixed."), + ParamSpec("dot_radius_min", "Min radius", "float", d["dot_radius_min"], min=0.1, max=10.0, step=0.1, + help="Minimum dot radius for by-tone mode."), + ParamSpec("dot_radius_max", "Max radius", "float", d["dot_radius_max"], min=0.1, max=20.0, step=0.1, + help="Maximum dot radius for by-tone mode."), + + ParamSpec("draw_boundaries", "Draw boundaries", "bool", d["draw_boundaries"], + help="Draw posterized region boundaries as strokes on top."), + ParamSpec("boundary_width", "Boundary width", "float", d["boundary_width"], min=0.1, max=10.0, step=0.1, + help="Stroke width for region boundaries."), + ParamSpec("seed", "Random seed", "int", d["seed"], min=0, max=2_000_000_000, step=1, + help="Seed for deterministic dot placement."), + ] + + def run(self, bgr: np.ndarray, params: Dict[str, Any]) -> StyleResult: + # --- Parse params + n_colors = clamp_int(params.get("n_colors"), 2, 20, 6) + blur = clamp_int(params.get("blur"), 0, 10, 1) + min_region_area = clamp_int(params.get("min_region_area"), 0, 10_000_000, 200) + + clahe_on = bool(params.get("clahe", False)) + clahe_clip = clamp_float(params.get("clahe_clip"), 0.1, 10.0, 2.0) + clahe_grid = clamp_int(params.get("clahe_grid"), 2, 64, 8) + tone_gamma = clamp_float(params.get("tone_gamma"), 0.3, 4.0, 1.6) + + grid_step = clamp_float(params.get("grid_step"), 2.0, 50.0, 8.0) + jitter = clamp_float(params.get("jitter"), 0.0, 1.0, 0.8) + accept_scale = clamp_float(params.get("accept_scale"), 0.1, 3.0, 1.0) + max_points = clamp_int(params.get("max_points"), 1000, 500_000, 30000) + + radius_mode = str(params.get("radius_mode", "by_tone")).strip().lower() + fixed_radius = clamp_float(params.get("fixed_radius"), 0.1, 10.0, 1.2) + rmin = clamp_float(params.get("dot_radius_min"), 0.1, 10.0, 0.6) + rmax = clamp_float(params.get("dot_radius_max"), 0.1, 20.0, 2.2) + if rmax < rmin: + rmax, rmin = rmin, rmax + + draw_boundaries = bool(params.get("draw_boundaries", True)) + boundary_width = clamp_float(params.get("boundary_width"), 0.1, 10.0, 0.8) + + scale = clamp_float(params.get("scale"), 0.05, 10.0, 1.0) + seed = clamp_int(params.get("seed"), 0, 2_000_000_000, 1) + + # --- Preprocess for segmentation + work = bgr.copy() + if blur > 0: + work = cv2.medianBlur(work, 2 * blur + 1) + + # K-means in Lab (same as posterized) + lab = cv2.cvtColor(work, cv2.COLOR_BGR2LAB) + Z = lab.reshape((-1, 3)).astype(np.float32) + criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 1.0) + _, labels, centers = cv2.kmeans(Z, n_colors, None, criteria, 3, cv2.KMEANS_PP_CENTERS) + centers = centers.astype(np.uint8) + labels2d = labels.reshape((lab.shape[0], lab.shape[1])) + + # Palette in BGR + palette_bgr = cv2.cvtColor(centers.reshape(1, -1, 3), cv2.COLOR_LAB2BGR).reshape(-1, 3) + + # Optionally compute region boundaries (contours) + boundaries: List[np.ndarray] = [] + if draw_boundaries: + for k in range(n_colors): + mask = (labels2d == k).astype(np.uint8) * 255 + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3, 3), np.uint8), iterations=1) + cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + for c in cnts: + if cv2.contourArea(c) >= min_region_area: + boundaries.append(c) + + # --- Tone map for acceptance + gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) + gray = cv2.GaussianBlur(gray, (5, 5), 0) + + if clahe_on: + clahe = cv2.createCLAHE(clipLimit=float(clahe_clip), tileGridSize=(clahe_grid, clahe_grid)) + gray = clahe.apply(gray) + + # Normalize tone to [0,1], where darkness drives dot probability + tone = gray.astype(np.float32) / 255.0 + demand = (1.0 - tone) ** float(tone_gamma) # 0..1, higher in darks + + # --- Candidate points on jittered grid + rng = np.random.default_rng(seed) + h, w = gray.shape + + xs = np.arange(0.0, w, grid_step, dtype=np.float32) + ys = np.arange(0.0, h, grid_step, dtype=np.float32) + + dots: List[Tuple[float, float, float, Tuple[int, int, int]]] = [] + + # Shuffle traversal so max_points cuts fairly + coords = [(float(x), float(y)) for y in ys for x in xs] + rng.shuffle(coords) + + for x0, y0 in coords: + # jitter around grid point + jx = (rng.random() * 2.0 - 1.0) * jitter * grid_step * 0.5 + jy = (rng.random() * 2.0 - 1.0) * jitter * grid_step * 0.5 + x = x0 + jx + y = y0 + jy + xi = int(round(x)) + yi = int(round(y)) + if xi < 0 or xi >= w or yi < 0 or yi >= h: + continue + + p = float(demand[yi, xi]) * float(accept_scale) + if p <= 0: + continue + if rng.random() > min(1.0, p): + continue + + # Determine dot radius + if radius_mode == "fixed": + r = fixed_radius + else: + # Darker => larger dots (common stipple aesthetic); invert if desired + r = rmin + (rmax - rmin) * float(demand[yi, xi]) + + # Dot color from posterized region label + k = int(labels2d[yi, xi]) + bgrc = palette_bgr[k] + color = (int(bgrc[2]), int(bgrc[1]), int(bgrc[0])) # SVG expects RGB tuple + + dots.append((float(x), float(y), float(r), color)) + + if len(dots) >= max_points: + break + + # --- Preview raster: render dots over white (fast preview) + preview = np.full((h, w, 3), 255, dtype=np.uint8) + for x, y, r, (rr, gg, bb) in dots: + cv2.circle(preview, (int(round(x)), int(round(y))), int(max(1, round(r))), (bb, gg, rr), -1) + + if draw_boundaries and boundaries: + cv2.drawContours(preview, boundaries, -1, (0, 0, 0), 1) + + svg = svg_from_pontillist( + width_px=w, + height_px=h, + dots=dots, # (x,y,r,(r,g,b)) + boundaries=boundaries if draw_boundaries else [], + boundary_width=boundary_width, + scale=scale, + ) + + meta = { + "dots": len(dots), + "boundaries": len(boundaries), + "n_colors": n_colors, + } + return StyleResult(preview_bgr=preview, svg=svg, meta=meta) + diff --git a/src/r2s/styles/posterized.py b/src/r2s/styles/posterized.py new file mode 100644 index 0000000..2f9d1e8 --- /dev/null +++ b/src/r2s/styles/posterized.py @@ -0,0 +1,153 @@ +from __future__ import annotations +from typing import Any, Dict, List, Tuple +import numpy as np +import cv2 + +from .base import StyleResult, clamp_int, clamp_float +from ..svg_render import svg_from_filled_regions + +from .base import ParamSpec + + + + +class PosterizedStyle: + name = "posterized" + + def default_params(self) -> Dict[str, Any]: + return { + "n_colors": 6, # 2..20 + "blur": 1, # 0..10 (median blur kernel = 2*blur+1) + "min_area": 80, # remove tiny regions + "simplify": 1.2, # polygon approx epsilon (pixels) + "add_stroke": False, + "stroke_width": 0.8, + "scale": 1.0, # output scaling + } + + def param_specs(self) -> list[ParamSpec]: + d = self.default_params() + return [ + ParamSpec( + "n_colors", + "Number of colors", + "int", + d["n_colors"], + min=2, + max=20, + step=1, + help="Number of color clusters used for posterization (k-means)." + ), + ParamSpec( + "blur", + "Pre-blur radius", + "int", + d["blur"], + min=0, + max=10, + step=1, + help="Median blur radius applied before color quantization. " + "Helps suppress noise and small color speckles." + ), + ParamSpec( + "min_area", + "Minimum region area", + "int", + d["min_area"], + min=0, + max=1_000_000, + step=10, + help="Discard connected regions smaller than this area (in pixels)." + ), + ParamSpec( + "simplify", + "Path simplify ε", + "float", + d["simplify"], + min=0.0, + max=20.0, + step=0.1, + help="Douglas–Peucker tolerance for polygon simplification. " + "Higher values produce fewer vertices." + ), + ParamSpec( + "add_stroke", + "Outline regions", + "bool", + d["add_stroke"], + help="Add a black stroke around each filled region." + ), + ParamSpec( + "stroke_width", + "Stroke width", + "float", + d["stroke_width"], + min=0.1, + max=20.0, + step=0.1, + help="Stroke width for region outlines (if enabled)." + ), + ] + + + def run(self, bgr: np.ndarray, params: Dict[str, Any]) -> StyleResult: + n_colors = clamp_int(params.get("n_colors"), 2, 20, 6) + blur = clamp_int(params.get("blur"), 0, 10, 1) + min_area = clamp_int(params.get("min_area"), 0, 10_000_000, 80) + simplify = clamp_float(params.get("simplify"), 0.0, 20.0, 1.2) + add_stroke = bool(params.get("add_stroke", False)) + stroke_width = clamp_float(params.get("stroke_width"), 0.0, 20.0, 0.8) + scale = clamp_float(params.get("scale"), 0.05, 10.0, 1.0) + + work = bgr.copy() + if blur > 0: + k = 2 * blur + 1 + work = cv2.medianBlur(work, k) + + # K-means quantization in Lab (more perceptual) + lab = cv2.cvtColor(work, cv2.COLOR_BGR2LAB) + Z = lab.reshape((-1, 3)).astype(np.float32) + + # Criteria: (type, max_iter, epsilon) + criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 1.0) + attempts = 3 + _, labels, centers = cv2.kmeans(Z, n_colors, None, criteria, attempts, cv2.KMEANS_PP_CENTERS) + centers = centers.astype(np.uint8) + q = centers[labels.flatten()].reshape(lab.shape) + q_bgr = cv2.cvtColor(q, cv2.COLOR_LAB2BGR) + + # Extract regions by each label value (connected components) + labels2d = labels.reshape((lab.shape[0], lab.shape[1])) + regions: List[Tuple[np.ndarray, Tuple[int, int, int]]] = [] + + # Compute BGR palette per label for fill + palette_bgr = cv2.cvtColor(centers.reshape(1, -1, 3), cv2.COLOR_LAB2BGR).reshape(-1, 3) + + for k in range(n_colors): + mask = (labels2d == k).astype(np.uint8) * 255 + # Clean tiny holes/noise + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1) + contours, _hier = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + + color = tuple(int(c) for c in palette_bgr[k]) + for cnt in contours: + area = cv2.contourArea(cnt) + if area < min_area: + continue + regions.append((cnt, color)) + + svg = svg_from_filled_regions( + width_px=bgr.shape[1], + height_px=bgr.shape[0], + regions=regions, + simplify_eps=simplify, + add_stroke=add_stroke, + stroke_width=stroke_width, + scale=scale, + ) + + meta = { + "n_colors": n_colors, + "regions": len(regions), + } + return StyleResult(preview_bgr=q_bgr, svg=svg, meta=meta) diff --git a/src/r2s/styles/woodcut.py b/src/r2s/styles/woodcut.py new file mode 100644 index 0000000..f6ea822 --- /dev/null +++ b/src/r2s/styles/woodcut.py @@ -0,0 +1,298 @@ +from __future__ import annotations +from typing import Any, Dict, List, Tuple +import numpy as np +import cv2 + +from .base import StyleResult, clamp_int, clamp_float +from ..svg_render import svg_from_woodcut + +from .base import ParamSpec + + + +class WoodcutStyle: + name = "woodcut" + + def param_specs(self) -> list[ParamSpec]: + d = self.default_params() + return [ + # --- Tone / contrast + ParamSpec( + "clahe", "Use CLAHE contrast", "bool", d["clahe"], + help="Applies local contrast enhancement (CLAHE) before tone banding." + ), + ParamSpec( + "clahe_clip", "CLAHE clip limit", "float", d["clahe_clip"], + min=0.1, max=10.0, step=0.1, + help="Higher values increase local contrast but can amplify noise." + ), + ParamSpec( + "clahe_grid", "CLAHE grid size", "int", d["clahe_grid"], + min=2, max=64, step=1, + help="Tile grid size for CLAHE; larger values use coarser local regions." + ), + ParamSpec( + "tone_bands", "Tone bands", "int", d["tone_bands"], + min=2, max=12, step=1, + help="Number of tone bands for hatching. More bands gives more tonal nuance but increases linework." + ), + + # --- Edges / keylines + ParamSpec( + "edge_low", "Canny low threshold", "int", d["edge_low"], + min=0, max=255, step=1, + help="Lower hysteresis threshold for Canny edge detection." + ), + ParamSpec( + "edge_high", "Canny high threshold", "int", d["edge_high"], + min=0, max=255, step=1, + help="Upper hysteresis threshold for Canny edge detection." + ), + ParamSpec( + "edge_dilate", "Edge dilation", "int", d["edge_dilate"], + min=0, max=5, step=1, + help="Dilate detected edges to create bolder keylines." + ), + ParamSpec( + "edge_min_area", "Min edge component area", "int", d["edge_min_area"], + min=0, max=1_000_000, step=10, + help="Discard small edge fragments below this area (in pixels)." + ), + ParamSpec( + "edge_simplify", "Edge path simplify ε", "float", d["edge_simplify"], + min=0.0, max=20.0, step=0.1, + help="Douglas–Peucker tolerance for simplifying edge polylines." + ), + ParamSpec( + "edge_stroke_width", "Edge stroke width", "float", d["edge_stroke_width"], + min=0.1, max=20.0, step=0.1, + help="Stroke width for keyline/edge layer." + ), + + # --- Hatching + ParamSpec( + "hatch_base_spacing", "Hatch base spacing (px)", "float", d["hatch_base_spacing"], + min=2.0, max=200.0, step=1.0, + help="Base spacing between hatch lines for the lightest hatched band (pixels)." + ), + ParamSpec( + "hatch_spacing_factor", "Hatch spacing factor", "float", d["hatch_spacing_factor"], + min=0.2, max=0.95, step=0.01, + help="Per-band multiplier applied to spacing; darker bands get tighter spacing." + ), + ParamSpec( + "hatch_angle_deg", "Hatch angle (deg)", "float", d["hatch_angle_deg"], + min=-89.0, max=89.0, step=1.0, + help="Angle of hatch lines. Try -25, 25, or 45 degrees for woodcut-like texture." + ), + ParamSpec( + "hatch_stroke_width", "Hatch stroke width", "float", d["hatch_stroke_width"], + min=0.1, max=20.0, step=0.1, + help="Stroke width for hatch lines." + ), + ParamSpec( + "hatch_min_seg_len", "Min hatch segment length", "float", d["hatch_min_seg_len"], + min=0.0, max=200.0, step=1.0, + help="Drop hatch dashes shorter than this (helps laser friendliness)." + ), + + # If you prefer per-style scale instead of global: + # ParamSpec("scale", "Output scale", "float", d["scale"], min=0.05, max=10.0, step=0.05, + # help="Scale factor applied to SVG output size.") + ] + + + def default_params(self) -> Dict[str, Any]: + return { + # Overall + "scale": 1.0, + + # Preprocess / tone + "clahe": True, + "clahe_clip": 2.0, + "clahe_grid": 8, # 4..16 typical + "tone_bands": 5, # 2..9 typical + + # Edges + "edge_low": 40, + "edge_high": 120, + "edge_dilate": 1, # 0..3 + "edge_simplify": 1.0, + "edge_stroke_width": 1.4, + "edge_min_area": 20, + + # Hatching + "hatch_base_spacing": 18.0, # px; darker bands get smaller spacing + "hatch_spacing_factor": 0.70, # per band multiplier + "hatch_angle_deg": -25.0, + "hatch_stroke_width": 1.0, + "hatch_simplify": 0.0, # segments are already simple + "hatch_min_seg_len": 6.0, # drop tiny hatch dashes + } + + def run(self, bgr: np.ndarray, params: Dict[str, Any]) -> StyleResult: + scale = clamp_float(params.get("scale"), 0.05, 10.0, 1.0) + + clahe_on = bool(params.get("clahe", True)) + clahe_clip = clamp_float(params.get("clahe_clip"), 0.1, 10.0, 2.0) + clahe_grid = clamp_int(params.get("clahe_grid"), 2, 64, 8) + + tone_bands = clamp_int(params.get("tone_bands"), 2, 12, 5) + + edge_low = clamp_int(params.get("edge_low"), 0, 255, 40) + edge_high = clamp_int(params.get("edge_high"), 0, 255, 120) + edge_dilate = clamp_int(params.get("edge_dilate"), 0, 5, 1) + edge_simplify = clamp_float(params.get("edge_simplify"), 0.0, 20.0, 1.0) + edge_stroke_width = clamp_float(params.get("edge_stroke_width"), 0.1, 50.0, 1.4) + edge_min_area = clamp_int(params.get("edge_min_area"), 0, 10_000_000, 20) + + hatch_base_spacing = clamp_float(params.get("hatch_base_spacing"), 2.0, 200.0, 18.0) + hatch_spacing_factor = clamp_float(params.get("hatch_spacing_factor"), 0.2, 0.95, 0.70) + hatch_angle_deg = clamp_float(params.get("hatch_angle_deg"), -89.0, 89.0, -25.0) + hatch_stroke_width = clamp_float(params.get("hatch_stroke_width"), 0.1, 50.0, 1.0) + hatch_min_seg_len = clamp_float(params.get("hatch_min_seg_len"), 0.0, 10_000.0, 6.0) + + # ---- preprocess to grayscale + gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY) + gray = cv2.GaussianBlur(gray, (5, 5), 0) + + if clahe_on: + clahe = cv2.createCLAHE(clipLimit=float(clahe_clip), tileGridSize=(clahe_grid, clahe_grid)) + gray2 = clahe.apply(gray) + else: + gray2 = gray + + # ---- tone quantization into bands (0..tone_bands-1), dark=high index + # Using uniform bins over 0..255 after contrast step + bins = np.linspace(0, 256, tone_bands + 1, dtype=np.int32) + band_idx = np.digitize(gray2, bins) - 1 + band_idx = np.clip(band_idx, 0, tone_bands - 1) + + # preview: show quantized tones as grayscale blocks + preview_levels = (band_idx * (255 // max(1, tone_bands - 1))).astype(np.uint8) + preview_bgr = cv2.cvtColor(preview_levels, cv2.COLOR_GRAY2BGR) + + # ---- edges for outline layer + edges = cv2.Canny(gray2, edge_low, edge_high) + if edge_dilate > 0: + k = np.ones((3, 3), np.uint8) + edges = cv2.dilate(edges, k, iterations=edge_dilate) + edges = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8), iterations=1) + + contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + edge_contours = [c for c in contours if cv2.contourArea(c) >= edge_min_area] + + # ---- hatch segments per dark band + # We hatch from darker to lighter, excluding the lightest band by default + # For each band threshold, build a mask for pixels at/above that darkness. + hatch_layers: List[Tuple[str, List[Tuple[Tuple[float, float], Tuple[float, float]]]]] = [] + h, w = gray2.shape + + for k in range(1, tone_bands): # skip lightest band 0 + # mask where band >= k (i.e., at least this dark) + mask = (band_idx >= k).astype(np.uint8) * 255 + # clean specks + mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((3,3), np.uint8), iterations=1) + # spacing: darker => tighter + spacing = hatch_base_spacing * (hatch_spacing_factor ** (k - 1)) + + segs = _hatch_segments_from_mask( + mask=mask, + angle_deg=hatch_angle_deg, + spacing=float(spacing), + min_seg_len=float(hatch_min_seg_len), + ) + hatch_layers.append((f"hatch_band_{k}", segs)) + + svg = svg_from_woodcut( + width_px=bgr.shape[1], + height_px=bgr.shape[0], + edge_contours=edge_contours, + edge_simplify_eps=edge_simplify, + edge_stroke_width=edge_stroke_width, + hatch_layers=hatch_layers, + hatch_stroke_width=hatch_stroke_width, + scale=scale, + ) + + meta = { + "tone_bands": tone_bands, + "edge_contours": len(edge_contours), + "hatch_layers": [(name, len(segs)) for name, segs in hatch_layers], + } + return StyleResult(preview_bgr=preview_bgr, svg=svg, meta=meta) + +def _hatch_segments_from_mask( + mask: np.ndarray, + angle_deg: float, + spacing: float, + min_seg_len: float, +) -> List[Tuple[Tuple[float, float], Tuple[float, float]]]: + """ + Generate hatch line segments clipped to the white area of a mask (uint8 0/255). + No polygon clipping library required: we raster-sample along each hatch line + and emit contiguous 'inside' runs as segments. + """ + h, w = mask.shape + theta = np.deg2rad(angle_deg) + v = np.array([np.cos(theta), np.sin(theta)], dtype=np.float32) # along-line + n = np.array([-v[1], v[0]], dtype=np.float32) # normal + + # Determine range of offsets along normal that cover the image + corners = np.array([[0,0], [w-1,0], [0,h-1], [w-1,h-1]], dtype=np.float32) + projs = corners @ n + o_min, o_max = float(projs.min()), float(projs.max()) + + if spacing <= 0.5: + spacing = 0.5 + + offsets = np.arange(o_min - spacing, o_max + spacing, spacing, dtype=np.float32) + + segs: List[Tuple[Tuple[float, float], Tuple[float, float]]] = [] + + # sample step along line (pixel-ish). Lower is more accurate but slower. + step = 1.0 + + for o in offsets: + # A point on the line: p0 = o*n + p0 = n * o + + # Intersect infinite line with image bounding box by marching along v over a conservative length + # Choose t range big enough to cross whole image. + L = float(np.hypot(w, h)) + 5.0 + t0, t1 = -L, L + + # Sample points along the line + ts = np.arange(t0, t1 + step, step, dtype=np.float32) + xs = p0[0] + ts * v[0] + ys = p0[1] + ts * v[1] + + # Convert to integer pixel coords and filter within bounds + xi = np.round(xs).astype(np.int32) + yi = np.round(ys).astype(np.int32) + inside = (xi >= 0) & (xi < w) & (yi >= 0) & (yi < h) + if not np.any(inside): + continue + + xi2 = xi[inside] + yi2 = yi[inside] + + # inside-mask boolean along this hatch line + on = mask[yi2, xi2] > 0 + if on.size == 0: + continue + + # find contiguous runs of True + # run starts where on[i] and (i==0 or not on[i-1]); run ends where on[i] and (i==last or not on[i+1]) + idx = np.arange(on.size) + starts = idx[on & np.r_[True, ~on[:-1]]] + ends = idx[on & np.r_[~on[1:], True]] + + for s, e in zip(starts, ends): + x0, y0 = float(xi2[s]), float(yi2[s]) + x1, y1 = float(xi2[e]), float(yi2[e]) + if (x1 - x0)**2 + (y1 - y0)**2 < (min_seg_len ** 2): + continue + segs.append(((x0, y0), (x1, y1))) + + return segs diff --git a/src/r2s/svg_render.py b/src/r2s/svg_render.py new file mode 100644 index 0000000..7526fbb --- /dev/null +++ b/src/r2s/svg_render.py @@ -0,0 +1,144 @@ +from __future__ import annotations +from typing import List, Tuple +import svgwrite +import cv2 +import numpy as np + +def svg_from_pontillist( + width_px: int, + height_px: int, + dots: List[Tuple[float, float, float, Tuple[int, int, int]]], # (x,y,r,(r,g,b)) + boundaries: List[np.ndarray], + boundary_width: float, + scale: float = 1.0, +) -> str: + w = int(round(width_px * scale)) + h = int(round(height_px * scale)) + dwg = svgwrite.Drawing(size=(f"{w}px", f"{h}px"), profile="tiny") + dwg.viewbox(0, 0, width_px, height_px) + + # Dots (filled circles) + g_dots = dwg.g(id="dots") + for x, y, r, (rr, gg, bb) in dots: + g_dots.add(dwg.circle(center=(x, y), r=r, fill=svgwrite.rgb(rr, gg, bb), stroke="none")) + dwg.add(g_dots) + + # Boundaries on top (optional) + if boundaries: + g_b = dwg.g(id="boundaries", fill="none", stroke=svgwrite.rgb(0, 0, 0)) + g_b.attribs["stroke-width"] = boundary_width + g_b.attribs["stroke-linejoin"] = "round" + g_b.attribs["stroke-linecap"] = "round" + for cnt in boundaries: + pts = [(float(p[0][0]), float(p[0][1])) for p in cnt] + if len(pts) >= 2: + g_b.add(dwg.polyline(points=pts)) + dwg.add(g_b) + + return dwg.tostring() + +def _approx_contour(cnt: np.ndarray, eps: float) -> np.ndarray: + if eps <= 0: + return cnt + return cv2.approxPolyDP(cnt, eps, True) + +def svg_from_filled_regions( + width_px: int, + height_px: int, + regions: List[Tuple[np.ndarray, Tuple[int, int, int]]], + simplify_eps: float, + add_stroke: bool, + stroke_width: float, + scale: float = 1.0, +) -> str: + w = int(round(width_px * scale)) + h = int(round(height_px * scale)) + dwg = svgwrite.Drawing(size=(f"{w}px", f"{h}px"), profile="tiny") + dwg.viewbox(0, 0, width_px, height_px) + + grp = dwg.g(id="fills") + for cnt, color_bgr in regions: + poly = _approx_contour(cnt, simplify_eps) + pts = [(float(p[0][0]), float(p[0][1])) for p in poly] + b, g, r = color_bgr + fill = svgwrite.rgb(r, g, b) + stroke = svgwrite.rgb(0, 0, 0) if add_stroke else "none" + grp.add(dwg.polygon( + points=pts, + fill=fill, + stroke=stroke, + stroke_width=stroke_width if add_stroke else 0, + stroke_linejoin="round", + )) + dwg.add(grp) + return dwg.tostring() + +def svg_from_stroked_contours( + width_px: int, + height_px: int, + contours: List[np.ndarray], + simplify_eps: float, + stroke_width: float, + scale: float = 1.0, +) -> str: + w = int(round(width_px * scale)) + h = int(round(height_px * scale)) + dwg = svgwrite.Drawing(size=(f"{w}px", f"{h}px"), profile="tiny") + dwg.viewbox(0, 0, width_px, height_px) + + grp = dwg.g(id="strokes", fill="none", stroke=svgwrite.rgb(0, 0, 0)) + grp.attribs["stroke-width"] = stroke_width + grp.attribs["stroke-linejoin"] = "round" + grp.attribs["stroke-linecap"] = "round" + + for cnt in contours: + poly = _approx_contour(cnt, simplify_eps) + pts = [(float(p[0][0]), float(p[0][1])) for p in poly] + if len(pts) < 2: + continue + # Use a polyline; if you want closed loops, use polygon with fill="none" + grp.add(dwg.polyline(points=pts)) + dwg.add(grp) + return dwg.tostring() + + + +def svg_from_woodcut( + width_px: int, + height_px: int, + edge_contours: List[np.ndarray], + edge_simplify_eps: float, + edge_stroke_width: float, + hatch_layers: List[Tuple[str, List[Tuple[Tuple[float, float], Tuple[float, float]]]]], + hatch_stroke_width: float, + scale: float = 1.0, +) -> str: + w = int(round(width_px * scale)) + h = int(round(height_px * scale)) + dwg = svgwrite.Drawing(size=(f"{w}px", f"{h}px"), profile="tiny") + dwg.viewbox(0, 0, width_px, height_px) + + # Hatch layers first (so edges sit on top) + for layer_name, segs in hatch_layers: + grp = dwg.g(id=layer_name, fill="none", stroke=svgwrite.rgb(0, 0, 0)) + grp.attribs["stroke-width"] = hatch_stroke_width + grp.attribs["stroke-linecap"] = "round" + for (x0, y0), (x1, y1) in segs: + grp.add(dwg.line(start=(x0, y0), end=(x1, y1))) + dwg.add(grp) + + # Edge contours on top + edge_grp = dwg.g(id="edges", fill="none", stroke=svgwrite.rgb(0, 0, 0)) + edge_grp.attribs["stroke-width"] = edge_stroke_width + edge_grp.attribs["stroke-linejoin"] = "round" + edge_grp.attribs["stroke-linecap"] = "round" + + for cnt in edge_contours: + poly = cv2.approxPolyDP(cnt, float(edge_simplify_eps), True) if edge_simplify_eps > 0 else cnt + pts = [(float(p[0][0]), float(p[0][1])) for p in poly] + if len(pts) < 2: + continue + edge_grp.add(dwg.polyline(points=pts)) + + dwg.add(edge_grp) + return dwg.tostring() diff --git a/src/r2s/ui/app.py b/src/r2s/ui/app.py new file mode 100644 index 0000000..88271f5 --- /dev/null +++ b/src/r2s/ui/app.py @@ -0,0 +1,10 @@ +from __future__ import annotations +import sys +from PySide6.QtWidgets import QApplication +from .main_window import MainWindow + +def run_app() -> None: + app = QApplication(sys.argv) + w = MainWindow() + w.show() + sys.exit(app.exec()) diff --git a/src/r2s/ui/auto_explore.py b/src/r2s/ui/auto_explore.py new file mode 100644 index 0000000..3541a64 --- /dev/null +++ b/src/r2s/ui/auto_explore.py @@ -0,0 +1,121 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +from PySide6.QtCore import Qt +from PySide6.QtGui import QPixmap +from PySide6.QtWidgets import ( + QDialog, QGridLayout, QHBoxLayout, QLabel, QPushButton, QVBoxLayout +) + +@dataclass +class Candidate: + label: str + params: Dict[str, Any] + preview_pix: QPixmap + +class ClickableLabel(QLabel): + def __init__(self, idx: int, on_click): + super().__init__() + self._idx = idx + self._on_click = on_click + self.setCursor(Qt.CursorShape.PointingHandCursor) + + def mousePressEvent(self, event): + self._on_click(self._idx) + +class AutoExploreDialog(QDialog): + """ + 3x3 candidate chooser. + Center cell is the current/selected params. + Surrounding cells are candidates. + """ + def __init__(self, parent, candidates8: List[Candidate], current: Candidate): + super().__init__(parent) + self.setWindowTitle("Auto-explore") + self.setModal(True) + + if len(candidates8) != 8: + raise ValueError("AutoExploreDialog requires exactly 8 candidates") + + self._orig_current = current + self._selected = current + + self._cells: List[ClickableLabel] = [] + self._cell_candidates: List[Optional[Candidate]] = [None] * 9 + + # Map candidates into 3x3 positions with center=4. + positions = [0, 1, 2, 3, 5, 6, 7, 8] + for pos, cand in zip(positions, candidates8): + self._cell_candidates[pos] = cand + self._cell_candidates[4] = current + + grid = QGridLayout() + grid.setSpacing(8) + + for i in range(9): + lbl = ClickableLabel(i, self._on_pick) + lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) + lbl.setMinimumSize(220, 160) + lbl.setStyleSheet("border: 1px solid #999;") + cand = self._cell_candidates[i] + if cand is not None: + lbl.setPixmap(cand.preview_pix.scaled( + lbl.size(), + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation + )) + lbl.setToolTip(cand.label) + self._cells.append(lbl) + grid.addWidget(lbl, i // 3, i % 3) + + self._info = QLabel(self._selected.label) + self._info.setWordWrap(True) + + btn_ok = QPushButton("OK") + btn_cancel = QPushButton("Cancel") + btn_ok.clicked.connect(self.accept) + btn_cancel.clicked.connect(self.reject) + + btn_row = QHBoxLayout() + btn_row.addStretch(1) + btn_row.addWidget(btn_cancel) + btn_row.addWidget(btn_ok) + + root = QVBoxLayout() + root.addLayout(grid) + root.addWidget(self._info) + root.addLayout(btn_row) + self.setLayout(root) + + self._highlight_selected() + + def _on_pick(self, idx: int) -> None: + cand = self._cell_candidates[idx] + if cand is None: + return + # Selecting a candidate updates center cell and selection + self._selected = cand + self._cell_candidates[4] = cand + self._cells[4].setPixmap(cand.preview_pix.scaled( + self._cells[4].size(), + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation + )) + self._info.setText(cand.label) + self._highlight_selected() + + def _highlight_selected(self) -> None: + # Simple highlight: center border thicker + for i, lbl in enumerate(self._cells): + if i == 4: + lbl.setStyleSheet("border: 3px solid #2a7;") + else: + lbl.setStyleSheet("border: 1px solid #999;") + + def selected_params(self) -> Dict[str, Any]: + return dict(self._selected.params) + + def selected_label(self) -> str: + return self._selected.label diff --git a/src/r2s/ui/main_window.py b/src/r2s/ui/main_window.py new file mode 100644 index 0000000..0350c12 --- /dev/null +++ b/src/r2s/ui/main_window.py @@ -0,0 +1,774 @@ +from __future__ import annotations +import json +from pathlib import Path +from typing import Any, Dict, Optional + +import numpy as np +import cv2 + +from PySide6.QtCore import Qt, QTimer +from PySide6.QtGui import QImage, QPixmap +from PySide6.QtWidgets import ( + QFileDialog, QHBoxLayout, QLabel, QMainWindow, QPushButton, + QComboBox, QFormLayout, QLineEdit, QSpinBox, QDoubleSpinBox, + QCheckBox, QWidget, QVBoxLayout, QMessageBox +) +from PySide6.QtWidgets import ( + QPlainTextEdit, QSplitter, QGroupBox, QScrollArea, QMenuBar +) +from PySide6.QtGui import QAction + +from PySide6.QtWidgets import QColorDialog +from PySide6.QtGui import QColor +from PySide6.QtWidgets import QDialog + +from ..pipeline import run_style, available_styles +from ..preview import svg_to_bgr +from .auto_explore import AutoExploreDialog, Candidate +from ..styles import ALL_STYLES # ensure this exists + +def bgr_to_qpixmap(bgr: np.ndarray) -> QPixmap: + rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) + h, w, _ = rgb.shape + qimg = QImage(rgb.data, w, h, 3 * w, QImage.Format.Format_RGB888) + return QPixmap.fromImage(qimg) + +class MainWindow(QMainWindow): + def __init__(self) -> None: + super().__init__() + self.setWindowTitle("Raster2SVG Studio") + + self._styles = available_styles() + self._bgr: Optional[np.ndarray] = None + self._last_svg: Optional[str] = None + self._last_preview_bgr: Optional[np.ndarray] = None + + self._input_path: Optional[str] = None + self._param_widgets: Dict[str, QWidget] = {} + self._dirty_params: bool = False + + + # Debounced update timer + self._timer = QTimer(self) + self._timer.setSingleShot(True) + self._timer.timeout.connect(self._recompute) + + # --- UI widgets + + menubar = QMenuBar(self) + help_menu = menubar.addMenu("Help") + act_theory = QAction("Theory of Operation", self) + act_theory.triggered.connect(self._show_theory) + help_menu.addAction(act_theory) + self.setMenuBar(menubar) + + + self.btn_open = QPushButton("Open Image…") + self.btn_open.clicked.connect(self._open_image) + + self.input_path_edit = QLineEdit() + self.input_path_edit.setReadOnly(True) + + self.bg_color = QColor(255, 255, 255) # default white + + self.btn_bg_color = QPushButton("Background…") + self.btn_bg_color.clicked.connect(self._pick_bg_color) + self.btn_bg_color.setEnabled(False) + + self.lbl_bg_swatch = QLabel() + self.lbl_bg_swatch.setFixedSize(24, 24) + self._update_bg_swatch() + + + self.output_name_edit = QLineEdit() + self.output_name_edit.setReadOnly(False) + + self.auto_update_chk = QCheckBox("Auto-update") + self.auto_update_chk.setChecked(True) + self.auto_update_chk.stateChanged.connect(self._on_auto_update_changed) + + self.btn_apply = QPushButton("Apply") + self.btn_apply.clicked.connect(self._recompute) + self.btn_apply.setEnabled(False) # enabled when auto-update off and dirty + + + self.btn_export = QPushButton("Export SVG…") + self.btn_export.clicked.connect(self._export_svg) + self.btn_export.setEnabled(False) + + + self.btn_auto_explore = QPushButton("Auto-explore…") + self.btn_auto_explore.clicked.connect(self._auto_explore) + self.btn_auto_explore.setEnabled(False) + + self.style_combo = QComboBox() + for k in self._styles.keys(): + self.style_combo.addItem(k) + self.style_combo.currentTextChanged.connect(self._on_style_changed) + + # Param widgets (we use a simple JSON param editor + a few common controls) + ''' + self.params_json = QLineEdit() + self.params_json.setPlaceholderText('{"n_colors":6,"simplify":1.2,"min_area":80}') + self.params_json.editingFinished.connect(self._schedule_update) + ''' + self.params_json_view = QPlainTextEdit() + self.params_json_view.setReadOnly(True) + self.params_json_view.setMinimumHeight(140) + + self.param_form_box = QGroupBox("Parameters") + self.param_form_layout = QFormLayout() + self.param_form_box.setLayout(self.param_form_layout) + + + + scroll = QScrollArea() + scroll.setWidgetResizable(True) + scroll.setWidget(self.param_form_box) + + + + self.scale_box = QDoubleSpinBox() + self.scale_box.setRange(0.05, 10.0) + self.scale_box.setSingleStep(0.05) + self.scale_box.setValue(1.0) + self.scale_box.valueChanged.connect(self._schedule_update) + + # Previews + self.lbl_orig = QLabel("Original") + self.lbl_proc = QLabel("Processed") + self.lbl_svg = QLabel("SVG Preview") + for lbl in (self.lbl_orig, self.lbl_proc, self.lbl_svg): + lbl.setAlignment(Qt.AlignmentFlag.AlignCenter) + lbl.setMinimumSize(320, 240) + lbl.setStyleSheet("border: 1px solid #999;") + + # Layout + top = QHBoxLayout() + + #top.addWidget(self.btn_open) + #self.btn_open.setLayout(top) + + top.addWidget(self.input_path_edit) + top.addWidget(self.output_name_edit) + top.addWidget(self.auto_update_chk) + + top.addWidget(QLabel("Style:")) + top.addWidget(self.style_combo) + top.addWidget(QLabel("Scale:")) + top.addWidget(self.scale_box) + top.addWidget(self.btn_export) + top.addStretch(1) + + form = QFormLayout() + form.addRow("Params (JSON overrides):", self.params_json_view) + + previews = QHBoxLayout() + previews.addWidget(self.lbl_orig, 1) + previews.addWidget(self.lbl_proc, 1) + previews.addWidget(self.lbl_svg, 1) + + root = QVBoxLayout() + root.addLayout(top) + root.addLayout(form) + root.addLayout(previews) + + left = QWidget() + left_layout = QVBoxLayout() + + # top file group + file_form = QFormLayout() + file_form.addRow("File:", self.btn_open) + + file_form.addRow("Input:", self.input_path_edit) + file_form.addRow("Suggested output:", self.output_name_edit) + + file_form.addRow("Export SVG:", self.btn_export) + file_form.addRow("Auto-explore:", self.btn_auto_explore) + + bg_row = QHBoxLayout() + bg_row.addWidget(QLabel("Transparent background:")) + bg_row.addWidget(self.lbl_bg_swatch) + bg_row.addWidget(self.btn_bg_color) + bg_row.addStretch(1) + left_layout.addLayout(bg_row) + + style_row = QHBoxLayout() + style_row.addWidget(QLabel("Style:")) + style_row.addWidget(self.style_combo) + style_row.addStretch(1) + style_row.addWidget(self.auto_update_chk) + style_row.addWidget(self.btn_apply) + + #left_layout.addLayout(top) + left_layout.addLayout(file_form) + left_layout.addLayout(style_row) + left_layout.addWidget(scroll, 1) + left_layout.addWidget(QLabel("Parameters (JSON)")) + left_layout.addWidget(self.params_json_view, 0) + + left.setLayout(left_layout) + + right = QWidget() + right_layout = QVBoxLayout() + row = QHBoxLayout() + row.addWidget(self.lbl_proc, 1) + row.addWidget(self.lbl_svg, 1) + right_layout.addLayout(row, 1) + + # smaller original at bottom + right_layout.addWidget(self.lbl_orig, 0) + right.setLayout(right_layout) + + split = QSplitter() + split.addWidget(left) + split.addWidget(right) + split.setStretchFactor(0, 0) + split.setStretchFactor(1, 1) + + + + #container = QWidget() + #container.setLayout(root) + #self.setCentralWidget(container) + + self.setCentralWidget(split) + + # Initialize param JSON to defaults for initial style + self._apply_style_defaults_to_editor() + + def _pick_bg_color(self) -> None: + col = QColorDialog.getColor(self.bg_color, self, "Select background color") + if not col.isValid(): + return + self.bg_color = col + self._update_bg_swatch() + self._apply_background_and_refresh() + + def _update_bg_swatch(self) -> None: + self.lbl_bg_swatch.setStyleSheet( + f"background-color: {self.bg_color.name()}; border: 1px solid #666;" + ) + + def _apply_background(self, img: np.ndarray) -> np.ndarray: + """ + Convert an RGBA or RGB image to opaque BGR by compositing + over the selected background color. + """ + if img.ndim != 3 or img.shape[2] != 4: + # Already opaque BGR + if img.shape[2] == 3: + return img + raise ValueError("Unexpected image format") + + # Split channels + b, g, r, a = cv2.split(img) + alpha = a.astype(np.float32) / 255.0 + alpha = alpha[..., None] # shape (H, W, 1) + + bg = np.array( + [self.bg_color.blue(), self.bg_color.green(), self.bg_color.red()], + dtype=np.float32 + ) + + fg = np.dstack([b, g, r]).astype(np.float32) + + out = fg * alpha + bg * (1.0 - alpha) + return out.astype(np.uint8) + + def _apply_background_and_refresh(self) -> None: + if self._orig_img is None: + return + + try: + self._bgr = self._apply_background(self._orig_img) + except Exception as e: + QMessageBox.critical(self, "Image error", str(e)) + return + + # Update original preview + self.lbl_orig.setPixmap( + bgr_to_qpixmap(self._bgr).scaled( + self.lbl_orig.size(), + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation + ) + ) + + # Recompute derivatives + if self.auto_update_chk.isChecked(): + self._schedule_update() + else: + self._dirty_params = True + self.btn_apply.setEnabled(True) + + def _show_theory(self) -> None: + txt = ( + "Theory of Operation\n\n" + "This tool converts a raster image (pixels) into an SVG (vector paths).\n\n" + "Views:\n" + "1) Original: the unmodified input raster.\n" + "2) Processed: the style-specific intermediate representation (e.g., quantized colors, threshold mask, tone bands).\n" + "3) SVG Preview: the vector output rendered back to a raster for display.\n" + " If CairoSVG is installed, this is a true SVG render; otherwise it falls back to the Processed view.\n\n" + "Workflow:\n" + "Open an image → choose a style → adjust parameters → Apply/Auto-update → Export SVG.\n" + ) + QMessageBox.information(self, "Theory of Operation", txt) + + + def _apply_style_defaults_to_editor(self) -> None: + style = self.style_combo.currentText() + defaults = dict(self._styles.get(style, {})) + # Don’t duplicate scale; we manage it separately + defaults.pop("scale", None) + #self.params_json_view.setText(json.dumps(defaults)) + self._update_json_view() + + def _on_style_changed_1(self) -> None: + self._apply_style_defaults_to_editor() + self._schedule_update() + + def _on_style_changed(self) -> None: + style = self.style_combo.currentText() + self._build_param_form(style) + self._update_output_suggestion() + if self._bgr is not None and self.auto_update_chk.isChecked(): + self._schedule_update() + + + def _schedule_update(self) -> None: + # debounce to avoid thrashing while sliders/editing + self._timer.start(200) + + def _open_image(self) -> None: + path, _ = QFileDialog.getOpenFileName(self, "Open Image", "", "Images (*.png *.jpg *.jpeg *.bmp *.tif *.tiff)") + if not path: + return + + img = cv2.imread(path, cv2.IMREAD_UNCHANGED) + if img is None: + has_alpha = False + pass + else: + has_alpha = (img.ndim == 3 and img.shape[2] == 4) + self._has_alpha = has_alpha + self._orig_img = img # keep original, possibly RGBA + self.btn_bg_color.setEnabled(self._has_alpha) + + self._apply_background_and_refresh() + + bgr = cv2.imread(path, cv2.IMREAD_COLOR) + if bgr is None: + QMessageBox.critical(self, "Error", f"Could not read image:\n{path}") + return + self._bgr = bgr + self.lbl_orig.setPixmap(bgr_to_qpixmap(bgr).scaled( + self.lbl_orig.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation + )) + self.btn_export.setEnabled(True) + self.btn_auto_explore.setEnabled(True) + + self._input_path = path + self.input_path_edit.setText(path) + self._update_output_suggestion() + + self._schedule_update() + + def resizeEvent(self, event) -> None: + super().resizeEvent(event) + # re-render pixmaps to fit + if self._bgr is not None: + self.lbl_orig.setPixmap(bgr_to_qpixmap(self._bgr).scaled( + self.lbl_orig.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation + )) + if self._last_preview_bgr is not None: + self.lbl_proc.setPixmap(bgr_to_qpixmap(self._last_preview_bgr).scaled( + self.lbl_proc.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation + )) + if self._last_svg is not None and self._bgr is not None: + svg_bgr = svg_to_bgr(self._last_svg, self._bgr.shape[1], self._bgr.shape[0]) + if svg_bgr is not None: + self.lbl_svg.setPixmap(bgr_to_qpixmap(svg_bgr).scaled( + self.lbl_svg.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation + )) + + def _parse_params(self) -> Dict[str, Any]: + txt = self.params_json.text().strip() + if not txt: + return {} + try: + return json.loads(txt) + except Exception as e: + QMessageBox.warning(self, "Params JSON error", f"Could not parse params JSON:\n{e}") + return {} + + def _make_candidates8(self, style: str, base: Dict[str, Any]) -> list[tuple[str, Dict[str, Any]]]: + b = dict(base) + + def cand(label: str, updates: Dict[str, Any]) -> tuple[str, Dict[str, Any]]: + p = dict(b) + p.update(updates) + return (label, p) + + if style == "posterized": + n = int(b.get("n_colors", 6)) + simp = float(b.get("simplify", 1.2)) + area = int(b.get("min_area", 80)) + blur = int(b.get("blur", 1)) + sw = float(b.get("stroke_width", 0.8)) + stroke = bool(b.get("add_stroke", False)) + + return [ + cand("Fewer colors", {"n_colors": max(2, n - 2)}), + cand("More colors", {"n_colors": min(20, n + 3)}), + cand("Simplify more", {"simplify": min(20.0, simp * 1.6)}), + cand("Simplify less", {"simplify": max(0.0, simp * 0.7)}), + cand("Drop specks more", {"min_area": min(1_000_000, int(area * 2))}), + cand("Keep more detail", {"min_area": max(0, int(area * 0.5))}), + cand("More blur", {"blur": min(10, blur + 2)}), + cand("Outline regions", {"add_stroke": True, "stroke_width": max(0.1, sw)} if not stroke else {"add_stroke": False}), + ] + + if style == "lineart": + mode = str(b.get("mode", "adaptive")) + thr = int(b.get("threshold", 128)) + bs = int(b.get("block_size", 31)) + c = int(b.get("c", 7)) + simp = float(b.get("simplify", 1.0)) + sw = float(b.get("stroke_width", 1.2)) + inv = bool(b.get("invert", True)) + + # Ensure odd block_size + def odd(x: int) -> int: + x = max(3, min(201, x)) + return x if (x % 2 == 1) else x + 1 + + return [ + cand("Adaptive (smoother)", {"mode": "adaptive", "block_size": odd(bs + 20), "c": c}), + cand("Adaptive (sharper)", {"mode": "adaptive", "block_size": odd(bs - 14), "c": c}), + cand("Adaptive (higher C)", {"mode": "adaptive", "block_size": odd(bs), "c": min(50, c + 6)}), + cand("Adaptive (lower C)", {"mode": "adaptive", "block_size": odd(bs), "c": max(-50, c - 6)}), + cand("Fixed (darker)", {"mode": "fixed", "threshold": max(0, thr - 25)}), + cand("Fixed (lighter)", {"mode": "fixed", "threshold": min(255, thr + 25)}), + cand("Thicker stroke", {"stroke_width": min(50.0, sw * 1.6)}), + cand("Invert toggled", {"invert": (not inv)}), + ] + + if style == "woodcut": + bands = int(b.get("tone_bands", 5)) + base_sp = float(b.get("hatch_base_spacing", 18.0)) + fac = float(b.get("hatch_spacing_factor", 0.70)) + ang = float(b.get("hatch_angle_deg", -25.0)) + e_low = int(b.get("edge_low", 40)) + e_high = int(b.get("edge_high", 120)) + e_sw = float(b.get("edge_stroke_width", 1.4)) + h_sw = float(b.get("hatch_stroke_width", 1.0)) + + return [ + cand("More tone bands", {"tone_bands": min(12, bands + 2)}), + cand("Fewer tone bands", {"tone_bands": max(2, bands - 2)}), + cand("Tighter hatching", {"hatch_base_spacing": max(2.0, base_sp * 0.75)}), + cand("Looser hatching", {"hatch_base_spacing": min(200.0, base_sp * 1.35)}), + cand("Rotate hatch +30°", {"hatch_angle_deg": max(-89.0, min(89.0, ang + 30.0))}), + cand("Rotate hatch -30°", {"hatch_angle_deg": max(-89.0, min(89.0, ang - 30.0))}), + cand("Stronger edges", {"edge_low": min(255, e_low + 15), "edge_high": min(255, e_high + 25), "edge_stroke_width": min(20.0, e_sw * 1.2)}), + cand("More hatch ink", {"hatch_stroke_width": min(20.0, h_sw * 1.3)}), + ] + + if style == "pontillist": + n = int(b.get("n_colors", 6)) + step = float(b.get("grid_step", 8.0)) + gamma = float(b.get("tone_gamma", 1.6)) + acc = float(b.get("accept_scale", 1.0)) + rmax = float(b.get("dot_radius_max", 2.2)) + boundaries = bool(b.get("draw_boundaries", True)) + return [ + cand("Fewer colors", {"n_colors": max(2, n - 2)}), + cand("More colors", {"n_colors": min(20, n + 3)}), + cand("Denser (smaller step)", {"grid_step": max(2.0, step * 0.75)}), + cand("Sparser (larger step)", {"grid_step": min(50.0, step * 1.35)}), + cand("More dark emphasis", {"tone_gamma": min(4.0, gamma * 1.25)}), + cand("Less dark emphasis", {"tone_gamma": max(0.3, gamma * 0.8)}), + cand("More dots overall", {"accept_scale": min(3.0, acc * 1.25)}), + cand("Toggle boundaries", {"draw_boundaries": (not boundaries)}), + ] + + # Fallback: no candidates + return [] + + + + def _recompute(self) -> None: + if self._bgr is None: + return + + style = self.style_combo.currentText() + ''' + params = self._parse_params() + params["scale"] = float(self.scale_box.value()) + ''' + params = self._read_params_from_form() + params["scale"] = float(self.scale_box.value()) if hasattr(self, "scale_box") else 1.0 + + + try: + res = run_style(self._bgr, style, params) + except Exception as e: + QMessageBox.critical(self, "Conversion error", str(e)) + return + + self._last_svg = res.svg + self._last_preview_bgr = res.preview_bgr + + # Processed pane + self.lbl_proc.setPixmap(bgr_to_qpixmap(res.preview_bgr).scaled( + self.lbl_proc.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation + )) + + # SVG pane (render via cairosvg if available) + svg_bgr = svg_to_bgr(res.svg, self._bgr.shape[1], self._bgr.shape[0]) + if svg_bgr is None: + # fallback: show processed raster stage + self.lbl_svg.setPixmap(bgr_to_qpixmap(res.preview_bgr).scaled( + self.lbl_svg.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation + )) + self.lbl_svg.setToolTip("Install optional dependency 'cairosvg' for true SVG preview: pip install 'r2s[preview]'") + else: + self.lbl_svg.setPixmap(bgr_to_qpixmap(svg_bgr).scaled( + self.lbl_svg.size(), Qt.AspectRatioMode.KeepAspectRatio, Qt.TransformationMode.SmoothTransformation + )) + self.lbl_svg.setToolTip("") + self._dirty_params = False + self.btn_apply.setEnabled(False) + + def _export_svg(self) -> None: + + if not self._last_svg: + QMessageBox.information(self, "Nothing to export", "Run a conversion first.") + return + # out, _ = QFileDialog.getSaveFileName(self, "Export SVG", "output.svg", "SVG (*.svg)") + suggest = self.output_name_edit.text().strip() or "output.svg" + out, _ = QFileDialog.getSaveFileName(self, "Export SVG", suggest, "SVG (*.svg)") + if not out: + return + Path(out).write_text(self._last_svg, encoding="utf-8") + + + + def _auto_explore(self) -> None: + if self._bgr is None: + QMessageBox.information(self, "Auto-explore", "Open an image first.") + return + + style = self.style_combo.currentText() + + # Current params from form (not from JSON box) + base_params = self._read_params_from_form() + # If you keep scale global, include it for rendering consistency + if hasattr(self, "scale_box"): + base_params["scale"] = float(self.scale_box.value()) + + # Reduced scale for candidate rendering speed + explore_scale = 0.5 + # Preserve user scale but override for exploration rendering + base_for_render = dict(base_params) + base_for_render["scale"] = explore_scale + + # Build 8 candidates + cand_defs = self._make_candidates8(style, base_params) + if len(cand_defs) != 8: + QMessageBox.warning(self, "Auto-explore", f"No candidate generator for style '{style}'.") + return + + # Render current (center) + try: + cur_pix = self._render_candidate_pixmap(style, base_params, explore_scale) + except Exception as e: + QMessageBox.critical(self, "Auto-explore error", str(e)) + return + + current = Candidate(label="Current settings", params=dict(base_params), preview_pix=cur_pix) + + # Render 8 candidates + candidates8: list[Candidate] = [] + try: + for label, p in cand_defs: + pix = self._render_candidate_pixmap(style, p, explore_scale) + candidates8.append(Candidate(label=label, params=p, preview_pix=pix)) + except Exception as e: + QMessageBox.critical(self, "Auto-explore error", str(e)) + return + + dlg = AutoExploreDialog(self, candidates8=candidates8, current=current) + #if dlg.exec() == dlg.Accepted: + if dlg.exec() == QDialog.DialogCode.Accepted: + + chosen = dlg.selected_params() + # Apply chosen params to the form widgets + self._set_form_from_params(chosen) + self._dirty_params = True + self._update_json_view() + if self.auto_update_chk.isChecked(): + self._schedule_update() + else: + self.btn_apply.setEnabled(True) + + + def _set_form_from_params(self, params: Dict[str, Any]) -> None: + for k, v in params.items(): + if k not in self._param_widgets: + continue + w = self._param_widgets[k] + try: + if isinstance(w, QCheckBox): + w.setChecked(bool(v)) + elif isinstance(w, QSpinBox): + w.setValue(int(v)) + elif isinstance(w, QDoubleSpinBox): + w.setValue(float(v)) + elif isinstance(w, QComboBox): + w.setCurrentText(str(v)) + elif isinstance(w, QLineEdit): + w.setText(str(v)) + except Exception: + # Ignore ill-typed values rather than crashing + pass + + # If you keep scale global, update it too + if "scale" in params and hasattr(self, "scale_box"): + try: + self.scale_box.setValue(float(params["scale"])) + except Exception: + pass + + + + def _build_param_form(self, style_name: str) -> None: + # clear old widgets + while self.param_form_layout.rowCount(): + self.param_form_layout.removeRow(0) + self._param_widgets.clear() + + style = ALL_STYLES[style_name] + specs = style.param_specs() + defaults = style.default_params() + defaults.pop("scale", None) # if scale handled elsewhere + + for spec in specs: + w = None + + if spec.ptype == "bool": + cb = QCheckBox() + cb.setChecked(bool(spec.default)) + cb.stateChanged.connect(self._on_param_changed) + w = cb + + elif spec.ptype == "int": + sb = QSpinBox() + if spec.min is not None: sb.setMinimum(int(spec.min)) + if spec.max is not None: sb.setMaximum(int(spec.max)) + sb.setSingleStep(int(spec.step or 1)) + sb.setValue(int(spec.default)) + sb.valueChanged.connect(self._on_param_changed) + w = sb + + elif spec.ptype == "float": + dsb = QDoubleSpinBox() + if spec.min is not None: dsb.setMinimum(float(spec.min)) + if spec.max is not None: dsb.setMaximum(float(spec.max)) + dsb.setSingleStep(float(spec.step or 0.1)) + dsb.setDecimals(3) + dsb.setValue(float(spec.default)) + dsb.valueChanged.connect(self._on_param_changed) + w = dsb + + elif spec.ptype == "choice": + combo = QComboBox() + combo.addItems(spec.choices or []) + combo.setCurrentText(str(spec.default)) + combo.currentTextChanged.connect(self._on_param_changed) + w = combo + + else: # "str" + le = QLineEdit() + le.setText(str(spec.default)) + le.editingFinished.connect(self._on_param_changed) + w = le + + if spec.help: + w.setToolTip(spec.help) + + self._param_widgets[spec.key] = w + self.param_form_layout.addRow(spec.label + ":", w) + + self._dirty_params = True + self._update_json_view() + + + + def _on_param_changed(self) -> None: + self._dirty_params = True + self._update_json_view() + if self.auto_update_chk.isChecked(): + self._schedule_update() + else: + self.btn_apply.setEnabled(True) + + def _read_params_from_form(self) -> Dict[str, Any]: + params: Dict[str, Any] = {} + for k, w in self._param_widgets.items(): + if isinstance(w, QCheckBox): + params[k] = bool(w.isChecked()) + elif isinstance(w, QSpinBox): + params[k] = int(w.value()) + elif isinstance(w, QDoubleSpinBox): + params[k] = float(w.value()) + elif isinstance(w, QComboBox): + params[k] = str(w.currentText()) + elif isinstance(w, QLineEdit): + params[k] = str(w.text()) + else: + pass + return params + + def _update_json_view(self) -> None: + params = self._read_params_from_form() + # If you still manage scale separately, include it here so JSON reflects reality: + params["scale"] = float(self.scale_box.value()) if hasattr(self, "scale_box") else 1.0 + self.params_json_view.setPlainText(json.dumps(params, indent=2, sort_keys=True)) + + + + def _on_auto_update_changed(self) -> None: + if self.auto_update_chk.isChecked(): + self.btn_apply.setEnabled(False) + if self._dirty_params: + self._schedule_update() + else: + self.btn_apply.setEnabled(self._dirty_params) + + + + def _update_output_suggestion(self) -> None: + if not self._input_path: + return + p = Path(self._input_path) + style = self.style_combo.currentText() + self.output_name_edit.setText(f"{p.stem}_{style}.svg") + + + + def _render_candidate_pixmap(self, style: str, params: Dict[str, Any], explore_scale: float) -> QPixmap: + if self._bgr is None: + raise RuntimeError("No image loaded") + + # Force scaled rendering for speed + p = dict(params) + p["scale"] = explore_scale + + res = run_style(self._bgr, style, p) + + svg_bgr = svg_to_bgr(res.svg, self._bgr.shape[1], self._bgr.shape[0]) + use_bgr = svg_bgr if svg_bgr is not None else res.preview_bgr + return bgr_to_qpixmap(use_bgr) diff --git a/src/tests/test_pipeline_smoke.py b/src/tests/test_pipeline_smoke.py new file mode 100644 index 0000000..ad69330 --- /dev/null +++ b/src/tests/test_pipeline_smoke.py @@ -0,0 +1,14 @@ +import numpy as np +from r2s.pipeline import run_style + +def test_smoke_posterized(): + img = np.zeros((120, 160, 3), dtype=np.uint8) + img[:] = (30, 60, 90) + res = run_style(img, "posterized", {"n_colors": 3}) + assert "