"""Distribution model for normalized curve points inside a rectangular area."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Sequence
@dataclass(frozen=True)
[docs]
class CurvePoint:
"""One normalized point for a curve."""
@dataclass(frozen=True)
[docs]
class CurveSeries:
"""One curve made of normalized points in [0, 1]."""
[docs]
points: tuple[CurvePoint, ...]
[docs]
class CurveDistributionModel:
"""Store and validate multiple normalized curves."""
def __init__(
self,
curves: Sequence[Sequence[tuple[float, float]]],
labels: Sequence[str] | None = None,
*,
clamp_points: bool = True,
sort_by_x: bool = True,
) -> None:
[docs]
self._curves: list[CurveSeries] = []
self.set_curves(
curves,
labels=labels,
clamp_points=clamp_points,
sort_by_x=sort_by_x,
)
@property
[docs]
def curves(self) -> list[CurveSeries]:
return list(self._curves)
[docs]
def set_curves(
self,
curves: Sequence[Sequence[tuple[float, float]]],
labels: Sequence[str] | None = None,
*,
clamp_points: bool = True,
sort_by_x: bool = True,
) -> None:
if not curves:
raise ValueError("At least one curve is required")
if labels is None:
label_list = [f"Curve {i + 1}" for i in range(len(curves))]
else:
label_list = [str(lbl) for lbl in labels]
if len(label_list) != len(curves):
raise ValueError("labels length must match curves length")
out: list[CurveSeries] = []
for i, pts in enumerate(curves):
if len(pts) < 2:
raise ValueError("Each curve must contain at least 2 points")
normalized: list[CurvePoint] = []
for raw_x, raw_y in pts:
x = float(raw_x)
y = float(raw_y)
if clamp_points:
x = max(0.0, min(1.0, x))
y = max(0.0, min(1.0, y))
normalized.append(CurvePoint(x=x, y=y))
if sort_by_x:
normalized.sort(key=lambda p: p.x)
out.append(CurveSeries(label=label_list[i], points=tuple(normalized)))
self._curves = out