"""Graph-based intermediate representation for model fitting.
The GraphIR is a DAG (directed acyclic graph) of typed nodes connected
by explicit dependency edges. It captures the full semantics of a
model without prescribing an evaluation strategy.
Three-layer design:
1. **OOP tree** -- the user-facing Model/Component/Par objects.
Handles parsing, validation, user interaction. Unchanged.
2. **GraphIR** -- a directed acyclic graph of typed nodes with explicit
data-dependency edges. Axis-agnostic: works for 1D and 2D models.
3. **ScheduledPlan2D** -- a flat, packed-array execution schedule
compiled from the graph by the 2D backend. No Python objects, no
strings, no dicts in the hot path.
"""
from __future__ import annotations
import re
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import IntEnum
from typing import TYPE_CHECKING, Any
import numpy as np
from trspecfit.functions import energy as fcts_energy
from trspecfit.functions import profile as fcts_profile
if TYPE_CHECKING:
from trspecfit.mcp import Component, Model, Par
#
#
[docs]
class NodeKind(IntEnum):
"""Node types in the model graph."""
# --- Parameter nodes (leaves) ---
STATIC_PARAM = 0
OPT_PARAM = 1
# --- Computed parameter nodes ---
DYNAMICS_TRACE = 2
PARAM_PLUS_TRACE = 3
EXPRESSION = 4
# --- Component evaluation nodes ---
COMPONENT_EVAL = 5
# --- Reduction / combination nodes ---
SUM = 6
SPECTRUM_FED_OP = 7
# --- Convolution and profile nodes (representable, not v1-compilable) ---
CONVOLUTION = 100
PROFILE_SAMPLE = 101
PROFILE_AVERAGE = 102
SUBCYCLE_MASK = 103
SUBCYCLE_REMAP = 104
#
#
[docs]
class EdgeKind(IntEnum):
"""Edge types in the model graph."""
PARAM_INPUT = 0
TRACE_INPUT = 1
BASE_INPUT = 2
ADDEND = 3
SPECTRUM_INPUT = 4
EXPR_REF = 5
#
#
[docs]
class DomainKind(IntEnum):
"""Model domain classification.
Determined by which axes the model operates on:
- ``ENERGY_1D``: model has energy axis only.
- ``TIME_1D``: model has time axis only.
- ``ENERGY_TIME_2D``: model has both axes.
"""
ENERGY_1D = 0
TIME_1D = 1
ENERGY_TIME_2D = 2
#
#
[docs]
class OpKind(IntEnum):
"""2D backend component function op codes.
Backend-specific lowered enum. ``schedule_2d`` maps
``GraphNode.function_name`` to ``OpKind`` during compilation.
"""
GAUSS = 0
GAUSS_ASYM = 1
LORENTZ = 2
VOIGT = 3
GLS = 4
GLP = 5
DS = 6
OFFSET = 10
LINBACK = 11
SHIRLEY = 12
#
#
[docs]
class DynFuncKind(IntEnum):
"""Dynamics function op codes for the 2D backend.
Backend-specific lowered enum. ``schedule_2d`` maps
``GraphNode.function_name`` (for ``DYNAMICS_TRACE`` nodes) to
``DynFuncKind`` during compilation.
"""
EXPFUN = 0
SINFUN = 1
LINFUN = 2
SINDIVX = 3
ERFFUN = 4
SQRTFUN = 5
#
#
[docs]
class ProfileFuncKind(IntEnum):
"""Profile-function op codes for lowered 1D profile evaluation."""
PEXPDECAY = 0
PLINEAR = 1
PGAUSS = 2
#
#
[docs]
class ParamSourceKind(IntEnum):
"""Parameter source kinds for profiled component evaluation (1D and 2D)."""
SCALAR = 0
PROFILE_SAMPLE = 1
PROFILE_EXPR = 2
_FUNCTION_NAME_TO_DYN_FUNC: dict[str, DynFuncKind] = {
"expFun": DynFuncKind.EXPFUN,
"sinFun": DynFuncKind.SINFUN,
"linFun": DynFuncKind.LINFUN,
"sinDivX": DynFuncKind.SINDIVX,
"erfFun": DynFuncKind.ERFFUN,
"sqrtFun": DynFuncKind.SQRTFUN,
}
_FUNCTION_NAME_TO_PROFILE_FUNC: dict[str, ProfileFuncKind] = {
"pExpDecay": ProfileFuncKind.PEXPDECAY,
"pLinear": ProfileFuncKind.PLINEAR,
"pGauss": ProfileFuncKind.PGAUSS,
}
#
#
[docs]
class ConvKernelKind(IntEnum):
"""Registry IDs for convolution kernel functions.
Only kernel functions listed here can be lowered; other kernels
(or non-time packages) fall back to MCP.
"""
GAUSSCONV = 0
LORENTZCONV = 1
VOIGTCONV = 2
EXPSYMCONV = 3
EXPDECAYCONV = 4
EXPRISECONV = 5
BOXCONV = 6
_FUNCTION_NAME_TO_CONV_KERNEL: dict[str, ConvKernelKind] = {
"gaussCONV": ConvKernelKind.GAUSSCONV,
"lorentzCONV": ConvKernelKind.LORENTZCONV,
"voigtCONV": ConvKernelKind.VOIGTCONV,
"expSymCONV": ConvKernelKind.EXPSYMCONV,
"expDecayCONV": ConvKernelKind.EXPDECAYCONV,
"expRiseCONV": ConvKernelKind.EXPRISECONV,
"boxCONV": ConvKernelKind.BOXCONV,
}
_FUNCTION_NAME_TO_OP: dict[str, OpKind] = {
"Gauss": OpKind.GAUSS,
"GaussAsym": OpKind.GAUSS_ASYM,
"Lorentz": OpKind.LORENTZ,
"Voigt": OpKind.VOIGT,
"GLS": OpKind.GLS,
"GLP": OpKind.GLP,
"DS": OpKind.DS,
"Offset": OpKind.OFFSET,
"LinBack": OpKind.LINBACK,
"Shirley": OpKind.SHIRLEY,
}
#: Maps ``OpKind`` → ``(eval_function, needs_spectrum)``.
#: Single source of truth for component dispatch -- used by both the
#: evaluator hot path and constant-component precomputation.
OP_DISPATCH: dict[int, tuple[Callable[..., Any], bool]] = {
int(OpKind.GAUSS): (fcts_energy.Gauss, False),
int(OpKind.GAUSS_ASYM): (fcts_energy.GaussAsym, False),
int(OpKind.LORENTZ): (fcts_energy.Lorentz, False),
int(OpKind.VOIGT): (fcts_energy.Voigt, False),
int(OpKind.GLS): (fcts_energy.GLS, False),
int(OpKind.GLP): (fcts_energy.GLP, False),
int(OpKind.DS): (fcts_energy.DS, False),
int(OpKind.OFFSET): (fcts_energy.Offset, False),
int(OpKind.LINBACK): (fcts_energy.LinBack, False),
int(OpKind.SHIRLEY): (fcts_energy.Shirley, True),
}
PROFILE_DISPATCH: dict[int, Callable[..., Any]] = {
int(ProfileFuncKind.PEXPDECAY): fcts_profile.pExpDecay,
int(ProfileFuncKind.PLINEAR): fcts_profile.pLinear,
int(ProfileFuncKind.PGAUSS): fcts_profile.pGauss,
}
#
#
[docs]
class ExprNodeKind(IntEnum):
"""RPN instruction types for compiled expressions."""
CONST = 0
PARAM_REF = 1
ADD = 2
SUB = 3
MUL = 4
DIV = 5
NEG = 6
POW = 7
#
#
[docs]
@dataclass
class GraphNode:
"""One node in the model graph."""
id: int
kind: NodeKind
name: str
source_order: int
# Payload (interpretation depends on kind):
value: float | None = None
function_name: str | None = None
package: str | None = None
expr_string: str | None = None
vary: bool = False
bounds: tuple[float, float] | None = None
arrays: dict[str, np.ndarray] = field(default_factory=dict)
#
#
[docs]
@dataclass
class GraphEdge:
"""One edge in the model graph."""
source: int
target: int
kind: EdgeKind
position: int | None = None
#
#
[docs]
@dataclass
class GraphIR:
"""Directed acyclic graph representing a model.
Axis-agnostic: works for 1D and 2D models. A 1D energy model has
``time=None``; adding dynamics populates ``time`` and promotes
``domain`` to ``ENERGY_TIME_2D``.
"""
nodes: list[GraphNode]
edges: list[GraphEdge]
domain: DomainKind
energy: np.ndarray | None = None
time: np.ndarray | None = None
node_by_name: dict[str, int] = field(default_factory=dict)
#
[docs]
def to_dot(self, *, collapse_profiles: bool = True) -> str:
"""Return a Graphviz DOT string for this graph.
Node shapes and colours encode ``NodeKind``; edge labels encode
``EdgeKind``. The output can be rendered with ``dot -Tpng`` or
any Graphviz viewer.
Parameters
----------
collapse_profiles : bool, default=True
When True, per-sample profile nodes (``PROFILE_SAMPLE``,
per-sample ``COMPONENT_EVAL``, per-sample ``EXPRESSION``)
are collapsed into single representative nodes showing the
sample count. This keeps profile models readable.
"""
_NODE_STYLE: dict[NodeKind, dict[str, str]] = {
NodeKind.STATIC_PARAM: dict(
shape="ellipse", style="filled", fillcolor="#d3d3d3"
),
NodeKind.OPT_PARAM: dict(
shape="ellipse", style="filled", fillcolor="#87ceeb"
),
NodeKind.DYNAMICS_TRACE: dict(
shape="box", style="filled", fillcolor="#ffa07a"
),
NodeKind.PARAM_PLUS_TRACE: dict(
shape="box", style="filled", fillcolor="#ffcc80"
),
NodeKind.EXPRESSION: dict(
shape="hexagon", style="filled", fillcolor="#dda0dd"
),
NodeKind.COMPONENT_EVAL: dict(
shape="box", style="filled,bold", fillcolor="#90ee90"
),
NodeKind.SUM: dict(shape="diamond", style="filled", fillcolor="#fffacd"),
NodeKind.SPECTRUM_FED_OP: dict(
shape="box", style="filled,bold", fillcolor="#f08080"
),
NodeKind.CONVOLUTION: dict(
shape="octagon", style="filled", fillcolor="#e0e0ff"
),
NodeKind.PROFILE_SAMPLE: dict(
shape="parallelogram", style="filled", fillcolor="#c8e6c9"
),
NodeKind.PROFILE_AVERAGE: dict(
shape="parallelogram", style="filled", fillcolor="#a5d6a7"
),
NodeKind.SUBCYCLE_MASK: dict(
shape="trapezium", style="filled", fillcolor="#ffe0b2"
),
NodeKind.SUBCYCLE_REMAP: dict(
shape="trapezium", style="filled", fillcolor="#ffcc80"
),
}
_EDGE_STYLE: dict[EdgeKind, dict[str, str]] = {
EdgeKind.PARAM_INPUT: dict(color="#333333"),
EdgeKind.TRACE_INPUT: dict(color="#ff6600", style="dashed"),
EdgeKind.BASE_INPUT: dict(color="#0066cc", style="dashed"),
EdgeKind.ADDEND: dict(color="#009933", style="bold"),
EdgeKind.SPECTRUM_INPUT: dict(color="#cc0000", style="bold"),
EdgeKind.EXPR_REF: dict(color="#9933cc", style="dotted"),
}
# --- Profile collapsing ---
# Maps each per-sample node id to the representative node id for
# its group. Nodes not in this dict are emitted as-is.
collapsed: dict[int, int] = {} # sample_nid -> representative_nid
# Groups: representative_nid -> (base_name, kind, count, first_node)
_profile_groups: dict[int, tuple[str, NodeKind, int, GraphNode]] = {}
if collapse_profiles:
_sample_re = re.compile(
r"^(.+?)_(profile_sample|profile_expr|sample)_(\d+)$"
)
# Group nodes by (base_name, kind)
groups: dict[tuple[str, NodeKind], list[GraphNode]] = {}
for node in self.nodes:
m = _sample_re.match(node.name)
if m is not None:
base = m.group(1)
groups.setdefault((base, node.kind), []).append(node)
for (base, kind), members in groups.items():
if len(members) < 2:
continue
rep = members[0]
for member in members:
collapsed[member.id] = rep.id
_profile_groups[rep.id] = (base, kind, len(members), rep)
hidden_nodes = {nid for nid in collapsed if collapsed[nid] != nid}
lines: list[str] = [
"digraph ModelGraph {",
" rankdir=BT;",
' node [fontname="Helvetica", fontsize=10];',
' edge [fontname="Helvetica", fontsize=8];',
]
for node in self.nodes:
if node.id in hidden_nodes:
continue
attrs = dict(_NODE_STYLE.get(node.kind, {}))
if node.id in _profile_groups:
base, kind, count, _ = _profile_groups[node.id]
label_parts = [f"{base} (\u00d7{count})", kind.name]
else:
label_parts = [node.name, node.kind.name]
if node.function_name:
label_parts.append(f"fn={node.function_name}")
if node.value is not None:
label_parts.append(f"val={node.value:g}")
if node.expr_string:
label_parts.append(f"expr={node.expr_string}")
attrs["label"] = "\\n".join(label_parts)
attr_str = ", ".join(f'{k}="{v}"' for k, v in attrs.items())
lines.append(f" n{node.id} [{attr_str}];")
seen_edges: set[tuple[int, int, EdgeKind, int | None]] = set()
for edge in self.edges:
src = collapsed.get(edge.source, edge.source)
tgt = collapsed.get(edge.target, edge.target)
key = (src, tgt, edge.kind, edge.position)
if key in seen_edges:
continue
seen_edges.add(key)
attrs = dict(_EDGE_STYLE.get(edge.kind, {}))
label = edge.kind.name
if edge.position is not None:
label += f"[{edge.position}]"
attrs["label"] = label
attr_str = ", ".join(f'{k}="{v}"' for k, v in attrs.items())
lines.append(f" n{src} -> n{tgt} [{attr_str}];")
lines.append("}")
return "\n".join(lines)
#
#
[docs]
@dataclass(frozen=True)
class ExprProgram:
"""Compiled expression: flat int array encoding an RPN program.
Encoding: pairs of ``(node_kind, operand)``.
- ``CONST``: operand is float bits (``np.float64.view(np.int64)``)
- ``PARAM_REF``: operand is row index into trace matrix
- Operators: operand is 0 (unused)
"""
instructions: np.ndarray # (2 * n_instructions,) int64
#
#
[docs]
@dataclass(frozen=True)
class ScheduledPlan2D:
"""Compiled 2D execution schedule.
No Python objects in the hot path (except ``expr_programs``).
"""
energy: np.ndarray # (n_energy,)
time: np.ndarray # (n_time,)
n_params: int
n_time: int
# --- Parameter mapping ---
param_traces_init: np.ndarray # (n_params, n_time)
opt_indices: np.ndarray # (n_opt,) int
opt_param_names: list[str] # (n_opt,) canonical optimizer param names
# --- Dynamics subgraphs (grouped by target PARAM_PLUS_TRACE) ---
# Each "dynamics group" corresponds to one time-dependent parameter.
# Multiple dynamics components (e.g. bi-exponential) targeting the
# same PARAM_PLUS_TRACE are grouped together. Substeps within a
# group are indexed via the CSR-style dyn_group_indptr.
n_dyn_groups: int
dyn_group_target_row: np.ndarray # (n_dyn_groups,) int
dyn_group_base_row: np.ndarray # (n_dyn_groups,) int
dyn_group_indptr: np.ndarray # (n_dyn_groups + 1,) int -- CSR into substep arrays
dyn_sub_func_id: np.ndarray # (n_substeps,) int
dyn_sub_param_rows: np.ndarray # (n_substeps, max_dyn_params) int, -1 padded
dyn_sub_n_params: np.ndarray # (n_substeps,) int
# Per-substep subcycle schedule data. SUBCYCLE_REMAP /
# SUBCYCLE_MASK graph nodes are compiled away here: non-subcycle
# substeps get ``time`` and ``ones`` defaults, subcycle substeps get
# ``time_norm`` and ``time_n_sub`` copies. Evaluator applies them as
# ``func(dyn_sub_time_axes[s], ...) * dyn_sub_masks[s]``.
dyn_sub_time_axes: np.ndarray # (n_substeps, n_time) float64
dyn_sub_masks: np.ndarray # (n_substeps, n_time) float64
# --- Expression evaluation ---
n_expressions: int
expr_target_rows: np.ndarray # (n_expressions,) int
expr_programs: list[ExprProgram]
# --- Interleaved parameter resolution schedule ---
# Dynamics groups, expressions, and convolution steps may depend on
# each other (e.g. a dynamics param that is an expression of another
# dynamics param; a conv step that wraps a resolved param). The
# resolution_kinds / resolution_indices arrays encode the correct
# topological execution order:
# kind=0 -> dynamics group step, index into dyn_group_* arrays
# kind=1 -> expression step, index into expr_* arrays / expr_programs
# kind=2 -> convolution step, index into conv_* arrays
resolution_kinds: np.ndarray # (n_dyn_groups + n_expressions + n_conv_steps,) int8
resolution_indices: np.ndarray # (n_dyn_groups + n_expressions + n_conv_steps,) int
# --- Resolved-trace convolution program ---
# Each conv step rewrites a trace row in-place after its PARAM_PLUS_TRACE
# is fully resolved. Chained CONVOLUTION nodes emit multiple steps
# targeting the same row, executed in order. Kernel values are
# recomputed per theta from conv_param_rows; support values are frozen
# at plan-build time from ``node.arrays["kernel_time"]``. Only
# ``package == "time"`` kernels are lowered.
n_conv_steps: int
conv_target_rows: np.ndarray # (n_conv_steps,) int -- trace row rewritten
conv_func_ids: np.ndarray # (n_conv_steps,) int -- kernel function registry id
conv_param_indptr: np.ndarray # (n_conv_steps + 1,) int -- CSR row pointers
conv_param_rows: np.ndarray # (total_conv_params,) int -- kernel param trace rows
conv_support_indptr: (
np.ndarray
) # (n_conv_steps + 1,) int -- CSR into support values
conv_support_values: (
np.ndarray
) # (total_support,) float -- kernel time axis samples
# --- Profile-varying parameter groups (fixed aux_axis shape) ---
n_aux: int
aux_axis: np.ndarray # (n_aux,)
n_profile_samples: int
profile_sample_base_rows: np.ndarray # (n_profile_samples,) int
profile_sample_component_indptr: np.ndarray # (n_profile_samples + 1,) int
profile_component_func_ids: np.ndarray # (n_profile_components,) int
profile_component_param_indptr: np.ndarray # (n_profile_components + 1,) int
profile_component_param_rows: np.ndarray # (total_profile_component_params,) int
n_profile_exprs: int
profile_expr_programs: list[ExprProgram]
# --- Scheduled component ops ---
n_ops: int
op_schedule: np.ndarray # (n_ops,) int
op_kinds: np.ndarray # (n_ops,) OpKind int codes
op_param_indptr: np.ndarray # (n_ops + 1,) int -- CSR row pointers
op_param_source_kinds: np.ndarray # (total_op_params,) ParamSourceKind int codes
op_param_indices: np.ndarray # (total_op_params,) int -- row/group indices
op_needs_spectrum: np.ndarray # (n_ops,) bool
op_is_pre_spectrum: np.ndarray # (n_ops,) bool
op_is_profiled: np.ndarray # (n_ops,) bool
op_is_constant: np.ndarray # (n_ops,) bool
cached_result: np.ndarray # (n_time, n_energy)
cached_peak_sum: np.ndarray # (n_time, n_energy)
#
#
[docs]
@dataclass(frozen=True)
class ScheduledPlan1D:
"""Compiled 1D execution schedule for ENERGY_1D models.
Simpler than ``ScheduledPlan2D``: no time axis, no dynamics, no
trace matrix. Parameters are stored as a flat ``(n_params,)``
scalar vector.
"""
energy: np.ndarray # (n_energy,)
n_params: int
# --- Parameter mapping ---
param_values_init: np.ndarray # (n_params,) initial scalar values
opt_indices: np.ndarray # (n_opt,) int -- indices into param_values
opt_param_names: list[str] # (n_opt,) canonical optimizer param names
# --- Expression evaluation (topological order, no dynamics) ---
n_expressions: int
expr_target_indices: np.ndarray # (n_expressions,) int
expr_programs: list[ExprProgram]
# --- Profile-varying parameter groups (fixed aux_axis shape) ---
n_aux: int
aux_axis: np.ndarray # (n_aux,)
n_profile_samples: int
profile_sample_base_indices: np.ndarray # (n_profile_samples,) int
profile_sample_component_indptr: np.ndarray # (n_profile_samples + 1,) int
profile_component_func_ids: np.ndarray # (n_profile_components,) int
profile_component_param_indptr: np.ndarray # (n_profile_components + 1,) int
profile_component_param_indices: np.ndarray # (total_profile_component_params,) int
n_profile_exprs: int
profile_expr_programs: list[ExprProgram]
# --- Scheduled component ops ---
n_ops: int
op_kinds: np.ndarray # (n_ops,) OpKind int codes
op_param_indptr: np.ndarray # (n_ops + 1,) int -- CSR row pointers
op_param_source_kinds: np.ndarray # (total_op_params,) ParamSourceKind int codes
op_param_indices: np.ndarray # (total_op_params,) int -- source indices
op_needs_spectrum: np.ndarray # (n_ops,) bool
op_is_pre_spectrum: np.ndarray # (n_ops,) bool
op_is_profiled: np.ndarray # (n_ops,) bool
op_is_constant: np.ndarray # (n_ops,) bool
cached_result: np.ndarray # (n_energy,)
cached_peak_sum: np.ndarray # (n_energy,)
_PACKAGE_SHORT_NAMES: dict[str, str] = {
"fcts_energy": "energy",
"fcts_time": "time",
"fcts_profile": "profile",
}
#
def _package_short_name(comp: Component) -> str:
"""Return ``"energy"``, ``"time"``, or ``"profile"`` for a component."""
mod_name = comp.package_name.rsplit(".", maxsplit=1)[-1]
return _PACKAGE_SHORT_NAMES.get(mod_name, mod_name)
#
def _par_initial_value(par: Par) -> float:
"""Extract the current scalar value from a Par's lmfit_par."""
vals = list(par.lmfit_par.valuesdict().values())
return float(vals[0]) if vals else 0.0
#
def _par_bounds(par: Par) -> tuple[float, float] | None:
"""Extract bounds from a Par's lmfit_par, or None."""
for p in par.lmfit_par.values():
mn = p.min if p.min is not None else -np.inf
mx = p.max if p.max is not None else np.inf
return (float(mn), float(mx))
return None
#
def _is_expression_par(par: Par) -> bool:
"""True if this Par is defined by an expression string."""
return len(par.info) == 1 and isinstance(par.info[0], str)
#
def _par_expression_string(par: Par) -> str | None:
"""Return the expression string to use for graph wiring.
Prefer the lmfit expression when available because Dynamics models
auto-prefix references there (e.g. ``expFun_01_A`` ->
``GLP_01_A_expFun_01_A`` or ``parTEST_expFun_01_A``). Fall back to
the raw ``par.info[0]`` expression for plain energy-model expressions.
"""
if not _is_expression_par(par):
return None
for lmfit_par in par.lmfit_par.values():
if lmfit_par.expr:
return str(lmfit_par.expr)
return str(par.info[0])
#
def _extract_expression_references(expr_string: str) -> list[str]:
"""Extract identifier-like references from an expression string.
This is intentionally generic: it collects all Python-identifier-like
tokens in lexical order and leaves semantic filtering to the caller
(for example, by checking whether the token is present in
``resolved_param`` or ``node_by_name``).
"""
pattern = r"\b[A-Za-z_][A-Za-z0-9_]*\b"
refs: list[str] = []
seen: set[str] = set()
for token in re.findall(pattern, expr_string):
if token in seen:
continue
refs.append(token)
seen.add(token)
return refs
#
def _component_param_names(comp: Component) -> list[str]:
"""Return function parameter names, excluding axis arg and ``spectrum``."""
args = comp.fct_args[1:] # drop first (axis: x or t)
if args and args[-1] == "spectrum":
args = args[:-1]
return args
#
def _par_is_vary(par: Par) -> bool:
"""True if the Par has ``vary=True`` in its lmfit parameter."""
return bool(
par.lmfit_par.valuesdict() and any(p.vary for p in par.lmfit_par.values())
)
#
#
class _GraphBuilder:
"""Mutable state used while building a GraphIR from a Model."""
def __init__(self) -> None:
self.nodes: list[GraphNode] = []
self.edges: list[GraphEdge] = []
self.node_by_name: dict[str, int] = {}
self._next_id: int = 0
self._source_order: int = 0
self._removed: set[int] = set()
#
def add_node(self, kind: NodeKind, name: str, **kwargs) -> int:
"""Create a node, assign it an id and source_order, return id."""
nid = self._next_id
self._next_id += 1
order = self._source_order
self._source_order += 1
node = GraphNode(id=nid, kind=kind, name=name, source_order=order, **kwargs)
self.nodes.append(node)
self.node_by_name[name] = nid
return nid
#
def add_edge(
self, source: int, target: int, kind: EdgeKind, *, position: int | None = None
) -> None:
"""Append an edge connecting source -> target with the given kind."""
self.edges.append(
GraphEdge(source=source, target=target, kind=kind, position=position)
)
#
def mark_removed(self, nid: int) -> None:
"""Mark a node for removal during finalization."""
self._removed.add(nid)
#
def finalize(self) -> tuple[list[GraphNode], list[GraphEdge], dict[str, int]]:
"""Remove marked nodes/edges and re-index so node id == list position."""
if not self._removed:
return self.nodes, self.edges, dict(self.node_by_name)
kept_nodes = [n for n in self.nodes if n.id not in self._removed]
kept_edges = [
e
for e in self.edges
if e.source not in self._removed and e.target not in self._removed
]
# Re-index: old id -> new dense id
id_map = {old.id: new_id for new_id, old in enumerate(kept_nodes)}
for node in kept_nodes:
node.id = id_map[node.id]
for edge in kept_edges:
edge.source = id_map[edge.source]
edge.target = id_map[edge.target]
node_by_name = {n.name: n.id for n in kept_nodes}
return kept_nodes, kept_edges, node_by_name
#
[docs]
def build_graph(model: Model) -> GraphIR:
"""Walk the OOP Model tree and emit a GraphIR.
Parameters
----------
model : Model
A fully-constructed model (components, pars, and any dynamics /
profile models already attached).
Returns
-------
GraphIR
The semantic DAG representation.
"""
b = _GraphBuilder()
# ----- determine domain -----
has_energy = model.energy is not None
has_time = model.time is not None
if has_energy and has_time:
domain = DomainKind.ENERGY_TIME_2D
elif has_time:
domain = DomainKind.TIME_1D
else:
domain = DomainKind.ENERGY_1D
# ------------------------------------------------------------------ #
# 1. Create parameter nodes for every component #
# ------------------------------------------------------------------ #
# Maps par.name -> node id of the *resolved* value (which might be a
# PARAM_PLUS_TRACE node if time-dependent, or an EXPRESSION node).
resolved_param: dict[str, int] = {}
for comp in model.components:
for par in comp.pars:
_emit_par_nodes(b, par, resolved_param)
# ------------------------------------------------------------------ #
# 2. Create component nodes and wire PARAM_INPUT edges #
# ------------------------------------------------------------------ #
# Collect component node ids for combination wiring.
# comp_nodes: list of (comp, node_id, is_spectrum_fed)
comp_nodes: list[tuple[Component, int, bool]] = []
for comp in model.components:
if comp.comp_type == "none":
continue
pkg_name = _package_short_name(comp)
is_shirley = comp.fct_str == "Shirley"
# Convolution components
if comp.comp_type == "conv":
nid = b.add_node(
NodeKind.CONVOLUTION,
comp.comp_name,
function_name=comp.fct_str,
package=pkg_name,
)
# Store kernel-related arrays if available
if comp.time is not None:
b.nodes[nid].arrays["kernel_time"] = comp.time
elif is_shirley:
nid = b.add_node(
NodeKind.SPECTRUM_FED_OP,
comp.comp_name,
function_name=comp.fct_str,
package=pkg_name,
)
else:
nid = b.add_node(
NodeKind.COMPONENT_EVAL,
comp.comp_name,
function_name=comp.fct_str,
package=pkg_name,
)
# Wire PARAM_INPUT edges from resolved params to component
param_names = _component_param_names(comp)
for pos, _pname in enumerate(param_names):
par = comp.pars[pos]
src = resolved_param[par.name]
b.add_edge(src, nid, EdgeKind.PARAM_INPUT, position=pos)
comp_nodes.append((comp, nid, is_shirley))
# ------------------------------------------------------------------ #
# 3. Emit PROFILE_*, SUBCYCLE_* nodes for components that use them #
# ------------------------------------------------------------------ #
# profile_samples maps par_name -> [sample_nid_0, ..., sample_nid_n].
# Populated by p_vary components, consumed by expr_refs_profile_dep
# components. This preserves per-sample context across components so
# that expression params referencing a profiled param get per-sample
# EXPRESSION nodes (not a single averaged replacement).
profile_samples: dict[str, list[int]] = {}
for i, (comp, nid, is_shirley) in enumerate(comp_nodes):
new_nid = _emit_profile_nodes(b, comp, nid, resolved_param, profile_samples)
if new_nid != nid:
b.mark_removed(nid)
comp_nodes[i] = (comp, new_nid, is_shirley)
for comp in model.components:
if comp.comp_type == "none":
continue
_emit_subcycle_nodes(b, comp, resolved_param)
# ------------------------------------------------------------------ #
# 4. Create expression edges #
# ------------------------------------------------------------------ #
# Runs after profile nodes. Expression params that are
# expr_refs_profile_dep are already fully wired per-sample inside
# _emit_profile_nodes, so skip them here.
for comp in model.components:
for par in comp.pars:
if not _is_expression_par(par):
continue
if par.expr_refs_profile_dep:
continue # handled per-sample in _emit_profile_nodes
expr_nid = b.node_by_name[par.name]
expr_str = _par_expression_string(par)
if expr_str is None:
continue
refs = _extract_expression_references(expr_str)
for ref_name in refs:
if ref_name in resolved_param:
ref_nid = resolved_param[ref_name]
b.add_edge(ref_nid, expr_nid, EdgeKind.EXPR_REF)
# ------------------------------------------------------------------ #
# 5. Create SUM / combination nodes #
# ------------------------------------------------------------------ #
_emit_combination_nodes(b, comp_nodes)
nodes, edges, node_by_name = b.finalize()
return GraphIR(
nodes=nodes,
edges=edges,
domain=domain,
energy=model.energy,
time=model.time,
node_by_name=node_by_name,
)
#
def _emit_par_nodes(
b: _GraphBuilder,
par: Par,
resolved_param: dict[str, int],
) -> None:
"""Create graph nodes for one Par (and its dynamics/profile subgraph)."""
# Expression parameter
if _is_expression_par(par):
nid = b.add_node(
NodeKind.EXPRESSION,
par.name,
expr_string=par.info[0],
value=_par_initial_value(par),
)
resolved_param[par.name] = nid
return
# Base parameter node
if _par_is_vary(par):
base_nid = b.add_node(
NodeKind.OPT_PARAM,
par.name,
value=_par_initial_value(par),
vary=True,
bounds=_par_bounds(par),
)
else:
base_nid = b.add_node(
NodeKind.STATIC_PARAM,
par.name,
value=_par_initial_value(par),
)
# Default: resolved value is the base node itself
resolved_param[par.name] = base_nid
# Time-dependent parameter (Dynamics subgraph)
if par.t_vary and par.t_model is not None:
_emit_dynamics_subgraph(b, par, base_nid, resolved_param)
#
def _emit_dynamics_subgraph(
b: _GraphBuilder,
par: Par,
base_nid: int,
resolved_param: dict[str, int],
) -> None:
"""Emit DYNAMICS_TRACE + PARAM_PLUS_TRACE nodes for a time-dep par."""
t_model = par.t_model
assert t_model is not None
# Create parameter nodes for each dynamics component's parameters.
# Dynamics pars can be expressions (e.g. multi-cycle subcycle models
# where expFun_02_A = "-expFun_01_A").
dyn_param_nids: list[list[int]] = []
dyn_expr_pars: list[tuple[Par, int]] = [] # (par, node_id) for deferred wiring
for dyn_comp in t_model.components:
if dyn_comp.comp_type == "none":
dyn_param_nids.append([])
continue
comp_nids: list[int] = []
for dyn_par in dyn_comp.pars:
if _is_expression_par(dyn_par):
# Store the canonical (lmfit-prefixed) expression, not
# the raw YAML text, so expr_string matches the EXPR_REF
# edges wired below.
canonical_expr = _par_expression_string(dyn_par) or dyn_par.info[0]
dnid = b.add_node(
NodeKind.EXPRESSION,
dyn_par.name,
expr_string=canonical_expr,
value=_par_initial_value(dyn_par),
)
dyn_expr_pars.append((dyn_par, dnid))
elif _par_is_vary(dyn_par):
dnid = b.add_node(
NodeKind.OPT_PARAM,
dyn_par.name,
value=_par_initial_value(dyn_par),
vary=True,
bounds=_par_bounds(dyn_par),
)
else:
dnid = b.add_node(
NodeKind.STATIC_PARAM,
dyn_par.name,
value=_par_initial_value(dyn_par),
)
comp_nids.append(dnid)
dyn_param_nids.append(comp_nids)
# Wire EXPR_REF edges for dynamics expression pars.
# Dynamics expressions may be auto-prefixed by lmfit (e.g.
# "expFun_01_A" -> "GLP_01_A_expFun_01_A" or "parTEST_expFun_01_A").
# Use the canonical expression string and then match identifier refs
# against graph node names.
for dyn_par, expr_nid in dyn_expr_pars:
expr_str = _par_expression_string(dyn_par)
if expr_str is None:
continue
refs = _extract_expression_references(expr_str)
for ref_name in refs:
if ref_name in b.node_by_name:
b.add_edge(b.node_by_name[ref_name], expr_nid, EdgeKind.EXPR_REF)
# Create DYNAMICS_TRACE or CONVOLUTION nodes per dynamics component.
# Convolution components (gaussCONV, etc.) are CONVOLUTION nodes that
# wrap the resolved trace (conv(trace, kernel)), not addends.
trace_nids: list[int] = []
conv_nids: list[int] = []
for i, dyn_comp in enumerate(t_model.components):
if dyn_comp.comp_type == "none":
continue
node_name = f"{par.name}_dynamics"
if len(t_model.components) > 1:
node_name = f"{par.name}_{dyn_comp.comp_name}_dynamics"
if dyn_comp.comp_type == "conv":
nid = b.add_node(
NodeKind.CONVOLUTION,
node_name,
function_name=dyn_comp.fct_str,
package="time",
)
if dyn_comp.time is not None:
b.nodes[nid].arrays["kernel_time"] = dyn_comp.time
# Wire dynamics params -> conv node
for pos, dnid in enumerate(dyn_param_nids[i]):
b.add_edge(dnid, nid, EdgeKind.PARAM_INPUT, position=pos)
conv_nids.append(nid)
else:
nid = b.add_node(
NodeKind.DYNAMICS_TRACE,
node_name,
function_name=dyn_comp.fct_str,
package="time",
)
# Wire dynamics params -> trace node
for pos, dnid in enumerate(dyn_param_nids[i]):
b.add_edge(dnid, nid, EdgeKind.PARAM_INPUT, position=pos)
trace_nids.append(nid)
# Create PARAM_PLUS_TRACE node (base + sum of traces)
resolved_name = f"{par.name}_resolved"
resolved_nid = b.add_node(
NodeKind.PARAM_PLUS_TRACE,
resolved_name,
)
b.add_edge(base_nid, resolved_nid, EdgeKind.BASE_INPUT)
for trace_nid in trace_nids:
b.add_edge(trace_nid, resolved_nid, EdgeKind.TRACE_INPUT)
# Convolution nodes wrap the resolved trace: conv(resolved, kernel).
# Each CONVOLUTION takes TRACE_INPUT from the current resolved node
# and produces the new resolved value.
for conv_nid in conv_nids:
b.add_edge(resolved_nid, conv_nid, EdgeKind.TRACE_INPUT)
resolved_nid = conv_nid
resolved_param[par.name] = resolved_nid
#
def _rewire_param_input(
b: _GraphBuilder, target: int, *, old_source: int, new_source: int
) -> None:
"""Replace the source of a PARAM_INPUT edge targeting *target*."""
for edge in b.edges:
if (
edge.target == target
and edge.kind == EdgeKind.PARAM_INPUT
and edge.source == old_source
):
edge.source = new_source
return
#
def _rewire_trace_input(
b: _GraphBuilder, target: int, *, old_source: int, new_source: int
) -> None:
"""Replace the source of a TRACE_INPUT edge targeting *target*."""
for edge in b.edges:
if (
edge.target == target
and edge.kind == EdgeKind.TRACE_INPUT
and edge.source == old_source
):
edge.source = new_source
return
#
def _emit_profile_nodes(
b: _GraphBuilder,
comp: Component,
comp_nid: int,
resolved_param: dict[str, int],
profile_samples: dict[str, list[int]],
) -> int:
"""Emit per-sample evaluation subgraph for profiled components.
A component needs profile treatment when any of its params is
``p_vary`` (directly profiled) or ``expr_refs_profile_dep`` (an
expression referencing a profiled param on another component).
The interpreter evaluates the full component at each aux-axis point
and averages the resulting traces: ``mean_i(f(p_0, ..., p_i, ...))``.
Graph structure per aux point *i*:
- For ``p_vary`` params: a PROFILE_SAMPLE node computes the
parameter value at aux point *i* from the base value and the
profile function.
- For ``expr_refs_profile_dep`` params: a per-sample EXPRESSION
node evaluates the expression using the PROFILE_SAMPLE of the
referenced profiled param at the same aux point.
- A per-sample COMPONENT_EVAL evaluates the component function
with these per-sample inputs.
After all aux points:
- A component-level PROFILE_AVERAGE averages the per-sample
COMPONENT_EVAL traces and replaces the original component in
the combination graph.
``profile_samples`` is shared across components so that
``expr_refs_profile_dep`` params on one component can reference
the PROFILE_SAMPLE nodes created for a ``p_vary`` param on
a different component.
Returns
-------
int
Node id that replaces *comp_nid* in the combination graph,
or *comp_nid* unchanged if no profile treatment is needed.
"""
p_vary_pars = [p for p in comp.pars if p.p_vary and p.p_model is not None]
expr_dep_pars = [p for p in comp.pars if p.expr_refs_profile_dep]
if not p_vary_pars and not expr_dep_pars:
return comp_nid
# --- Determine aux_axis length ---
n_aux: int | None = None
if p_vary_pars:
aux_axis = p_vary_pars[0].p_model.aux_axis # type: ignore[union-attr]
if aux_axis is None:
return comp_nid
n_aux = len(aux_axis)
if n_aux is None:
# expr_dep_pars only — infer n_aux from an already-populated
# profile_samples entry.
for ref_name in _expr_dep_profile_refs(expr_dep_pars):
if ref_name in profile_samples:
n_aux = len(profile_samples[ref_name])
aux_axis = None # not needed for expr-dep-only components
break
if n_aux is None:
return comp_nid
pkg_name = _package_short_name(comp)
# --- Profile function parameter nodes for p_vary pars ---
prof_param_nids_by_par: dict[str, list[int]] = {}
for par in p_vary_pars:
p_model = par.p_model
assert p_model is not None
prof_param_nids: list[int] = []
prof_resolved: dict[str, int] = {}
for prof_comp in p_model.components:
if prof_comp.comp_type == "none":
continue
for prof_par in prof_comp.pars:
_emit_par_nodes(b, prof_par, prof_resolved)
prof_param_nids.append(prof_resolved[prof_par.name])
prof_param_nids_by_par[par.name] = prof_param_nids
# --- Per-sample nodes ---
sample_comp_nids: list[int] = []
for aux_i in range(n_aux):
# PROFILE_SAMPLE for each p_vary param
sample_nids: dict[str, int] = {} # par.name -> sample nid
for par in p_vary_pars:
base_nid = resolved_param[par.name]
p_aux = p_vary_pars[0].p_model.aux_axis # type: ignore[union-attr]
sample_nid = b.add_node(
NodeKind.PROFILE_SAMPLE,
f"{par.name}_profile_sample_{aux_i}",
arrays={"aux_axis": p_aux},
)
b.add_edge(base_nid, sample_nid, EdgeKind.PARAM_INPUT, position=0)
for pi, pnid in enumerate(prof_param_nids_by_par[par.name]):
b.add_edge(pnid, sample_nid, EdgeKind.PARAM_INPUT, position=pi + 1)
sample_nids[par.name] = sample_nid
profile_samples.setdefault(par.name, []).append(sample_nid)
# Per-sample EXPRESSION for each expr_refs_profile_dep param
expr_sample_nids: dict[str, int] = {} # par.name -> expr nid
for par in expr_dep_pars:
expr_str = _par_expression_string(par)
if expr_str is None:
continue
expr_nid = b.add_node(
NodeKind.EXPRESSION,
f"{par.name}_profile_expr_{aux_i}",
expr_string=expr_str,
value=_par_initial_value(par),
)
# Wire EXPR_REF to the per-sample PROFILE_SAMPLE of the
# referenced profiled param (not the averaged value).
refs = _extract_expression_references(expr_str)
for ref_name in refs:
if ref_name in profile_samples:
b.add_edge(
profile_samples[ref_name][aux_i],
expr_nid,
EdgeKind.EXPR_REF,
)
elif ref_name in resolved_param:
b.add_edge(
resolved_param[ref_name],
expr_nid,
EdgeKind.EXPR_REF,
)
expr_sample_nids[par.name] = expr_nid
# Per-sample COMPONENT_EVAL
sample_eval_nid = b.add_node(
NodeKind.COMPONENT_EVAL,
f"{comp.comp_name}_sample_{aux_i}",
function_name=comp.fct_str,
package=pkg_name,
)
# Wire params into the sample component eval
param_names = _component_param_names(comp)
for pos, _pname in enumerate(param_names):
par = comp.pars[pos]
if par.name in sample_nids:
src = sample_nids[par.name]
elif par.name in expr_sample_nids:
src = expr_sample_nids[par.name]
else:
src = resolved_param[par.name]
b.add_edge(src, sample_eval_nid, EdgeKind.PARAM_INPUT, position=pos)
sample_comp_nids.append(sample_eval_nid)
# --- Component-level PROFILE_AVERAGE over sample traces ---
comp_avg_nid = b.add_node(
NodeKind.PROFILE_AVERAGE,
f"{comp.comp_name}_profile_avg",
)
for sc_nid in sample_comp_nids:
b.add_edge(sc_nid, comp_avg_nid, EdgeKind.ADDEND)
return comp_avg_nid
#
def _expr_dep_profile_refs(expr_dep_pars: list[Par]) -> list[str]:
"""Collect profiled param names referenced by expr_refs_profile_dep pars."""
refs: list[str] = []
seen: set[str] = set()
for par in expr_dep_pars:
expr_str = _par_expression_string(par)
if expr_str is None:
continue
for ref in _extract_expression_references(expr_str):
if ref not in seen:
refs.append(ref)
seen.add(ref)
return refs
#
def _emit_subcycle_nodes(
b: _GraphBuilder,
comp: Component,
resolved_param: dict[str, int],
) -> None:
"""Emit SUBCYCLE_REMAP and SUBCYCLE_MASK nodes for subcycle dynamics.
SUBCYCLE_REMAP feeds into the DYNAMICS_TRACE (provides the remapped
time axis). SUBCYCLE_MASK consumes the DYNAMICS_TRACE output
(zeroes inactive regions).
"""
# Standalone TIME_1D dynamics model: subcycle info lives directly on the
# time component itself rather than on a parent Par.t_model.
comp_nid = b.node_by_name.get(comp.comp_name)
if (
comp_nid is not None
and comp.subcycle != 0
and comp.time_norm is not None
and comp.time_n_sub is not None
):
remap_nid = b.add_node(
NodeKind.SUBCYCLE_REMAP,
f"{comp.comp_name}_remap",
arrays={"time_norm": comp.time_norm},
)
b.add_edge(remap_nid, comp_nid, EdgeKind.TRACE_INPUT)
mask_nid = b.add_node(
NodeKind.SUBCYCLE_MASK,
f"{comp.comp_name}_mask",
arrays={"time_n_sub": comp.time_n_sub},
)
b.add_edge(comp_nid, mask_nid, EdgeKind.TRACE_INPUT)
for par in comp.pars:
if not par.t_vary or par.t_model is None:
continue
t_model = par.t_model
for dyn_comp in t_model.components:
if dyn_comp.subcycle == 0:
continue
if dyn_comp.time_norm is None or dyn_comp.time_n_sub is None:
continue
# Find the DYNAMICS_TRACE node for this dynamics component
trace_name = f"{par.name}_dynamics"
if len(t_model.components) > 1:
trace_name = f"{par.name}_{dyn_comp.comp_name}_dynamics"
trace_nid = b.node_by_name.get(trace_name)
if trace_nid is None:
continue
# SUBCYCLE_REMAP before the dynamics trace
remap_nid = b.add_node(
NodeKind.SUBCYCLE_REMAP,
f"{par.name}_{dyn_comp.comp_name}_remap",
arrays={"time_norm": dyn_comp.time_norm},
)
b.add_edge(remap_nid, trace_nid, EdgeKind.TRACE_INPUT)
# SUBCYCLE_MASK after the dynamics trace
mask_nid = b.add_node(
NodeKind.SUBCYCLE_MASK,
f"{par.name}_{dyn_comp.comp_name}_mask",
arrays={"time_n_sub": dyn_comp.time_n_sub},
)
b.add_edge(trace_nid, mask_nid, EdgeKind.TRACE_INPUT)
# Rewire PARAM_PLUS_TRACE to consume the masked trace
# instead of the raw DYNAMICS_TRACE.
resolved_nid = resolved_param.get(par.name)
if resolved_nid is not None:
_rewire_trace_input(
b, resolved_nid, old_source=trace_nid, new_source=mask_nid
)
#
def _emit_combination_nodes(
b: _GraphBuilder,
comp_nodes: list[tuple[Component, int, bool]],
) -> None:
"""Create SUM nodes that mirror the model's LIFO combine semantics.
Classification:
- **peaks**: ``comp_type == "add"`` — feed ``peak_sum``
- **backgrounds**: ``comp_type == "back"`` but *not* spectrum-fed
(Offset, LinBack) — feed ``total`` directly, *not* ``peak_sum``
- **spectrum-fed**: Shirley — receives ``SPECTRUM_INPUT`` from
``peak_sum``, feeds ``total``
- **convolution**: ``comp_type == "conv"`` — receives ``ADDEND``
from ``peak_sum``, feeds ``total``
This matches the spec example (lowered_evaluator.md lines 298-305):
only peaks contribute to ``peak_sum``; Offset/LinBack are added
at the ``total`` level.
"""
if not comp_nodes:
return
# Classify nodes
peaks: list[int] = [] # comp_type == "add"
backgrounds: list[int] = [] # Offset, LinBack (comp_type == "back", not Shirley)
spectrum_fed: list[int] = [] # Shirley
convolution: list[int] = [] # conv components
for comp, nid, is_shirley in comp_nodes:
if comp.comp_type == "conv":
convolution.append(nid)
elif is_shirley:
spectrum_fed.append(nid)
elif comp.comp_type == "add":
peaks.append(nid)
else:
# comp_type == "back" but not Shirley -> Offset, LinBack
backgrounds.append(nid)
# peak_sum: accumulates only peak (comp_type == "add") components
peak_sum_nid: int | None = None
if peaks:
peak_sum_nid = b.add_node(NodeKind.SUM, "peak_sum")
for nid in peaks:
b.add_edge(nid, peak_sum_nid, EdgeKind.ADDEND)
# Wire SPECTRUM_INPUT edges from peak_sum to spectrum-fed ops
if peak_sum_nid is not None:
for nid in spectrum_fed:
b.add_edge(peak_sum_nid, nid, EdgeKind.SPECTRUM_INPUT)
# Convolution nodes get ADDEND edge from peak_sum (signal to convolve)
if peak_sum_nid is not None:
for nid in convolution:
b.add_edge(peak_sum_nid, nid, EdgeKind.ADDEND)
# total: final sum of everything
# Collect all addends for total. Use peak_sum (not individual peaks)
# so the graph reflects the semantic grouping.
total_addends: list[int] = []
if peak_sum_nid is not None:
total_addends.append(peak_sum_nid)
total_addends.extend(backgrounds)
total_addends.extend(spectrum_fed)
total_addends.extend(convolution)
if len(total_addends) > 1:
total_nid = b.add_node(NodeKind.SUM, "total")
for nid in total_addends:
b.add_edge(nid, total_nid, EdgeKind.ADDEND)
# ---------------------------------------------------------------------------
# can_lower_2d
# ---------------------------------------------------------------------------
_LOWERABLE_2D_FUNCTIONS: frozenset[str] = frozenset(
{
"Gauss",
"GaussAsym",
"Lorentz",
"Voigt",
"GLS",
"GLP",
"DS",
"Offset",
"LinBack",
"Shirley",
}
)
# Base set of non-lowerable node kinds shared across backends.
# Backend-specific sets below carve out nodes as each backend gains
# support (e.g. profile nodes for 1D and 2D).
_NON_LOWERABLE_NODE_KINDS_BASE: frozenset[NodeKind] = frozenset(
{
NodeKind.CONVOLUTION,
NodeKind.PROFILE_SAMPLE,
NodeKind.PROFILE_AVERAGE,
NodeKind.SUBCYCLE_MASK,
NodeKind.SUBCYCLE_REMAP,
}
)
# 2D backend: profiles, resolved-trace time-domain convolution, and
# subcycle dynamics are compilable. Convolution lowerability is
# additionally gated per-node by ``_is_lowerable_convolution_2d``.
# Subcycle substeps are compiled away into dyn_sub_time_axes /
# dyn_sub_masks schedule arrays.
_NON_LOWERABLE_2D_NODE_KINDS: frozenset[NodeKind] = (
_NON_LOWERABLE_NODE_KINDS_BASE
- frozenset(
{
NodeKind.PROFILE_SAMPLE,
NodeKind.PROFILE_AVERAGE,
NodeKind.CONVOLUTION,
NodeKind.SUBCYCLE_MASK,
NodeKind.SUBCYCLE_REMAP,
}
)
)
#
def _is_lowerable_convolution_2d(node: GraphNode, graph: GraphIR) -> bool:
"""Return True if a CONVOLUTION node can be lowered by the 2D backend.
Lowering contract -- all of:
1. Time-domain kernel (``package == "time"``) with a registered
kernel function and a ``kernel_time`` array populated at
graph-build time.
2. Resolved-trace shape: the node has exactly one ``TRACE_INPUT``
ancestor, and walking that chain terminates at a
``PARAM_PLUS_TRACE``.
The structural check (item 2) is the same walk the scheduler
performs in :func:`_resolve_convolution_target_row`, kept in sync so
that unsupported conv flavours (spectrum-level ``comp_type="conv"``
components that receive ``ADDEND`` from ``peak_sum``, externally
produced graphs with unusual conv topology, etc.) fall back to MCP
cleanly instead of raising inside ``schedule_2d``.
"""
if (
node.kind != NodeKind.CONVOLUTION
or node.package != "time"
or node.function_name not in _FUNCTION_NAME_TO_CONV_KERNEL
or "kernel_time" not in node.arrays
):
return False
id_to_node = {n.id: n for n in graph.nodes}
current = node
# Bounded walk: guards against malformed chains and cycles.
for _ in range(len(id_to_node)):
trace_parents = [
e
for e in graph.edges
if e.target == current.id and e.kind == EdgeKind.TRACE_INPUT
]
if len(trace_parents) != 1:
return False
parent = id_to_node[trace_parents[0].source]
if parent.kind == NodeKind.PARAM_PLUS_TRACE:
return True
if parent.kind != NodeKind.CONVOLUTION:
return False
current = parent
return False
#
def _is_lowerable_subcycle_2d(node: GraphNode, graph: GraphIR) -> bool:
"""Return True if a SUBCYCLE_REMAP / MASK node meets the lowering contract.
The scheduler replaces per-substep ``graph.time`` / all-ones defaults
with the payload carried by these nodes; once ``can_lower_2d``
vouches, ``schedule_2d`` reads the arrays unconditionally. All of:
1. The required payload array is present on the node and is a 1D
``float64``-compatible ndarray of length ``len(graph.time)``.
2. Canonical topology:
SUBCYCLE_REMAP -> DYNAMICS_TRACE (TRACE_INPUT)
DYNAMICS_TRACE -> SUBCYCLE_MASK (TRACE_INPUT, if present)
SUBCYCLE_MASK -> PARAM_PLUS_TRACE (TRACE_INPUT)
Anything off-pattern (multiple consumers, chain doesn't end at a
PPT, etc.) is rejected so malformed graphs fall back to MCP
cleanly instead of silently producing wrong output.
"""
if node.kind not in (NodeKind.SUBCYCLE_REMAP, NodeKind.SUBCYCLE_MASK):
return False
if graph.time is None:
return False
n_time = len(graph.time)
key = "time_norm" if node.kind == NodeKind.SUBCYCLE_REMAP else "time_n_sub"
arr = node.arrays.get(key)
if arr is None or not isinstance(arr, np.ndarray):
return False
if arr.ndim != 1 or arr.shape[0] != n_time:
return False
id_to_node = {n.id: n for n in graph.nodes}
if node.kind == NodeKind.SUBCYCLE_REMAP:
outs = [
e
for e in graph.edges
if e.source == node.id and e.kind == EdgeKind.TRACE_INPUT
]
if len(outs) != 1:
return False
return id_to_node[outs[0].target].kind == NodeKind.DYNAMICS_TRACE
# SUBCYCLE_MASK: one TRACE_INPUT in from DYNAMICS_TRACE, one out to PPT.
ins = [
e for e in graph.edges if e.target == node.id and e.kind == EdgeKind.TRACE_INPUT
]
outs = [
e for e in graph.edges if e.source == node.id and e.kind == EdgeKind.TRACE_INPUT
]
if len(ins) != 1 or len(outs) != 1:
return False
return (
id_to_node[ins[0].source].kind == NodeKind.DYNAMICS_TRACE
and id_to_node[outs[0].target].kind == NodeKind.PARAM_PLUS_TRACE
)
#
[docs]
def can_lower_2d(graph: GraphIR) -> bool:
"""Check whether the 2D NumPy backend can compile this graph.
Parameters
----------
graph : GraphIR
The model graph to check.
Returns
-------
bool
True if ``schedule_2d`` can compile this graph.
"""
if graph.domain != DomainKind.ENERGY_TIME_2D:
return False
for node in graph.nodes:
# Reject future node types not yet compilable
if node.kind in _NON_LOWERABLE_2D_NODE_KINDS:
return False
# Convolution nodes are per-node gated: only time-domain kernels
# with a registered kernel function, frozen support, and the
# resolved-trace topology the scheduler expects are lowered.
if node.kind == NodeKind.CONVOLUTION:
if not _is_lowerable_convolution_2d(node, graph):
return False
# Subcycle nodes are per-node gated: required payload arrays and
# canonical REMAP -> DYNAMICS_TRACE -> MASK? -> PPT topology.
if node.kind in (NodeKind.SUBCYCLE_REMAP, NodeKind.SUBCYCLE_MASK):
if not _is_lowerable_subcycle_2d(node, graph):
return False
# Check component functions are supported
if node.kind in (NodeKind.COMPONENT_EVAL, NodeKind.SPECTRUM_FED_OP):
if node.function_name not in _LOWERABLE_2D_FUNCTIONS:
return False
# Check dynamics functions are supported
if node.kind == NodeKind.DYNAMICS_TRACE:
if node.function_name not in _FUNCTION_NAME_TO_DYN_FUNC:
return False
# Check profile nodes have required aux_axis array
if node.kind == NodeKind.PROFILE_SAMPLE:
if "aux_axis" not in node.arrays:
return False
# Check expressions are arithmetic-only (defer full AST check
# to the expression compiler; here just reject obvious non-starters)
if node.kind == NodeKind.EXPRESSION and node.expr_string is not None:
if not _is_arithmetic_expression(node.expr_string):
return False
return True
# Node kinds that are never valid in 1D energy models. 1D models have
# no time axis, so DYNAMICS_TRACE / PARAM_PLUS_TRACE should not appear.
# Start from the 2D blocklist so future unsupported node kinds propagate
# automatically, then carve out the profile nodes that 1D lowers
# explicitly.
_NON_LOWERABLE_1D_NODE_KINDS: frozenset[NodeKind] = (
_NON_LOWERABLE_NODE_KINDS_BASE
- frozenset({NodeKind.PROFILE_SAMPLE, NodeKind.PROFILE_AVERAGE})
| frozenset(
{
NodeKind.DYNAMICS_TRACE,
NodeKind.PARAM_PLUS_TRACE,
}
)
)
#
[docs]
def can_lower_1d(graph: GraphIR) -> bool:
"""Check whether the 1D NumPy backend can compile this graph.
Parameters
----------
graph : GraphIR
The model graph to check.
Returns
-------
bool
True if ``schedule_1d`` can compile this graph.
"""
if graph.domain != DomainKind.ENERGY_1D:
return False
for node in graph.nodes:
if node.kind in _NON_LOWERABLE_1D_NODE_KINDS:
return False
if node.kind in (NodeKind.COMPONENT_EVAL, NodeKind.SPECTRUM_FED_OP):
if node.function_name not in _LOWERABLE_2D_FUNCTIONS:
return False
if node.kind == NodeKind.PROFILE_SAMPLE:
if "aux_axis" not in node.arrays:
return False
if node.kind == NodeKind.EXPRESSION and node.expr_string is not None:
if not _is_arithmetic_expression(node.expr_string):
return False
return True
#
def _is_arithmetic_expression(expr_string: str) -> bool:
"""Return True if the expression uses only arithmetic ops and param refs.
Does a lightweight check via the ``ast`` module. Rejects function
calls, attribute access, subscripts, etc.
"""
import ast
try:
tree = ast.parse(expr_string, mode="eval")
except SyntaxError:
return False
for node in ast.walk(tree):
if isinstance(
node,
(
ast.Expression,
ast.BinOp,
ast.UnaryOp,
ast.Constant,
ast.Name,
ast.Add,
ast.Sub,
ast.Mult,
ast.Div,
ast.Pow,
ast.USub,
ast.UAdd,
ast.Load,
),
):
continue
return False
return True
# ---------------------------------------------------------------------------
# Symbolic expression compiler
# ---------------------------------------------------------------------------
#
#
[docs]
@dataclass(frozen=True)
class SymbolicRPN:
"""Symbolic RPN program with parameter references by name.
This is the *frontend* output of the expression compiler.
``schedule_2d`` binds names to trace-matrix row indices and
produces the final ``ExprProgram``.
Each instruction is a ``(ExprNodeKind, operand)`` pair:
- ``CONST``: operand is the float value itself
- ``PARAM_REF``: operand is the parameter name (str)
- Operators: operand is ``None``
"""
instructions: list[tuple[ExprNodeKind, float | str | None]]
referenced_names: list[str] # unique param names in order of first appearance
#
[docs]
def compile_expr_symbolic(expr_string: str) -> SymbolicRPN:
"""Parse an arithmetic expression string into symbolic RPN.
Uses the Python ``ast`` module to walk the expression tree and
emit a postfix instruction sequence. Parameter references are
kept as name strings; the scheduler resolves them to row indices.
Parameters
----------
expr_string : str
Arithmetic expression (e.g. ``"3/4*GLP_01_A"``).
Returns
-------
SymbolicRPN
The symbolic RPN program.
Raises
------
ValueError
If the expression contains unsupported AST nodes.
"""
import ast
tree = ast.parse(expr_string, mode="eval")
instructions: list[tuple[ExprNodeKind, float | str | None]] = []
names_seen: dict[str, None] = {} # ordered set via dict
_OP_MAP: dict[type, ExprNodeKind] = {
ast.Add: ExprNodeKind.ADD,
ast.Sub: ExprNodeKind.SUB,
ast.Mult: ExprNodeKind.MUL,
ast.Div: ExprNodeKind.DIV,
ast.Pow: ExprNodeKind.POW,
}
#
def _walk(node: ast.AST) -> None:
if isinstance(node, ast.Expression):
_walk(node.body)
elif isinstance(node, ast.Constant):
if not isinstance(node.value, (int, float)):
raise ValueError(
f"Unsupported constant {node.value!r}"
f" in expression: {expr_string!r}"
)
instructions.append((ExprNodeKind.CONST, float(node.value)))
elif isinstance(node, ast.Name):
instructions.append((ExprNodeKind.PARAM_REF, node.id))
names_seen.setdefault(node.id, None)
elif isinstance(node, ast.BinOp):
_walk(node.left)
_walk(node.right)
op_kind = _OP_MAP.get(type(node.op))
if op_kind is None:
raise ValueError(
f"Unsupported binary op {type(node.op).__name__!r}"
f" in expression: {expr_string!r}"
)
instructions.append((op_kind, None))
elif isinstance(node, ast.UnaryOp):
if isinstance(node.op, ast.USub):
_walk(node.operand)
instructions.append((ExprNodeKind.NEG, None))
elif isinstance(node.op, ast.UAdd):
_walk(node.operand)
# UAdd is a no-op
else:
raise ValueError(
f"Unsupported unary op {type(node.op).__name__!r}"
f" in expression: {expr_string!r}"
)
else:
raise ValueError(
f"Unsupported AST node {type(node).__name__!r}"
f" in expression: {expr_string!r}"
)
_walk(tree)
return SymbolicRPN(
instructions=instructions,
referenced_names=list(names_seen),
)
#
def _bind_expr_to_rows(
symbolic: SymbolicRPN,
name_to_row: dict[str, int],
) -> ExprProgram:
"""Convert a symbolic RPN program to a row-bound ExprProgram.
Parameters
----------
symbolic : SymbolicRPN
The symbolic RPN from ``compile_expr_symbolic``.
name_to_row : dict[str, int]
Maps parameter names to trace-matrix row indices.
Returns
-------
ExprProgram
Row-bound RPN program ready for the evaluator.
"""
flat: list[int] = []
for kind, operand in symbolic.instructions:
flat.append(int(kind))
if kind == ExprNodeKind.CONST:
assert isinstance(operand, (int, float))
flat.append(int(np.float64(operand).view(np.int64)))
elif kind == ExprNodeKind.PARAM_REF:
assert isinstance(operand, str)
flat.append(name_to_row[operand])
else:
flat.append(0)
return ExprProgram(instructions=np.array(flat, dtype=np.int64))
# ---------------------------------------------------------------------------
# schedule_2d
# ---------------------------------------------------------------------------
#
def _walk_convolution_to_param_plus_trace(
conv_node: GraphNode,
edges: list[GraphEdge],
id_to_node: dict[int, GraphNode],
) -> GraphNode:
"""Walk a conv chain back to its underlying PARAM_PLUS_TRACE node."""
current = conv_node
# Bounded walk: guards against pathological cycles in malformed graphs.
for _ in range(len(id_to_node)):
trace_parents = [
e
for e in edges
if e.target == current.id and e.kind == EdgeKind.TRACE_INPUT
]
if len(trace_parents) != 1:
raise ValueError(
f"CONVOLUTION node {conv_node.name!r} does not have a"
f" single TRACE_INPUT ancestor (found {len(trace_parents)})"
)
parent = id_to_node[trace_parents[0].source]
if parent.kind == NodeKind.PARAM_PLUS_TRACE:
return parent
if parent.kind != NodeKind.CONVOLUTION:
raise ValueError(
f"CONVOLUTION chain for {conv_node.name!r} walks through"
f" unsupported node kind {parent.kind.name!r}"
)
current = parent
raise ValueError(f"CONVOLUTION chain for {conv_node.name!r} exceeds graph size")
#
def _resolve_convolution_target_row(
conv_node: GraphNode,
edges: list[GraphEdge],
id_to_node: dict[int, GraphNode],
name_to_row: dict[str, int],
) -> int:
"""Walk a conv chain back to the underlying PARAM_PLUS_TRACE row.
CONVOLUTION nodes wrap a resolved trace: their TRACE_INPUT source is
either the PARAM_PLUS_TRACE (single conv) or an earlier CONVOLUTION
(chained conv). The lowered plan rewrites the PPT row in place, so
every conv in a chain shares the same target row.
"""
ppt_node = _walk_convolution_to_param_plus_trace(conv_node, edges, id_to_node)
return name_to_row[ppt_node.name]
#
def _topological_sort(graph: GraphIR) -> list[int]:
"""Topological sort of graph node IDs.
Tie-breaker: when two nodes have no dependency ordering between
them, sort by ``node.source_order`` (lower first). This makes the
schedule deterministic.
Does NOT assume ``node.id == list index``.
"""
import heapq
id_to_node: dict[int, GraphNode] = {n.id: n for n in graph.nodes}
in_degree: dict[int, int] = {n.id: 0 for n in graph.nodes}
children: dict[int, list[int]] = {n.id: [] for n in graph.nodes}
for edge in graph.edges:
children[edge.source].append(edge.target)
in_degree[edge.target] += 1
# Priority queue: (source_order, node_id) — lower source_order first
heap: list[tuple[int, int]] = []
for node in graph.nodes:
if in_degree[node.id] == 0:
heapq.heappush(heap, (node.source_order, node.id))
result: list[int] = []
while heap:
_order, nid = heapq.heappop(heap)
result.append(nid)
for child in children[nid]:
in_degree[child] -= 1
if in_degree[child] == 0:
heapq.heappush(
heap,
(id_to_node[child].source_order, child),
)
n = len(graph.nodes)
if len(result) != n:
raise ValueError(f"Graph has a cycle: sorted {len(result)} of {n} nodes")
return result
#
[docs]
def schedule_2d(graph: GraphIR) -> ScheduledPlan2D:
"""Compile a GraphIR into a flat 2D execution schedule.
Parameters
----------
graph : GraphIR
Must pass ``can_lower_2d(graph)``.
Returns
-------
ScheduledPlan2D
Packed-array execution schedule for ``evaluate_2d``.
Raises
------
ValueError
If the graph cannot be lowered (domain, unsupported nodes, etc.).
"""
if not can_lower_2d(graph):
raise ValueError("Graph cannot be lowered to 2D backend")
assert graph.energy is not None
assert graph.time is not None
n_time = len(graph.time)
# ------------------------------------------------------------------ #
# 1. Topological sort #
# ------------------------------------------------------------------ #
topo_order = _topological_sort(graph)
# Build id -> node lookup. Do NOT assume node.id == list index;
# external graph producers may use arbitrary ids.
id_to_node: dict[int, GraphNode] = {n.id: n for n in graph.nodes}
# ------------------------------------------------------------------ #
# 2. Assign trace-matrix rows #
# ------------------------------------------------------------------ #
# Nodes that occupy a row in the trace matrix: all parameter-like
# nodes (STATIC_PARAM, OPT_PARAM, PARAM_PLUS_TRACE, EXPRESSION).
_ROW_KINDS = frozenset(
{
NodeKind.STATIC_PARAM,
NodeKind.OPT_PARAM,
NodeKind.PARAM_PLUS_TRACE,
NodeKind.EXPRESSION,
}
)
# Collect nodes that need rows, bucketed for ordering:
# opt params first, then static, then computed (PPT + EXPR)
opt_nodes: list[GraphNode] = []
static_nodes: list[GraphNode] = []
computed_nodes: list[GraphNode] = []
for nid in topo_order:
node = id_to_node[nid]
if node.kind not in _ROW_KINDS or _is_profile_expr_node(node):
continue
if node.kind == NodeKind.OPT_PARAM and node.vary:
opt_nodes.append(node)
elif node.kind in (NodeKind.STATIC_PARAM, NodeKind.OPT_PARAM):
# OPT_PARAM with vary=False is treated like static
static_nodes.append(node)
else:
computed_nodes.append(node)
all_row_nodes = opt_nodes + static_nodes + computed_nodes
n_params = len(all_row_nodes)
name_to_row: dict[str, int] = {}
for row, node in enumerate(all_row_nodes):
name_to_row[node.name] = row
row_is_constant = np.zeros(n_params, dtype=np.bool_)
for node in static_nodes:
row_is_constant[name_to_row[node.name]] = True
# CONVOLUTION nodes rewrite a PARAM_PLUS_TRACE row in place, so they
# alias their underlying PPT's row rather than occupying a new one.
# Downstream edges (component PARAM_INPUT) reference the last node in
# the conv chain; this aliasing lets name_to_row resolve them to the
# PPT row. row_is_constant becomes False for conv-touched rows because
# kernel params may vary per call.
for node in graph.nodes:
if node.kind != NodeKind.CONVOLUTION:
continue
ppt_row = _resolve_convolution_target_row(
node, graph.edges, id_to_node, name_to_row
)
name_to_row[node.name] = ppt_row
row_is_constant[ppt_row] = False
# opt_indices and opt_param_names
n_opt = len(opt_nodes)
opt_indices = np.arange(n_opt, dtype=np.intp)
opt_param_names = [n.name for n in opt_nodes]
# ------------------------------------------------------------------ #
# 3. Compile dynamics subgraphs (grouped by PARAM_PLUS_TRACE) #
# ------------------------------------------------------------------ #
# First pass: extract per-DYNAMICS_TRACE info and find which PPT
# each trace feeds.
dyn_trace_nodes = [
id_to_node[nid]
for nid in topo_order
if id_to_node[nid].kind == NodeKind.DYNAMICS_TRACE
]
# Per-trace (substep) info, indexed by position in dyn_trace_nodes.
sub_func_ids: list[int] = []
sub_param_row_lists: list[list[int]] = []
sub_ppt_row: list[int] = [] # which PPT row this trace targets
sub_base_row: list[int] = []
# Per-substep subcycle arrays. ``None`` marks "no subcycle"
# (use defaults: ``graph.time`` axis, all-ones mask).
sub_time_axis: list[np.ndarray | None] = []
sub_mask: list[np.ndarray | None] = []
for dyn_node in dyn_trace_nodes:
assert dyn_node.function_name is not None
func_kind = _FUNCTION_NAME_TO_DYN_FUNC.get(dyn_node.function_name)
if func_kind is None:
raise ValueError(f"Unknown dynamics function: {dyn_node.function_name!r}")
sub_func_ids.append(int(func_kind))
# Param rows (from PARAM_INPUT edges, ordered by position)
param_edges = sorted(
(
e
for e in graph.edges
if e.target == dyn_node.id and e.kind == EdgeKind.PARAM_INPUT
),
key=lambda e: e.position or 0,
)
sub_param_row_lists.append(
[name_to_row[id_to_node[e.source].name] for e in param_edges]
)
# Subcycle REMAP upstream: incoming TRACE_INPUT from SUBCYCLE_REMAP
# carries the ``time_norm`` axis that replaces ``graph.time`` for
# this substep. Invariants (array present, (n_time,) shape) are
# enforced by ``_is_lowerable_subcycle_2d`` before we get here.
remap_time_axis: np.ndarray | None = None
for e in graph.edges:
if e.target == dyn_node.id and e.kind == EdgeKind.TRACE_INPUT:
src = id_to_node[e.source]
if src.kind == NodeKind.SUBCYCLE_REMAP:
remap_time_axis = src.arrays["time_norm"]
break
sub_time_axis.append(remap_time_axis)
# Target: the PARAM_PLUS_TRACE node this trace feeds. For
# subcycle-bearing traces, a SUBCYCLE_MASK sits between the
# DYNAMICS_TRACE and its PARAM_PLUS_TRACE consumer; walk past it
# and record the mask array for the evaluator.
ppt_edges = [
e
for e in graph.edges
if e.source == dyn_node.id and e.kind == EdgeKind.TRACE_INPUT
]
assert len(ppt_edges) == 1, (
f"DYNAMICS_TRACE '{dyn_node.name}' must feed exactly one"
f" downstream node, found {len(ppt_edges)}"
)
downstream = id_to_node[ppt_edges[0].target]
mask_array: np.ndarray | None = None
if downstream.kind == NodeKind.SUBCYCLE_MASK:
mask_array = downstream.arrays["time_n_sub"]
mask_out = [
e
for e in graph.edges
if e.source == downstream.id and e.kind == EdgeKind.TRACE_INPUT
]
assert len(mask_out) == 1, (
f"SUBCYCLE_MASK '{downstream.name}' must feed exactly one"
f" PARAM_PLUS_TRACE, found {len(mask_out)}"
)
ppt_node = id_to_node[mask_out[0].target]
else:
ppt_node = downstream
sub_mask.append(mask_array)
assert ppt_node.kind == NodeKind.PARAM_PLUS_TRACE, (
f"DYNAMICS_TRACE '{dyn_node.name}' chain does not terminate at"
f" PARAM_PLUS_TRACE (got {ppt_node.kind.name})"
)
sub_ppt_row.append(name_to_row[ppt_node.name])
# Base row: the BASE_INPUT to the PARAM_PLUS_TRACE
base_edges = [
e
for e in graph.edges
if e.target == ppt_node.id and e.kind == EdgeKind.BASE_INPUT
]
assert len(base_edges) == 1
sub_base_row.append(name_to_row[id_to_node[base_edges[0].source].name])
# Second pass: group substeps by target PPT row, preserving topo order.
# Each unique PPT row becomes one dynamics group.
seen_ppt: dict[int, int] = {} # ppt_row -> group index
group_target_rows: list[int] = []
group_base_rows: list[int] = []
group_substeps: list[list[int]] = [] # group_idx -> [substep indices]
for sub_idx, ppt_row in enumerate(sub_ppt_row):
if ppt_row not in seen_ppt:
gid = len(group_target_rows)
seen_ppt[ppt_row] = gid
group_target_rows.append(ppt_row)
group_base_rows.append(sub_base_row[sub_idx])
group_substeps.append([])
group_substeps[seen_ppt[ppt_row]].append(sub_idx)
n_dyn_groups = len(group_target_rows)
n_substeps = len(dyn_trace_nodes)
# Pack substep arrays (flat, ordered by group then topo within group)
flat_sub_indices: list[int] = []
indptr: list[int] = [0]
for subs in group_substeps:
flat_sub_indices.extend(subs)
indptr.append(indptr[-1] + len(subs))
max_dyn_params = (
max(len(r) for r in sub_param_row_lists) if sub_param_row_lists else 0
)
dyn_sub_param_rows = np.full((n_substeps, max_dyn_params), -1, dtype=np.intp)
dyn_sub_n_params_list: list[int] = []
dyn_sub_func_id_list: list[int] = []
# Per-substep subcycle schedule arrays. Default: ``graph.time`` and
# an all-ones mask. Overridden from upstream SUBCYCLE_REMAP / MASK
# nodes. Float64 kept uniform so the evaluator hot loop has a single
# dtype contract.
dyn_sub_time_axes = np.empty((n_substeps, n_time), dtype=np.float64)
dyn_sub_time_axes[:, :] = graph.time
dyn_sub_masks = np.ones((n_substeps, n_time), dtype=np.float64)
for flat_i, orig_i in enumerate(flat_sub_indices):
rows = sub_param_row_lists[orig_i]
dyn_sub_param_rows[flat_i, : len(rows)] = rows
dyn_sub_n_params_list.append(len(rows))
dyn_sub_func_id_list.append(sub_func_ids[orig_i])
if sub_time_axis[orig_i] is not None:
dyn_sub_time_axes[flat_i, :] = sub_time_axis[orig_i]
if sub_mask[orig_i] is not None:
dyn_sub_masks[flat_i, :] = sub_mask[orig_i]
dyn_group_target_row = np.array(group_target_rows, dtype=np.intp)
dyn_group_base_row = np.array(group_base_rows, dtype=np.intp)
dyn_group_indptr = np.array(indptr, dtype=np.intp)
dyn_sub_func_id = np.array(dyn_sub_func_id_list, dtype=np.intp)
dyn_sub_n_params = np.array(dyn_sub_n_params_list, dtype=np.intp)
# ------------------------------------------------------------------ #
# 4. Compile expressions (topological order) #
# ------------------------------------------------------------------ #
expr_nodes_topo = [
id_to_node[nid]
for nid in topo_order
if id_to_node[nid].kind == NodeKind.EXPRESSION
and not _is_profile_expr_node(id_to_node[nid])
]
n_expressions = len(expr_nodes_topo)
# Build per-expression name→row maps from EXPR_REF edges.
# The graph's EXPR_REF edges are the single source of truth for
# which node each name in the expression resolves to. We walk
# those edges to build a per-node override map, then compile the
# symbolic RPN (for operator structure) and bind names to rows
# using the edge-derived map.
#
# Index: expr_node.id -> {identifier_in_expr_string -> row}
expr_ref_maps: dict[int, dict[str, int]] = {}
for expr_node in expr_nodes_topo:
ref_map: dict[str, int] = {}
for edge in graph.edges:
if edge.target != expr_node.id or edge.kind != EdgeKind.EXPR_REF:
continue
src_node = id_to_node[edge.source]
src_row = name_to_row[src_node.name]
# The expr_string references names that appear as identifiers.
# The EXPR_REF source node's name is the canonical form (may
# include "_resolved" suffix or prefixed dynamics names).
# Walk the expression's identifier tokens to find which one
# this edge satisfies.
assert expr_node.expr_string is not None
for token in _extract_expression_references(expr_node.expr_string):
if token == src_node.name:
ref_map[token] = src_row
break
else:
# The identifier in the expression doesn't match the
# source node name literally. This shouldn't happen if
# build_graph stored the canonical expr_string, but
# fall back: try matching any identifier that names a
# parameter whose resolved row IS this source row.
for token in _extract_expression_references(expr_node.expr_string):
if token in name_to_row and name_to_row[token] == src_row:
ref_map[token] = src_row
break
# Also try: source is a resolved node, token is the
# base param name
if token not in ref_map and src_node.name.endswith("_resolved"):
base_name = src_node.name.removesuffix("_resolved")
if token == base_name:
ref_map[token] = src_row
break
expr_ref_maps[expr_node.id] = ref_map
# Compile and bind expressions
expr_programs: list[ExprProgram] = []
expr_target_rows_list: list[int] = []
for expr_node in expr_nodes_topo:
assert expr_node.expr_string is not None
symbolic = compile_expr_symbolic(expr_node.expr_string)
# Build the binding map: start with the base name_to_row, then
# overlay the edge-derived overrides for this expression.
binding = dict(name_to_row)
binding.update(expr_ref_maps.get(expr_node.id, {}))
program = _bind_expr_to_rows(symbolic, binding)
expr_programs.append(program)
target_row = name_to_row[expr_node.name]
expr_target_rows_list.append(target_row)
row_is_constant[target_row] = all(
row_is_constant[binding[name]] for name in symbolic.referenced_names
)
expr_target_rows = np.array(expr_target_rows_list, dtype=np.intp)
# Build resolution schedule: map DYNAMICS_TRACE node ids to their
# group index, and emit each group exactly once (on the *last* trace
# in that group, so all expression deps are resolved first).
_dyn_id_to_ppt = {n.id: sub_ppt_row[i] for i, n in enumerate(dyn_trace_nodes)}
_ppt_to_group = seen_ppt # ppt_row -> group index
_expr_id_to_idx = {n.id: i for i, n in enumerate(expr_nodes_topo)}
# Find which DYNAMICS_TRACE is the last in each group (in topo order).
# We emit the group step at that point so all substep expressions are
# resolved before the group evaluates.
_last_dyn_in_group: dict[int, int] = {} # group_idx -> last dyn node id
for nid in topo_order:
if nid in _dyn_id_to_ppt:
gid = _ppt_to_group[_dyn_id_to_ppt[nid]]
_last_dyn_in_group[gid] = nid
# Compile CONVOLUTION nodes into a conv program. Walk topo order
# so chained conv steps stay in the right sequence and the
# kind=2 resolution entries are emitted after the dyn/expr step that
# populates the target PARAM_PLUS_TRACE row.
conv_nodes_topo: list[GraphNode] = [
id_to_node[nid]
for nid in topo_order
if id_to_node[nid].kind == NodeKind.CONVOLUTION
]
_conv_id_to_idx: dict[int, int] = {n.id: i for i, n in enumerate(conv_nodes_topo)}
n_conv_steps = len(conv_nodes_topo)
conv_target_rows_list: list[int] = []
conv_func_ids_list: list[int] = []
conv_param_indptr_list: list[int] = [0]
conv_param_rows_list: list[int] = []
conv_support_indptr_list: list[int] = [0]
conv_support_values_list: list[float] = []
for conv_node in conv_nodes_topo:
assert conv_node.function_name is not None
kernel_kind = _FUNCTION_NAME_TO_CONV_KERNEL.get(conv_node.function_name)
if kernel_kind is None:
raise ValueError(f"Unknown convolution kernel: {conv_node.function_name!r}")
conv_func_ids_list.append(int(kernel_kind))
target_row = _resolve_convolution_target_row(
conv_node, graph.edges, id_to_node, name_to_row
)
conv_target_rows_list.append(target_row)
kernel_param_edges = sorted(
(
e
for e in graph.edges
if e.target == conv_node.id and e.kind == EdgeKind.PARAM_INPUT
),
key=lambda e: e.position or 0,
)
for edge in kernel_param_edges:
src_node = id_to_node[edge.source]
conv_param_rows_list.append(name_to_row[src_node.name])
conv_param_indptr_list.append(len(conv_param_rows_list))
support = conv_node.arrays.get("kernel_time")
if support is None:
raise ValueError(
f"CONVOLUTION node {conv_node.name!r} is missing"
" kernel_time array (required for lowered convolution)"
)
conv_support_values_list.extend(float(v) for v in np.asarray(support))
conv_support_indptr_list.append(len(conv_support_values_list))
conv_target_rows = np.array(conv_target_rows_list, dtype=np.intp)
conv_func_ids = np.array(conv_func_ids_list, dtype=np.intp)
conv_param_indptr = np.array(conv_param_indptr_list, dtype=np.intp)
conv_param_rows = np.array(conv_param_rows_list, dtype=np.intp)
conv_support_indptr = np.array(conv_support_indptr_list, dtype=np.intp)
conv_support_values = np.array(conv_support_values_list, dtype=np.float64)
resolution_kinds_list: list[int] = []
resolution_indices_list: list[int] = []
emitted_groups: set[int] = set()
for nid in topo_order:
if nid in _dyn_id_to_ppt:
gid = _ppt_to_group[_dyn_id_to_ppt[nid]]
if _last_dyn_in_group[gid] == nid and gid not in emitted_groups:
resolution_kinds_list.append(0) # dynamics group
resolution_indices_list.append(gid)
emitted_groups.add(gid)
elif nid in _expr_id_to_idx:
resolution_kinds_list.append(1) # expression
resolution_indices_list.append(_expr_id_to_idx[nid])
elif nid in _conv_id_to_idx:
resolution_kinds_list.append(2) # convolution step
resolution_indices_list.append(_conv_id_to_idx[nid])
resolution_kinds = np.array(resolution_kinds_list, dtype=np.int8)
resolution_indices = np.array(resolution_indices_list, dtype=np.intp)
# ------------------------------------------------------------------ #
# 4b. Compile PROFILE_SAMPLE groups #
# ------------------------------------------------------------------ #
# Build edge indexes for profile compilation (mirrors schedule_1d).
param_edges_by_target: dict[int, list[GraphEdge]] = {}
expr_ref_edges_by_target: dict[int, list[GraphEdge]] = {}
addend_edges_by_target: dict[int, list[GraphEdge]] = {}
spectrum_input_targets: set[int] = set()
for edge in graph.edges:
if edge.kind == EdgeKind.PARAM_INPUT:
param_edges_by_target.setdefault(edge.target, []).append(edge)
elif edge.kind == EdgeKind.EXPR_REF:
expr_ref_edges_by_target.setdefault(edge.target, []).append(edge)
elif edge.kind == EdgeKind.ADDEND:
addend_edges_by_target.setdefault(edge.target, []).append(edge)
elif edge.kind == EdgeKind.SPECTRUM_INPUT:
spectrum_input_targets.add(edge.target)
profile_sample_groups: dict[str, list[GraphNode]] = {}
for nid in topo_order:
node = id_to_node[nid]
if node.kind != NodeKind.PROFILE_SAMPLE:
continue
group_name = _profile_group_name(node.name, "profile_sample")
profile_sample_groups.setdefault(group_name, []).append(node)
plan_aux_axis = np.zeros(0, dtype=np.float64)
n_aux = 0
profile_sample_base_rows_list: list[int] = []
profile_sample_component_indptr_list: list[int] = [0]
profile_component_func_ids_list: list[int] = []
profile_component_param_indptr_list: list[int] = [0]
profile_component_param_rows_list: list[int] = []
profile_sample_is_constant_list: list[bool] = []
profile_sample_group_idx: dict[str, int] = {}
for group_name, sample_nodes in profile_sample_groups.items():
sample_nodes_sorted = sorted(
sample_nodes,
key=lambda node: _profile_group_index(node.name, "profile_sample"),
)
aux_indices = [
_profile_group_index(node.name, "profile_sample")
for node in sample_nodes_sorted
]
if aux_indices != list(range(len(sample_nodes_sorted))):
raise ValueError(
f"PROFILE_SAMPLE nodes for {group_name!r} do not cover "
"a contiguous aux-axis range"
)
aux_axis = sample_nodes_sorted[0].arrays.get("aux_axis")
if aux_axis is None:
raise ValueError(f"PROFILE_SAMPLE {group_name!r} is missing aux_axis")
aux_axis = np.asarray(aux_axis, dtype=np.float64)
if n_aux == 0:
n_aux = len(aux_axis)
plan_aux_axis = aux_axis.copy()
elif len(aux_axis) != n_aux or not np.array_equal(aux_axis, plan_aux_axis):
raise ValueError("All lowered profile groups must share one fixed aux_axis")
if len(sample_nodes_sorted) != n_aux:
raise ValueError(
f"PROFILE_SAMPLE group {group_name!r} has "
f"{len(sample_nodes_sorted)} samples but aux_axis length {n_aux}"
)
rep_node = sample_nodes_sorted[0]
rep_param_edges = sorted(
param_edges_by_target.get(rep_node.id, []),
key=lambda edge: edge.position or 0,
)
if not rep_param_edges:
raise ValueError(f"PROFILE_SAMPLE {group_name!r} has no PARAM_INPUT edges")
base_node = id_to_node[rep_param_edges[0].source]
base_row = name_to_row[base_node.name]
is_constant = bool(row_is_constant[base_row])
profile_sample_base_rows_list.append(base_row)
component_func_by_name: dict[str, int] = {}
component_param_rows_by_name: dict[str, list[int]] = {}
component_order: list[str] = []
for edge in rep_param_edges[1:]:
src_node = id_to_node[edge.source]
src_row = name_to_row[src_node.name]
# PARAM_PLUS_TRACE nodes have a "_resolved" suffix; CONVOLUTION
# nodes wrap a PPT for profile-time-dynamics and carry their own
# "<par>_<kernel>_dynamics" name. In both cases, walk back to
# the underlying PPT so profile component parsing sees the
# original profile parameter name.
parse_name = src_node.name
if src_node.kind == NodeKind.PARAM_PLUS_TRACE:
parse_name = parse_name.removesuffix("_resolved")
elif src_node.kind == NodeKind.CONVOLUTION:
ppt_node = _walk_convolution_to_param_plus_trace(
src_node, graph.edges, id_to_node
)
parse_name = ppt_node.name.removesuffix("_resolved")
comp_name, func_name = _parse_profile_component_param_name(
group_name,
parse_name,
)
prof_func_kind = _FUNCTION_NAME_TO_PROFILE_FUNC.get(func_name)
if prof_func_kind is None:
raise ValueError(f"Unknown profile function: {func_name!r}")
if comp_name not in component_func_by_name:
component_order.append(comp_name)
component_func_by_name[comp_name] = int(prof_func_kind)
component_param_rows_by_name[comp_name] = []
component_param_rows_by_name[comp_name].append(src_row)
is_constant = is_constant and bool(row_is_constant[src_row])
for comp_name in component_order:
profile_component_func_ids_list.append(component_func_by_name[comp_name])
profile_component_param_rows_list.extend(
component_param_rows_by_name[comp_name]
)
profile_component_param_indptr_list.append(
len(profile_component_param_rows_list)
)
profile_sample_component_indptr_list.append(
len(profile_component_func_ids_list)
)
group_idx = len(profile_sample_base_rows_list) - 1
profile_sample_group_idx[group_name] = group_idx
profile_sample_is_constant_list.append(is_constant)
n_profile_samples = len(profile_sample_base_rows_list)
profile_sample_base_rows = np.array(profile_sample_base_rows_list, dtype=np.intp)
profile_sample_component_indptr = np.array(
profile_sample_component_indptr_list, dtype=np.intp
)
profile_component_func_ids = np.array(
profile_component_func_ids_list, dtype=np.intp
)
profile_component_param_indptr = np.array(
profile_component_param_indptr_list, dtype=np.intp
)
profile_component_param_rows = np.array(
profile_component_param_rows_list, dtype=np.intp
)
profile_sample_is_constant = np.array(
profile_sample_is_constant_list, dtype=np.bool_
)
# ------------------------------------------------------------------ #
# 4c. Compile per-sample profile expressions #
# ------------------------------------------------------------------ #
profile_expr_groups: dict[str, list[GraphNode]] = {}
for nid in topo_order:
node = id_to_node[nid]
if _is_profile_expr_node(node):
group_name = _profile_group_name(node.name, "profile_expr")
profile_expr_groups.setdefault(group_name, []).append(node)
profile_expr_programs_2d: list[ExprProgram] = []
profile_expr_is_constant_list: list[bool] = []
profile_expr_group_idx: dict[str, int] = {}
for group_name, p_expr_nodes in profile_expr_groups.items():
p_expr_nodes_sorted = sorted(
p_expr_nodes,
key=lambda node: _profile_group_index(node.name, "profile_expr"),
)
aux_indices = [
_profile_group_index(node.name, "profile_expr")
for node in p_expr_nodes_sorted
]
if aux_indices != list(range(len(p_expr_nodes_sorted))):
raise ValueError(
f"Profile expression nodes for {group_name!r} do not cover "
"a contiguous aux-axis range"
)
if len(p_expr_nodes_sorted) != n_aux:
raise ValueError(
f"Profile expression group {group_name!r} has "
f"{len(p_expr_nodes_sorted)} samples but aux_axis length {n_aux}"
)
rep_node = p_expr_nodes_sorted[0]
if rep_node.expr_string is None:
raise ValueError(
f"Profile expression {group_name!r} is missing expr_string"
)
expr_refs = set(_extract_expression_references(rep_node.expr_string))
prof_ref_map: dict[str, int] = {}
for edge in expr_ref_edges_by_target.get(rep_node.id, []):
src_node = id_to_node[edge.source]
if src_node.kind == NodeKind.PROFILE_SAMPLE:
sample_name = _profile_group_name(src_node.name, "profile_sample")
src_idx = n_params + profile_sample_group_idx[sample_name]
match_name = sample_name
else:
src_idx = name_to_row[src_node.name]
match_name = src_node.name
# PARAM_PLUS_TRACE nodes carry a "_resolved" suffix;
# the expression string uses the bare param name.
if match_name not in expr_refs and src_node.name.endswith("_resolved"):
match_name = src_node.name.removesuffix("_resolved")
if match_name in expr_refs:
prof_ref_map[match_name] = src_idx
symbolic = compile_expr_symbolic(rep_node.expr_string)
prof_binding: dict[str, int] = dict(name_to_row)
prof_binding.update(prof_ref_map)
program = _bind_expr_to_rows(symbolic, prof_binding)
profile_expr_programs_2d.append(program)
is_constant = True
for name in symbolic.referenced_names:
bound_idx = int(prof_binding[name])
if bound_idx < n_params:
is_constant = is_constant and bool(row_is_constant[bound_idx])
else:
is_constant = is_constant and bool(
profile_sample_is_constant[bound_idx - n_params]
)
profile_expr_is_constant_list.append(is_constant)
profile_expr_group_idx[group_name] = len(profile_expr_programs_2d) - 1
n_profile_exprs = len(profile_expr_programs_2d)
profile_expr_is_constant = np.array(profile_expr_is_constant_list, dtype=np.bool_)
# ------------------------------------------------------------------ #
# 5. Schedule component ops #
# ------------------------------------------------------------------ #
# Identify peak_sum contributors: nodes with ADDEND edges into the
# "peak_sum" SUM node (if it exists).
peak_sum_sources: set[int] = set()
peak_sum_nid = graph.node_by_name.get("peak_sum")
if peak_sum_nid is not None:
for e in addend_edges_by_target.get(peak_sum_nid, []):
peak_sum_sources.add(e.source)
# Collect per-sample component inputs for PROFILE_AVERAGE nodes.
profile_avg_sample_inputs: dict[int, list[GraphNode]] = {}
sample_component_ids: set[int] = set()
for nid in topo_order:
node = id_to_node[nid]
if node.kind != NodeKind.PROFILE_AVERAGE:
continue
sample_nodes = [
id_to_node[edge.source]
for edge in addend_edges_by_target.get(node.id, [])
if id_to_node[edge.source].kind
in (NodeKind.COMPONENT_EVAL, NodeKind.SPECTRUM_FED_OP)
]
profile_avg_sample_inputs[node.id] = sample_nodes
sample_component_ids.update(sample.id for sample in sample_nodes)
comp_op_kinds = {NodeKind.COMPONENT_EVAL, NodeKind.SPECTRUM_FED_OP}
comp_nodes_topo = [
id_to_node[nid]
for nid in topo_order
if (
(id_to_node[nid].kind in comp_op_kinds and nid not in sample_component_ids)
or id_to_node[nid].kind == NodeKind.PROFILE_AVERAGE
)
]
n_ops = len(comp_nodes_topo)
op_schedule = np.arange(n_ops, dtype=np.intp)
op_kinds_list: list[int] = []
op_param_indptr_list: list[int] = [0]
op_param_source_kinds_list: list[int] = []
op_param_indices_list: list[int] = []
op_needs_spectrum_list: list[bool] = []
op_is_pre_spectrum_list: list[bool] = []
op_is_profiled_list: list[bool] = []
op_is_constant_list: list[bool] = []
for comp_node in comp_nodes_topo:
if comp_node.kind == NodeKind.PROFILE_AVERAGE:
# Profiled component: gather from per-sample COMPONENT_EVAL inputs.
sample_nodes = sorted(
profile_avg_sample_inputs.get(comp_node.id, []),
key=lambda node: _profile_component_sample_index(node.name),
)
if not sample_nodes:
raise ValueError(
f"PROFILE_AVERAGE {comp_node.name!r} has no sample component inputs"
)
if len(sample_nodes) != n_aux:
raise ValueError(
f"PROFILE_AVERAGE {comp_node.name!r} has "
f"{len(sample_nodes)} samples "
f"but aux_axis length {n_aux}"
)
rep_node = sample_nodes[0]
assert rep_node.function_name is not None
op = _FUNCTION_NAME_TO_OP.get(rep_node.function_name)
if op is None:
raise ValueError(
f"Unknown component function: {rep_node.function_name!r}"
)
op_kinds_list.append(int(op))
op_is_profiled_list.append(True)
rep_param_edges = sorted(
param_edges_by_target.get(rep_node.id, []),
key=lambda edge: edge.position or 0,
)
sample_param_edges = [
sorted(
param_edges_by_target.get(sample_node.id, []),
key=lambda edge: edge.position or 0,
)
for sample_node in sample_nodes
]
is_constant = True
for pos, rep_edge in enumerate(rep_param_edges):
src_node = id_to_node[rep_edge.source]
if src_node.kind == NodeKind.PROFILE_SAMPLE:
group_name = _profile_group_name(src_node.name, "profile_sample")
source_kind = int(ParamSourceKind.PROFILE_SAMPLE)
source_idx = profile_sample_group_idx[group_name]
is_constant = is_constant and bool(
profile_sample_is_constant[source_idx]
)
for aux_i, edges in enumerate(sample_param_edges):
sample_src = id_to_node[edges[pos].source]
if sample_src.kind != NodeKind.PROFILE_SAMPLE:
raise ValueError(
"Mixed parameter source kinds "
f"in profiled op {comp_node.name!r}"
)
if (
_profile_group_name(sample_src.name, "profile_sample")
!= group_name
or _profile_group_index(sample_src.name, "profile_sample")
!= aux_i
):
raise ValueError(
"Inconsistent PROFILE_SAMPLE wiring "
f"in {comp_node.name!r}"
)
elif _is_profile_expr_node(src_node):
group_name = _profile_group_name(src_node.name, "profile_expr")
source_kind = int(ParamSourceKind.PROFILE_EXPR)
source_idx = profile_expr_group_idx[group_name]
is_constant = is_constant and bool(
profile_expr_is_constant[source_idx]
)
for aux_i, edges in enumerate(sample_param_edges):
sample_src = id_to_node[edges[pos].source]
if not _is_profile_expr_node(sample_src):
raise ValueError(
"Mixed expression source kinds "
f"in profiled op {comp_node.name!r}"
)
if (
_profile_group_name(sample_src.name, "profile_expr")
!= group_name
or _profile_group_index(sample_src.name, "profile_expr")
!= aux_i
):
raise ValueError(
"Inconsistent profile-expression "
f"wiring in {comp_node.name!r}"
)
else:
source_kind = int(ParamSourceKind.SCALAR)
source_idx = name_to_row[src_node.name]
is_constant = is_constant and bool(row_is_constant[source_idx])
for edges in sample_param_edges[1:]:
if id_to_node[edges[pos].source].id != src_node.id:
raise ValueError(
"Scalar parameter source changed "
f"across samples in {comp_node.name!r}"
)
op_param_source_kinds_list.append(source_kind)
op_param_indices_list.append(source_idx)
op_param_indptr_list.append(len(op_param_indices_list))
else:
# Non-profiled component: standard param row wiring.
assert comp_node.function_name is not None
op = _FUNCTION_NAME_TO_OP.get(comp_node.function_name)
if op is None:
raise ValueError(
f"Unknown component function: {comp_node.function_name!r}"
)
op_kinds_list.append(int(op))
op_is_profiled_list.append(False)
param_edges = sorted(
param_edges_by_target.get(comp_node.id, []),
key=lambda edge: edge.position or 0,
)
is_constant = True
for pe in param_edges:
src_node = id_to_node[pe.source]
src_row = name_to_row[src_node.name]
op_param_source_kinds_list.append(int(ParamSourceKind.SCALAR))
op_param_indices_list.append(src_row)
is_constant = is_constant and bool(row_is_constant[src_row])
op_param_indptr_list.append(len(op_param_indices_list))
has_spec_input = comp_node.id in spectrum_input_targets
op_needs_spectrum_list.append(has_spec_input)
op_is_pre_spectrum_list.append(comp_node.id in peak_sum_sources)
op_is_constant_list.append((not has_spec_input) and is_constant)
op_kinds = np.array(op_kinds_list, dtype=np.intp)
op_param_indptr = np.array(op_param_indptr_list, dtype=np.intp)
op_param_source_kinds = np.array(op_param_source_kinds_list, dtype=np.int8)
op_param_indices = np.array(op_param_indices_list, dtype=np.intp)
op_needs_spectrum = np.array(op_needs_spectrum_list, dtype=np.bool_)
op_is_pre_spectrum = np.array(op_is_pre_spectrum_list, dtype=np.bool_)
op_is_profiled = np.array(op_is_profiled_list, dtype=np.bool_)
op_is_constant = np.array(op_is_constant_list, dtype=np.bool_)
# ------------------------------------------------------------------ #
# 6. Initialize trace matrix #
# ------------------------------------------------------------------ #
param_traces_init = np.zeros((n_params, n_time), dtype=np.float64)
# Static and opt params: broadcast initial value
for node in opt_nodes + static_nodes:
row = name_to_row[node.name]
param_traces_init[row, :] = node.value if node.value is not None else 0.0
# PARAM_PLUS_TRACE: base + dynamics trace at initial values.
# We need to evaluate dynamics functions at initial parameter values
# to populate these rows.
from trspecfit.functions import time as fcts_time
_DYN_DISPATCH: dict[int, Callable[..., Any]] = {
int(DynFuncKind.EXPFUN): fcts_time.expFun,
int(DynFuncKind.SINFUN): fcts_time.sinFun,
int(DynFuncKind.LINFUN): fcts_time.linFun,
int(DynFuncKind.SINDIVX): fcts_time.sinDivX,
int(DynFuncKind.ERFFUN): fcts_time.erfFun,
int(DynFuncKind.SQRTFUN): fcts_time.sqrtFun,
}
_CONV_KERNEL_DISPATCH: dict[int, Callable[..., Any]] = {
int(ConvKernelKind.GAUSSCONV): fcts_time.gaussCONV,
int(ConvKernelKind.LORENTZCONV): fcts_time.lorentzCONV,
int(ConvKernelKind.VOIGTCONV): fcts_time.voigtCONV,
int(ConvKernelKind.EXPSYMCONV): fcts_time.expSymCONV,
int(ConvKernelKind.EXPDECAYCONV): fcts_time.expDecayCONV,
int(ConvKernelKind.EXPRISECONV): fcts_time.expRiseCONV,
int(ConvKernelKind.BOXCONV): fcts_time.boxCONV,
}
from trspecfit.eval_2d import eval_expr_program
from trspecfit.utils.arrays import my_conv
# Dynamics groups, expressions, and resolved-trace convolutions are
# interleaved in topological order so that downstream consumers see
# the fully resolved trace (base + dynamics + expressions + IRF).
for step in range(len(resolution_kinds)):
kind = int(resolution_kinds[step])
idx = int(resolution_indices[step])
if kind == 0: # dynamics group
target = int(dyn_group_target_row[idx])
base = int(dyn_group_base_row[idx])
param_traces_init[target, :] = param_traces_init[base, :]
for s in range(int(dyn_group_indptr[idx]), int(dyn_group_indptr[idx + 1])):
n_dp = int(dyn_sub_n_params[s])
p_vals = [
float(param_traces_init[dyn_sub_param_rows[s, j], 0])
for j in range(n_dp)
]
func = _DYN_DISPATCH[int(dyn_sub_func_id[s])]
param_traces_init[target, :] += (
func(dyn_sub_time_axes[s], *p_vals) * dyn_sub_masks[s]
)
elif kind == 1: # expression
target_row = int(expr_target_rows[idx])
program = expr_programs[idx]
param_traces_init[target_row, :] = eval_expr_program(
program, param_traces_init
)
else: # kind == 2: resolved-trace convolution
target_row = int(conv_target_rows[idx])
kernel_func = _CONV_KERNEL_DISPATCH[int(conv_func_ids[idx])]
p_start = int(conv_param_indptr[idx])
p_end = int(conv_param_indptr[idx + 1])
kernel_params = [
float(param_traces_init[int(conv_param_rows[j]), 0])
for j in range(p_start, p_end)
]
s_start = int(conv_support_indptr[idx])
s_end = int(conv_support_indptr[idx + 1])
support = conv_support_values[s_start:s_end]
kernel = kernel_func(support, *kernel_params)
param_traces_init[target_row, :] = my_conv(
graph.time, param_traces_init[target_row, :], kernel
)
# ------------------------------------------------------------------ #
# 6b. Precompute constant component contributions #
# ------------------------------------------------------------------ #
# Evaluate profile sample/expr values at initial traces for caching.
from trspecfit.eval_2d import (
_evaluate_profile_expr_values_2d,
_evaluate_profile_sample_values_2d,
)
profile_sample_values_init = _evaluate_profile_sample_values_2d(
plan_aux_axis,
param_traces_init,
profile_sample_base_rows,
profile_sample_component_indptr,
profile_component_func_ids,
profile_component_param_indptr,
profile_component_param_rows,
)
profile_expr_values_init = _evaluate_profile_expr_values_2d(
param_traces_init,
profile_sample_values_init,
n_params,
profile_expr_programs_2d,
)
energy = graph.energy[np.newaxis, :]
cached_result = np.zeros((n_time, len(graph.energy)), dtype=np.float64)
cached_peak_sum = np.zeros_like(cached_result)
for op_idx in range(n_ops):
if not op_is_constant[op_idx]:
continue
start = int(op_param_indptr[op_idx])
end = int(op_param_indptr[op_idx + 1])
if op_is_profiled[op_idx]:
from trspecfit.eval_2d import _evaluate_profiled_op_2d
component = _evaluate_profiled_op_2d(
energy,
int(op_kinds[op_idx]),
op_param_source_kinds[start:end],
op_param_indices[start:end],
param_traces_init,
profile_sample_values_init,
profile_expr_values_init,
cached_peak_sum,
needs_spectrum=bool(op_needs_spectrum[op_idx]),
n_aux=n_aux,
)
else:
param_rows = op_param_indices[start:end]
params = [
param_traces_init[int(row), :][:, np.newaxis] for row in param_rows
]
func, _needs = OP_DISPATCH[int(op_kinds[op_idx])]
if op_needs_spectrum[op_idx]:
component = func(energy, *params, cached_peak_sum)
else:
component = func(energy, *params)
cached_result += component
if op_is_pre_spectrum[op_idx]:
cached_peak_sum += component
# ------------------------------------------------------------------ #
# 7. Pack into ScheduledPlan2D #
# ------------------------------------------------------------------ #
return ScheduledPlan2D(
energy=graph.energy,
time=graph.time,
n_params=n_params,
n_time=n_time,
param_traces_init=param_traces_init,
opt_indices=opt_indices,
opt_param_names=opt_param_names,
n_dyn_groups=n_dyn_groups,
dyn_group_target_row=dyn_group_target_row,
dyn_group_base_row=dyn_group_base_row,
dyn_group_indptr=dyn_group_indptr,
dyn_sub_func_id=dyn_sub_func_id,
dyn_sub_param_rows=dyn_sub_param_rows,
dyn_sub_n_params=dyn_sub_n_params,
dyn_sub_time_axes=dyn_sub_time_axes,
dyn_sub_masks=dyn_sub_masks,
n_expressions=n_expressions,
expr_target_rows=expr_target_rows,
expr_programs=expr_programs,
resolution_kinds=resolution_kinds,
resolution_indices=resolution_indices,
n_aux=n_aux,
aux_axis=plan_aux_axis,
n_profile_samples=n_profile_samples,
profile_sample_base_rows=profile_sample_base_rows,
profile_sample_component_indptr=profile_sample_component_indptr,
profile_component_func_ids=profile_component_func_ids,
profile_component_param_indptr=profile_component_param_indptr,
profile_component_param_rows=profile_component_param_rows,
n_profile_exprs=n_profile_exprs,
profile_expr_programs=profile_expr_programs_2d,
n_ops=n_ops,
op_schedule=op_schedule,
op_kinds=op_kinds,
op_param_indptr=op_param_indptr,
op_param_source_kinds=op_param_source_kinds,
op_param_indices=op_param_indices,
op_needs_spectrum=op_needs_spectrum,
op_is_pre_spectrum=op_is_pre_spectrum,
op_is_profiled=op_is_profiled,
op_is_constant=op_is_constant,
cached_result=cached_result,
cached_peak_sum=cached_peak_sum,
n_conv_steps=n_conv_steps,
conv_target_rows=conv_target_rows,
conv_func_ids=conv_func_ids,
conv_param_indptr=conv_param_indptr,
conv_param_rows=conv_param_rows,
conv_support_indptr=conv_support_indptr,
conv_support_values=conv_support_values,
)
# ---------------------------------------------------------------------------
# schedule_1d
# ---------------------------------------------------------------------------
#
def _eval_expr_scalar(program: ExprProgram, values: np.ndarray) -> float:
"""Evaluate an RPN ExprProgram against a scalar parameter vector.
Parameters
----------
program
Compiled RPN instruction array.
values
``(n_params,)`` scalar parameter vector.
Returns
-------
float
Scalar result.
"""
stack: list[float] = []
instr = program.instructions
n_instr = len(instr) // 2
for i in range(n_instr):
kind = ExprNodeKind(instr[2 * i])
operand = instr[2 * i + 1]
if kind == ExprNodeKind.CONST:
stack.append(float(np.int64(operand).view(np.float64)))
elif kind == ExprNodeKind.PARAM_REF:
stack.append(float(values[int(operand)]))
elif kind == ExprNodeKind.ADD:
b, a = stack.pop(), stack.pop()
stack.append(a + b)
elif kind == ExprNodeKind.SUB:
b, a = stack.pop(), stack.pop()
stack.append(a - b)
elif kind == ExprNodeKind.MUL:
b, a = stack.pop(), stack.pop()
stack.append(a * b)
elif kind == ExprNodeKind.DIV:
b, a = stack.pop(), stack.pop()
stack.append(a / b)
elif kind == ExprNodeKind.NEG:
stack.append(-stack.pop())
elif kind == ExprNodeKind.POW:
b, a = stack.pop(), stack.pop()
stack.append(a**b)
assert len(stack) == 1
return stack[0]
#
def _profile_group_name(node_name: str, label: str) -> str:
"""Return the shared profile group base name from a per-sample node name."""
match = re.fullmatch(rf"(.+)_{label}_(\d+)", node_name)
if match is None:
raise ValueError(f"Malformed profile node name: {node_name!r}")
return match.group(1)
#
def _profile_group_index(node_name: str, label: str) -> int:
"""Return the aux-axis sample index encoded in a per-sample node name."""
match = re.fullmatch(rf"(.+)_{label}_(\d+)", node_name)
if match is None:
raise ValueError(f"Malformed profile node name: {node_name!r}")
return int(match.group(2))
#
def _profile_component_sample_index(node_name: str) -> int:
"""Return the aux-axis sample index from ``<component>_sample_<i>``."""
return _profile_group_index(node_name, "sample")
#
def _is_profile_expr_node(node: GraphNode) -> bool:
"""Return True for per-sample profile-expression nodes."""
return (
node.kind == NodeKind.EXPRESSION
and re.fullmatch(r"(.+)_profile_expr_(\d+)", node.name) is not None
)
#
def _parse_profile_component_param_name(
target_param_name: str,
source_param_name: str,
) -> tuple[str, str]:
"""Parse ``<target>_<profile_comp>_<par>`` into component + function name."""
prefix = f"{target_param_name}_"
if not source_param_name.startswith(prefix):
raise ValueError(
f"Profile parameter {source_param_name!r} does not match "
f"target parameter {target_param_name!r}"
)
remainder = source_param_name[len(prefix) :]
comp_name, sep, _par_name = remainder.rpartition("_")
if not sep or not comp_name:
raise ValueError(f"Malformed profile parameter name: {source_param_name!r}")
func_name, sep2, comp_idx = comp_name.rpartition("_")
if not sep2 or not comp_idx.isdigit():
raise ValueError(f"Malformed profile component name: {comp_name!r}")
return comp_name, func_name
#
def _eval_expr_vector(program: ExprProgram, traces: np.ndarray) -> np.ndarray:
"""Evaluate an RPN ExprProgram against an aux-resolved trace matrix."""
from trspecfit.eval_2d import eval_expr_program
return eval_expr_program(program, traces)
#
def _evaluate_profile_sample_values(
aux_axis: np.ndarray,
scalar_values: np.ndarray,
profile_sample_base_indices: np.ndarray,
profile_sample_component_indptr: np.ndarray,
profile_component_func_ids: np.ndarray,
profile_component_param_indptr: np.ndarray,
profile_component_param_indices: np.ndarray,
) -> np.ndarray:
"""Evaluate lowered PROFILE_SAMPLE groups into ``(n_groups, n_aux)`` values."""
n_groups = len(profile_sample_base_indices)
n_aux = len(aux_axis)
if n_groups == 0:
return np.zeros((0, n_aux), dtype=np.float64)
sample_values = np.empty((n_groups, n_aux), dtype=np.float64)
for group_idx in range(n_groups):
base_idx = int(profile_sample_base_indices[group_idx])
values = np.full(n_aux, scalar_values[base_idx], dtype=np.float64)
comp_start = int(profile_sample_component_indptr[group_idx])
comp_end = int(profile_sample_component_indptr[group_idx + 1])
for comp_idx in range(comp_start, comp_end):
func = PROFILE_DISPATCH[int(profile_component_func_ids[comp_idx])]
param_start = int(profile_component_param_indptr[comp_idx])
param_end = int(profile_component_param_indptr[comp_idx + 1])
params = [
float(scalar_values[int(idx)])
for idx in profile_component_param_indices[param_start:param_end]
]
values += np.asarray(func(aux_axis, *params), dtype=np.float64)
sample_values[group_idx, :] = values
return sample_values
#
def _evaluate_profile_expr_values(
scalar_values: np.ndarray,
profile_sample_values: np.ndarray,
n_params: int,
profile_expr_programs: list[ExprProgram],
) -> np.ndarray:
"""Evaluate lowered per-sample profile expressions over the aux axis."""
n_exprs = len(profile_expr_programs)
if n_exprs == 0:
n_aux = profile_sample_values.shape[1] if profile_sample_values.size else 0
return np.zeros((0, n_aux), dtype=np.float64)
n_aux = profile_sample_values.shape[1]
traces = np.empty(
(n_params + profile_sample_values.shape[0], n_aux), dtype=np.float64
)
traces[:n_params, :] = scalar_values[:, np.newaxis]
if profile_sample_values.size:
traces[n_params:, :] = profile_sample_values
expr_values = np.empty((n_exprs, n_aux), dtype=np.float64)
for expr_idx, program in enumerate(profile_expr_programs):
expr_values[expr_idx, :] = _eval_expr_vector(program, traces)
return expr_values
#
def _evaluate_scheduled_op_1d(
energy: np.ndarray,
kind: int,
param_source_kinds: np.ndarray,
param_indices: np.ndarray,
scalar_values: np.ndarray,
profile_sample_values: np.ndarray,
profile_expr_values: np.ndarray,
peak_sum: np.ndarray,
*,
needs_spectrum: bool,
is_profiled: bool,
n_aux: int,
) -> np.ndarray:
"""Evaluate one scheduled 1D op in either scalar or profiled mode."""
func, _needs = OP_DISPATCH[kind]
if not is_profiled:
scalar_params = [float(scalar_values[int(row)]) for row in param_indices]
if needs_spectrum:
return np.asarray(func(energy, *scalar_params, peak_sum), dtype=np.float64)
return np.asarray(func(energy, *scalar_params), dtype=np.float64)
energy_2d = energy[np.newaxis, :]
params_2d: list[np.ndarray] = []
for source_kind, source_idx in zip(
param_source_kinds,
param_indices,
strict=True,
):
if int(source_kind) == int(ParamSourceKind.SCALAR):
param = np.full(
(n_aux, 1), scalar_values[int(source_idx)], dtype=np.float64
)
elif int(source_kind) == int(ParamSourceKind.PROFILE_SAMPLE):
param = profile_sample_values[int(source_idx), :][:, np.newaxis]
else:
param = profile_expr_values[int(source_idx), :][:, np.newaxis]
params_2d.append(param)
if needs_spectrum:
profiled = func(energy_2d, *params_2d, peak_sum[np.newaxis, :])
else:
profiled = func(energy_2d, *params_2d)
result: np.ndarray = np.asarray(profiled, dtype=np.float64).mean(axis=0)
return result
#
[docs]
def schedule_1d(graph: GraphIR) -> ScheduledPlan1D:
"""Compile a GraphIR into a flat 1D execution schedule.
Parameters
----------
graph : GraphIR
Must pass ``can_lower_1d(graph)``.
Returns
-------
ScheduledPlan1D
Packed-array execution schedule for ``evaluate_1d``.
Raises
------
ValueError
If the graph cannot be lowered (domain, unsupported nodes, etc.).
"""
if not can_lower_1d(graph):
raise ValueError("Graph cannot be lowered to 1D backend")
assert graph.energy is not None
# ------------------------------------------------------------------ #
# 1. Topological sort + helper lookups #
# ------------------------------------------------------------------ #
topo_order = _topological_sort(graph)
id_to_node: dict[int, GraphNode] = {n.id: n for n in graph.nodes}
param_edges_by_target: dict[int, list[GraphEdge]] = {}
expr_ref_edges_by_target: dict[int, list[GraphEdge]] = {}
addend_edges_by_target: dict[int, list[GraphEdge]] = {}
spectrum_input_targets: set[int] = set()
for edge in graph.edges:
if edge.kind == EdgeKind.PARAM_INPUT:
param_edges_by_target.setdefault(edge.target, []).append(edge)
elif edge.kind == EdgeKind.EXPR_REF:
expr_ref_edges_by_target.setdefault(edge.target, []).append(edge)
elif edge.kind == EdgeKind.ADDEND:
addend_edges_by_target.setdefault(edge.target, []).append(edge)
elif edge.kind == EdgeKind.SPECTRUM_INPUT:
spectrum_input_targets.add(edge.target)
# ------------------------------------------------------------------ #
# 2. Assign scalar parameter indices #
# ------------------------------------------------------------------ #
_ROW_KINDS = frozenset(
{
NodeKind.STATIC_PARAM,
NodeKind.OPT_PARAM,
NodeKind.EXPRESSION,
}
)
opt_nodes: list[GraphNode] = []
static_nodes: list[GraphNode] = []
computed_nodes: list[GraphNode] = []
for nid in topo_order:
node = id_to_node[nid]
if node.kind not in _ROW_KINDS or _is_profile_expr_node(node):
continue
if node.kind == NodeKind.OPT_PARAM and node.vary:
opt_nodes.append(node)
elif node.kind in (NodeKind.STATIC_PARAM, NodeKind.OPT_PARAM):
static_nodes.append(node)
else:
computed_nodes.append(node)
all_param_nodes = opt_nodes + static_nodes + computed_nodes
n_params = len(all_param_nodes)
name_to_idx = {node.name: idx for idx, node in enumerate(all_param_nodes)}
idx_is_constant = np.zeros(n_params, dtype=np.bool_)
for node in static_nodes:
idx_is_constant[name_to_idx[node.name]] = True
n_opt = len(opt_nodes)
opt_indices = np.arange(n_opt, dtype=np.intp)
opt_param_names = [n.name for n in opt_nodes]
# ------------------------------------------------------------------ #
# 3. Compile scalar expressions #
# ------------------------------------------------------------------ #
expr_nodes_topo = [
id_to_node[nid]
for nid in topo_order
if id_to_node[nid].kind == NodeKind.EXPRESSION
and not _is_profile_expr_node(id_to_node[nid])
]
expr_programs: list[ExprProgram] = []
expr_target_indices_list: list[int] = []
for expr_node in expr_nodes_topo:
assert expr_node.expr_string is not None
symbolic = compile_expr_symbolic(expr_node.expr_string)
expr_refs = set(_extract_expression_references(expr_node.expr_string))
ref_map: dict[str, int] = {}
for edge in expr_ref_edges_by_target.get(expr_node.id, []):
src_node = id_to_node[edge.source]
src_idx = name_to_idx[src_node.name]
if src_node.name in expr_refs:
ref_map[src_node.name] = src_idx
binding = dict(name_to_idx)
binding.update(ref_map)
program = _bind_expr_to_rows(symbolic, binding)
expr_programs.append(program)
target_idx = name_to_idx[expr_node.name]
expr_target_indices_list.append(target_idx)
idx_is_constant[target_idx] = all(
idx_is_constant[int(binding[name])] for name in symbolic.referenced_names
)
n_expressions = len(expr_programs)
expr_target_indices = np.array(expr_target_indices_list, dtype=np.intp)
# ------------------------------------------------------------------ #
# 4. Compile PROFILE_SAMPLE groups #
# ------------------------------------------------------------------ #
profile_sample_groups: dict[str, list[GraphNode]] = {}
for nid in topo_order:
node = id_to_node[nid]
if node.kind != NodeKind.PROFILE_SAMPLE:
continue
group_name = _profile_group_name(node.name, "profile_sample")
profile_sample_groups.setdefault(group_name, []).append(node)
plan_aux_axis = np.zeros(0, dtype=np.float64)
n_aux = 0
profile_sample_base_indices_list: list[int] = []
profile_sample_component_indptr_list: list[int] = [0]
profile_component_func_ids_list: list[int] = []
profile_component_param_indptr_list: list[int] = [0]
profile_component_param_indices_list: list[int] = []
profile_sample_is_constant_list: list[bool] = []
profile_sample_group_idx: dict[str, int] = {}
for group_name, sample_nodes in profile_sample_groups.items():
sample_nodes_sorted = sorted(
sample_nodes,
key=lambda node: _profile_group_index(node.name, "profile_sample"),
)
aux_indices = [
_profile_group_index(node.name, "profile_sample")
for node in sample_nodes_sorted
]
if aux_indices != list(range(len(sample_nodes_sorted))):
raise ValueError(
f"PROFILE_SAMPLE nodes for {group_name!r} do not cover "
"a contiguous aux-axis range"
)
aux_axis = sample_nodes_sorted[0].arrays.get("aux_axis")
if aux_axis is None:
raise ValueError(f"PROFILE_SAMPLE {group_name!r} is missing aux_axis")
aux_axis = np.asarray(aux_axis, dtype=np.float64)
if n_aux == 0:
n_aux = len(aux_axis)
plan_aux_axis = aux_axis.copy()
elif len(aux_axis) != n_aux or not np.array_equal(aux_axis, plan_aux_axis):
raise ValueError("All lowered profile groups must share one fixed aux_axis")
if len(sample_nodes_sorted) != n_aux:
raise ValueError(
f"PROFILE_SAMPLE group {group_name!r} has {len(sample_nodes_sorted)} "
f"samples but aux_axis length {n_aux}"
)
rep_node = sample_nodes_sorted[0]
param_edges = sorted(
param_edges_by_target.get(rep_node.id, []),
key=lambda edge: edge.position or 0,
)
if not param_edges:
raise ValueError(f"PROFILE_SAMPLE {group_name!r} has no PARAM_INPUT edges")
base_node = id_to_node[param_edges[0].source]
if base_node.name not in name_to_idx:
raise ValueError(
f"PROFILE_SAMPLE base source {base_node.name!r} is not scalar-lowerable"
)
base_idx = name_to_idx[base_node.name]
is_constant = bool(idx_is_constant[base_idx])
profile_sample_base_indices_list.append(base_idx)
component_func_by_name: dict[str, int] = {}
component_param_indices_by_name: dict[str, list[int]] = {}
component_order: list[str] = []
for edge in param_edges[1:]:
src_node = id_to_node[edge.source]
if src_node.name not in name_to_idx:
raise ValueError(
f"Profile parameter source {src_node.name!r} "
"is not scalar-lowerable"
)
comp_name, func_name = _parse_profile_component_param_name(
group_name,
src_node.name,
)
prof_func_kind = _FUNCTION_NAME_TO_PROFILE_FUNC.get(func_name)
if prof_func_kind is None:
raise ValueError(f"Unknown profile function: {func_name!r}")
if comp_name not in component_func_by_name:
component_order.append(comp_name)
component_func_by_name[comp_name] = int(prof_func_kind)
component_param_indices_by_name[comp_name] = []
src_idx = name_to_idx[src_node.name]
component_param_indices_by_name[comp_name].append(src_idx)
is_constant = is_constant and bool(idx_is_constant[src_idx])
for comp_name in component_order:
profile_component_func_ids_list.append(component_func_by_name[comp_name])
profile_component_param_indices_list.extend(
component_param_indices_by_name[comp_name]
)
profile_component_param_indptr_list.append(
len(profile_component_param_indices_list)
)
profile_sample_component_indptr_list.append(
len(profile_component_func_ids_list)
)
group_idx = len(profile_sample_base_indices_list) - 1
profile_sample_group_idx[group_name] = group_idx
profile_sample_is_constant_list.append(is_constant)
n_profile_samples = len(profile_sample_base_indices_list)
profile_sample_base_indices = np.array(
profile_sample_base_indices_list, dtype=np.intp
)
profile_sample_component_indptr = np.array(
profile_sample_component_indptr_list, dtype=np.intp
)
profile_component_func_ids = np.array(
profile_component_func_ids_list, dtype=np.intp
)
profile_component_param_indptr = np.array(
profile_component_param_indptr_list, dtype=np.intp
)
profile_component_param_indices = np.array(
profile_component_param_indices_list, dtype=np.intp
)
profile_sample_is_constant = np.array(
profile_sample_is_constant_list,
dtype=np.bool_,
)
# ------------------------------------------------------------------ #
# 5. Compile per-sample profile expressions #
# ------------------------------------------------------------------ #
profile_expr_groups: dict[str, list[GraphNode]] = {}
for nid in topo_order:
node = id_to_node[nid]
if _is_profile_expr_node(node):
group_name = _profile_group_name(node.name, "profile_expr")
profile_expr_groups.setdefault(group_name, []).append(node)
profile_expr_programs: list[ExprProgram] = []
profile_expr_is_constant_list: list[bool] = []
profile_expr_group_idx: dict[str, int] = {}
for group_name, expr_nodes in profile_expr_groups.items():
expr_nodes_sorted = sorted(
expr_nodes,
key=lambda node: _profile_group_index(node.name, "profile_expr"),
)
aux_indices = [
_profile_group_index(node.name, "profile_expr")
for node in expr_nodes_sorted
]
if aux_indices != list(range(len(expr_nodes_sorted))):
raise ValueError(
f"Profile expression nodes for {group_name!r} do not cover "
"a contiguous aux-axis range"
)
if len(expr_nodes_sorted) != n_aux:
raise ValueError(
f"Profile expression group {group_name!r} has {len(expr_nodes_sorted)} "
f"samples but aux_axis length {n_aux}"
)
rep_node = expr_nodes_sorted[0]
if rep_node.expr_string is None:
raise ValueError(
f"Profile expression {group_name!r} is missing expr_string"
)
expr_refs = set(_extract_expression_references(rep_node.expr_string))
prof_ref_map: dict[str, int] = {}
for edge in expr_ref_edges_by_target.get(rep_node.id, []):
src_node = id_to_node[edge.source]
if src_node.kind == NodeKind.PROFILE_SAMPLE:
sample_name = _profile_group_name(src_node.name, "profile_sample")
src_idx = n_params + profile_sample_group_idx[sample_name]
match_name = sample_name
else:
src_idx = name_to_idx[src_node.name]
match_name = src_node.name
if match_name in expr_refs:
prof_ref_map[match_name] = src_idx
symbolic = compile_expr_symbolic(rep_node.expr_string)
binding = dict(name_to_idx)
binding.update(prof_ref_map)
program = _bind_expr_to_rows(symbolic, binding)
profile_expr_programs.append(program)
is_constant = True
for name in symbolic.referenced_names:
bound_idx = int(binding[name])
if bound_idx < n_params:
is_constant = is_constant and bool(idx_is_constant[bound_idx])
else:
is_constant = is_constant and bool(
profile_sample_is_constant[bound_idx - n_params]
)
profile_expr_is_constant_list.append(is_constant)
profile_expr_group_idx[group_name] = len(profile_expr_programs) - 1
n_profile_exprs = len(profile_expr_programs)
profile_expr_is_constant = np.array(profile_expr_is_constant_list, dtype=np.bool_)
# ------------------------------------------------------------------ #
# 6. Schedule component ops #
# ------------------------------------------------------------------ #
peak_sum_sources: set[int] = set()
peak_sum_nid = graph.node_by_name.get("peak_sum")
if peak_sum_nid is not None:
for edge in addend_edges_by_target.get(peak_sum_nid, []):
peak_sum_sources.add(edge.source)
profile_avg_sample_inputs: dict[int, list[GraphNode]] = {}
sample_component_ids: set[int] = set()
for nid in topo_order:
node = id_to_node[nid]
if node.kind != NodeKind.PROFILE_AVERAGE:
continue
sample_nodes = [
id_to_node[edge.source]
for edge in addend_edges_by_target.get(node.id, [])
if id_to_node[edge.source].kind
in (NodeKind.COMPONENT_EVAL, NodeKind.SPECTRUM_FED_OP)
]
profile_avg_sample_inputs[node.id] = sample_nodes
sample_component_ids.update(sample.id for sample in sample_nodes)
comp_nodes_topo = [
id_to_node[nid]
for nid in topo_order
if (
(
id_to_node[nid].kind
in (NodeKind.COMPONENT_EVAL, NodeKind.SPECTRUM_FED_OP)
and nid not in sample_component_ids
)
or id_to_node[nid].kind == NodeKind.PROFILE_AVERAGE
)
]
op_kinds_list: list[int] = []
op_param_indptr_list: list[int] = [0]
op_param_source_kinds_list: list[int] = []
op_param_indices_list: list[int] = []
op_needs_spectrum_list: list[bool] = []
op_is_pre_spectrum_list: list[bool] = []
op_is_profiled_list: list[bool] = []
op_is_constant_list: list[bool] = []
for comp_node in comp_nodes_topo:
if comp_node.kind == NodeKind.PROFILE_AVERAGE:
sample_nodes = sorted(
profile_avg_sample_inputs.get(comp_node.id, []),
key=lambda node: _profile_component_sample_index(node.name),
)
if not sample_nodes:
raise ValueError(
f"PROFILE_AVERAGE {comp_node.name!r} has no sample component inputs"
)
if len(sample_nodes) != n_aux:
raise ValueError(
f"PROFILE_AVERAGE {comp_node.name!r} has "
f"{len(sample_nodes)} samples "
f"but aux_axis length {n_aux}"
)
rep_node = sample_nodes[0]
assert rep_node.function_name is not None
op = _FUNCTION_NAME_TO_OP.get(rep_node.function_name)
if op is None:
raise ValueError(
f"Unknown component function: {rep_node.function_name!r}"
)
op_kinds_list.append(int(op))
op_is_profiled_list.append(True)
rep_param_edges = sorted(
param_edges_by_target.get(rep_node.id, []),
key=lambda edge: edge.position or 0,
)
sample_param_edges = [
sorted(
param_edges_by_target.get(sample_node.id, []),
key=lambda edge: edge.position or 0,
)
for sample_node in sample_nodes
]
is_constant = True
for pos, rep_edge in enumerate(rep_param_edges):
src_node = id_to_node[rep_edge.source]
if src_node.kind == NodeKind.PROFILE_SAMPLE:
group_name = _profile_group_name(src_node.name, "profile_sample")
source_kind = int(ParamSourceKind.PROFILE_SAMPLE)
source_idx = profile_sample_group_idx[group_name]
is_constant = is_constant and bool(
profile_sample_is_constant[source_idx]
)
for aux_i, edges in enumerate(sample_param_edges):
sample_src = id_to_node[edges[pos].source]
if sample_src.kind != NodeKind.PROFILE_SAMPLE:
raise ValueError(
"Mixed parameter source kinds "
f"in profiled op {comp_node.name!r}"
)
if (
_profile_group_name(sample_src.name, "profile_sample")
!= group_name
or _profile_group_index(sample_src.name, "profile_sample")
!= aux_i
):
raise ValueError(
"Inconsistent PROFILE_SAMPLE wiring "
f"in {comp_node.name!r}"
)
elif _is_profile_expr_node(src_node):
group_name = _profile_group_name(src_node.name, "profile_expr")
source_kind = int(ParamSourceKind.PROFILE_EXPR)
source_idx = profile_expr_group_idx[group_name]
is_constant = is_constant and bool(
profile_expr_is_constant[source_idx]
)
for aux_i, edges in enumerate(sample_param_edges):
sample_src = id_to_node[edges[pos].source]
if not _is_profile_expr_node(sample_src):
raise ValueError(
"Mixed expression source kinds "
f"in profiled op {comp_node.name!r}"
)
if (
_profile_group_name(sample_src.name, "profile_expr")
!= group_name
or _profile_group_index(sample_src.name, "profile_expr")
!= aux_i
):
raise ValueError(
"Inconsistent profile-expression "
f"wiring in {comp_node.name!r}"
)
else:
if src_node.name not in name_to_idx:
raise ValueError(
f"Non-scalar parameter source {src_node.name!r} in 1D op"
)
source_kind = int(ParamSourceKind.SCALAR)
source_idx = name_to_idx[src_node.name]
is_constant = is_constant and bool(idx_is_constant[source_idx])
for edges in sample_param_edges[1:]:
if id_to_node[edges[pos].source].id != src_node.id:
raise ValueError(
"Scalar parameter source changed "
f"across samples in {comp_node.name!r}"
)
op_param_source_kinds_list.append(source_kind)
op_param_indices_list.append(source_idx)
op_param_indptr_list.append(len(op_param_indices_list))
else:
assert comp_node.function_name is not None
op = _FUNCTION_NAME_TO_OP.get(comp_node.function_name)
if op is None:
raise ValueError(
f"Unknown component function: {comp_node.function_name!r}"
)
op_kinds_list.append(int(op))
op_is_profiled_list.append(False)
param_edges = sorted(
param_edges_by_target.get(comp_node.id, []),
key=lambda edge: edge.position or 0,
)
is_constant = True
for edge in param_edges:
src_node = id_to_node[edge.source]
if src_node.name not in name_to_idx:
raise ValueError(
f"Non-scalar parameter source {src_node.name!r} in 1D op"
)
src_idx = name_to_idx[src_node.name]
op_param_source_kinds_list.append(int(ParamSourceKind.SCALAR))
op_param_indices_list.append(src_idx)
is_constant = is_constant and bool(idx_is_constant[src_idx])
op_param_indptr_list.append(len(op_param_indices_list))
has_spec_input = comp_node.id in spectrum_input_targets
op_needs_spectrum_list.append(has_spec_input)
op_is_pre_spectrum_list.append(comp_node.id in peak_sum_sources)
op_is_constant_list.append((not has_spec_input) and is_constant)
n_ops = len(comp_nodes_topo)
op_kinds = np.array(op_kinds_list, dtype=np.intp)
op_param_indptr = np.array(op_param_indptr_list, dtype=np.intp)
op_param_source_kinds = np.array(op_param_source_kinds_list, dtype=np.int8)
op_param_indices = np.array(op_param_indices_list, dtype=np.intp)
op_needs_spectrum = np.array(op_needs_spectrum_list, dtype=np.bool_)
op_is_pre_spectrum = np.array(op_is_pre_spectrum_list, dtype=np.bool_)
op_is_profiled = np.array(op_is_profiled_list, dtype=np.bool_)
op_is_constant = np.array(op_is_constant_list, dtype=np.bool_)
# ------------------------------------------------------------------ #
# 7. Initialize scalar + profile values #
# ------------------------------------------------------------------ #
param_values_init = np.zeros(n_params, dtype=np.float64)
for node in opt_nodes + static_nodes:
param_values_init[name_to_idx[node.name]] = (
node.value if node.value is not None else 0.0
)
for i in range(n_expressions):
target_idx = int(expr_target_indices[i])
param_values_init[target_idx] = _eval_expr_scalar(
expr_programs[i], param_values_init
)
profile_sample_values_init = _evaluate_profile_sample_values(
plan_aux_axis,
param_values_init,
profile_sample_base_indices,
profile_sample_component_indptr,
profile_component_func_ids,
profile_component_param_indptr,
profile_component_param_indices,
)
profile_expr_values_init = _evaluate_profile_expr_values(
param_values_init,
profile_sample_values_init,
n_params,
profile_expr_programs,
)
# ------------------------------------------------------------------ #
# 8. Precompute constant component contributions #
# ------------------------------------------------------------------ #
energy = graph.energy
cached_result = np.zeros(len(energy), dtype=np.float64)
cached_peak_sum = np.zeros_like(cached_result)
for op_idx in range(n_ops):
if not op_is_constant[op_idx]:
continue
start = int(op_param_indptr[op_idx])
end = int(op_param_indptr[op_idx + 1])
component = _evaluate_scheduled_op_1d(
energy,
int(op_kinds[op_idx]),
op_param_source_kinds[start:end],
op_param_indices[start:end],
param_values_init,
profile_sample_values_init,
profile_expr_values_init,
cached_peak_sum,
needs_spectrum=bool(op_needs_spectrum[op_idx]),
is_profiled=bool(op_is_profiled[op_idx]),
n_aux=n_aux,
)
cached_result += component
if op_is_pre_spectrum[op_idx]:
cached_peak_sum += component
# ------------------------------------------------------------------ #
# 9. Pack into ScheduledPlan1D #
# ------------------------------------------------------------------ #
return ScheduledPlan1D(
energy=energy,
n_params=n_params,
param_values_init=param_values_init,
opt_indices=opt_indices,
opt_param_names=opt_param_names,
n_expressions=n_expressions,
expr_target_indices=expr_target_indices,
expr_programs=expr_programs,
n_aux=n_aux,
aux_axis=plan_aux_axis,
n_profile_samples=n_profile_samples,
profile_sample_base_indices=profile_sample_base_indices,
profile_sample_component_indptr=profile_sample_component_indptr,
profile_component_func_ids=profile_component_func_ids,
profile_component_param_indptr=profile_component_param_indptr,
profile_component_param_indices=profile_component_param_indices,
n_profile_exprs=n_profile_exprs,
profile_expr_programs=profile_expr_programs,
n_ops=n_ops,
op_kinds=op_kinds,
op_param_indptr=op_param_indptr,
op_param_source_kinds=op_param_source_kinds,
op_param_indices=op_param_indices,
op_needs_spectrum=op_needs_spectrum,
op_is_pre_spectrum=op_is_pre_spectrum,
op_is_profiled=op_is_profiled,
op_is_constant=op_is_constant,
cached_result=cached_result,
cached_peak_sum=cached_peak_sum,
)