Source code for trspecfit.graph_ir

"""Graph-based intermediate representation for model fitting.

The GraphIR is a DAG (directed acyclic graph) of typed nodes connected
by explicit dependency edges.  It captures the full semantics of a
model without prescribing an evaluation strategy.

Three-layer design:

1. **OOP tree** -- the user-facing Model/Component/Par objects.
   Handles parsing, validation, user interaction.  Unchanged.
2. **GraphIR** -- a directed acyclic graph of typed nodes with explicit
   data-dependency edges.  Axis-agnostic: works for 1D and 2D models.
3. **ScheduledPlan2D** -- a flat, packed-array execution schedule
   compiled from the graph by the 2D backend.  No Python objects, no
   strings, no dicts in the hot path.
"""

from __future__ import annotations

import re
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import IntEnum
from typing import TYPE_CHECKING, Any

import numpy as np

from trspecfit.functions import energy as fcts_energy
from trspecfit.functions import profile as fcts_profile

if TYPE_CHECKING:
    from trspecfit.mcp import Component, Model, Par


#
#
[docs] class NodeKind(IntEnum): """Node types in the model graph.""" # --- Parameter nodes (leaves) --- STATIC_PARAM = 0 OPT_PARAM = 1 # --- Computed parameter nodes --- DYNAMICS_TRACE = 2 PARAM_PLUS_TRACE = 3 EXPRESSION = 4 # --- Component evaluation nodes --- COMPONENT_EVAL = 5 # --- Reduction / combination nodes --- SUM = 6 SPECTRUM_FED_OP = 7 # --- Convolution and profile nodes (representable, not v1-compilable) --- CONVOLUTION = 100 PROFILE_SAMPLE = 101 PROFILE_AVERAGE = 102 SUBCYCLE_MASK = 103 SUBCYCLE_REMAP = 104
# #
[docs] class EdgeKind(IntEnum): """Edge types in the model graph.""" PARAM_INPUT = 0 TRACE_INPUT = 1 BASE_INPUT = 2 ADDEND = 3 SPECTRUM_INPUT = 4 EXPR_REF = 5
# #
[docs] class DomainKind(IntEnum): """Model domain classification. Determined by which axes the model operates on: - ``ENERGY_1D``: model has energy axis only. - ``TIME_1D``: model has time axis only. - ``ENERGY_TIME_2D``: model has both axes. """ ENERGY_1D = 0 TIME_1D = 1 ENERGY_TIME_2D = 2
# #
[docs] class OpKind(IntEnum): """2D backend component function op codes. Backend-specific lowered enum. ``schedule_2d`` maps ``GraphNode.function_name`` to ``OpKind`` during compilation. """ GAUSS = 0 GAUSS_ASYM = 1 LORENTZ = 2 VOIGT = 3 GLS = 4 GLP = 5 DS = 6 OFFSET = 10 LINBACK = 11 SHIRLEY = 12
# #
[docs] class DynFuncKind(IntEnum): """Dynamics function op codes for the 2D backend. Backend-specific lowered enum. ``schedule_2d`` maps ``GraphNode.function_name`` (for ``DYNAMICS_TRACE`` nodes) to ``DynFuncKind`` during compilation. """ EXPFUN = 0 SINFUN = 1 LINFUN = 2 SINDIVX = 3 ERFFUN = 4 SQRTFUN = 5
# #
[docs] class ProfileFuncKind(IntEnum): """Profile-function op codes for lowered 1D profile evaluation.""" PEXPDECAY = 0 PLINEAR = 1 PGAUSS = 2
# #
[docs] class ParamSourceKind(IntEnum): """Parameter source kinds for profiled component evaluation (1D and 2D).""" SCALAR = 0 PROFILE_SAMPLE = 1 PROFILE_EXPR = 2
_FUNCTION_NAME_TO_DYN_FUNC: dict[str, DynFuncKind] = { "expFun": DynFuncKind.EXPFUN, "sinFun": DynFuncKind.SINFUN, "linFun": DynFuncKind.LINFUN, "sinDivX": DynFuncKind.SINDIVX, "erfFun": DynFuncKind.ERFFUN, "sqrtFun": DynFuncKind.SQRTFUN, } _FUNCTION_NAME_TO_PROFILE_FUNC: dict[str, ProfileFuncKind] = { "pExpDecay": ProfileFuncKind.PEXPDECAY, "pLinear": ProfileFuncKind.PLINEAR, "pGauss": ProfileFuncKind.PGAUSS, } # #
[docs] class ConvKernelKind(IntEnum): """Registry IDs for convolution kernel functions. Only kernel functions listed here can be lowered; other kernels (or non-time packages) fall back to MCP. """ GAUSSCONV = 0 LORENTZCONV = 1 VOIGTCONV = 2 EXPSYMCONV = 3 EXPDECAYCONV = 4 EXPRISECONV = 5 BOXCONV = 6
_FUNCTION_NAME_TO_CONV_KERNEL: dict[str, ConvKernelKind] = { "gaussCONV": ConvKernelKind.GAUSSCONV, "lorentzCONV": ConvKernelKind.LORENTZCONV, "voigtCONV": ConvKernelKind.VOIGTCONV, "expSymCONV": ConvKernelKind.EXPSYMCONV, "expDecayCONV": ConvKernelKind.EXPDECAYCONV, "expRiseCONV": ConvKernelKind.EXPRISECONV, "boxCONV": ConvKernelKind.BOXCONV, } _FUNCTION_NAME_TO_OP: dict[str, OpKind] = { "Gauss": OpKind.GAUSS, "GaussAsym": OpKind.GAUSS_ASYM, "Lorentz": OpKind.LORENTZ, "Voigt": OpKind.VOIGT, "GLS": OpKind.GLS, "GLP": OpKind.GLP, "DS": OpKind.DS, "Offset": OpKind.OFFSET, "LinBack": OpKind.LINBACK, "Shirley": OpKind.SHIRLEY, } #: Maps ``OpKind`` → ``(eval_function, needs_spectrum)``. #: Single source of truth for component dispatch -- used by both the #: evaluator hot path and constant-component precomputation. OP_DISPATCH: dict[int, tuple[Callable[..., Any], bool]] = { int(OpKind.GAUSS): (fcts_energy.Gauss, False), int(OpKind.GAUSS_ASYM): (fcts_energy.GaussAsym, False), int(OpKind.LORENTZ): (fcts_energy.Lorentz, False), int(OpKind.VOIGT): (fcts_energy.Voigt, False), int(OpKind.GLS): (fcts_energy.GLS, False), int(OpKind.GLP): (fcts_energy.GLP, False), int(OpKind.DS): (fcts_energy.DS, False), int(OpKind.OFFSET): (fcts_energy.Offset, False), int(OpKind.LINBACK): (fcts_energy.LinBack, False), int(OpKind.SHIRLEY): (fcts_energy.Shirley, True), } PROFILE_DISPATCH: dict[int, Callable[..., Any]] = { int(ProfileFuncKind.PEXPDECAY): fcts_profile.pExpDecay, int(ProfileFuncKind.PLINEAR): fcts_profile.pLinear, int(ProfileFuncKind.PGAUSS): fcts_profile.pGauss, } # #
[docs] class ExprNodeKind(IntEnum): """RPN instruction types for compiled expressions.""" CONST = 0 PARAM_REF = 1 ADD = 2 SUB = 3 MUL = 4 DIV = 5 NEG = 6 POW = 7
# #
[docs] @dataclass class GraphNode: """One node in the model graph.""" id: int kind: NodeKind name: str source_order: int # Payload (interpretation depends on kind): value: float | None = None function_name: str | None = None package: str | None = None expr_string: str | None = None vary: bool = False bounds: tuple[float, float] | None = None arrays: dict[str, np.ndarray] = field(default_factory=dict)
# #
[docs] @dataclass class GraphEdge: """One edge in the model graph.""" source: int target: int kind: EdgeKind position: int | None = None
# #
[docs] @dataclass class GraphIR: """Directed acyclic graph representing a model. Axis-agnostic: works for 1D and 2D models. A 1D energy model has ``time=None``; adding dynamics populates ``time`` and promotes ``domain`` to ``ENERGY_TIME_2D``. """ nodes: list[GraphNode] edges: list[GraphEdge] domain: DomainKind energy: np.ndarray | None = None time: np.ndarray | None = None node_by_name: dict[str, int] = field(default_factory=dict) #
[docs] def to_dot(self, *, collapse_profiles: bool = True) -> str: """Return a Graphviz DOT string for this graph. Node shapes and colours encode ``NodeKind``; edge labels encode ``EdgeKind``. The output can be rendered with ``dot -Tpng`` or any Graphviz viewer. Parameters ---------- collapse_profiles : bool, default=True When True, per-sample profile nodes (``PROFILE_SAMPLE``, per-sample ``COMPONENT_EVAL``, per-sample ``EXPRESSION``) are collapsed into single representative nodes showing the sample count. This keeps profile models readable. """ _NODE_STYLE: dict[NodeKind, dict[str, str]] = { NodeKind.STATIC_PARAM: dict( shape="ellipse", style="filled", fillcolor="#d3d3d3" ), NodeKind.OPT_PARAM: dict( shape="ellipse", style="filled", fillcolor="#87ceeb" ), NodeKind.DYNAMICS_TRACE: dict( shape="box", style="filled", fillcolor="#ffa07a" ), NodeKind.PARAM_PLUS_TRACE: dict( shape="box", style="filled", fillcolor="#ffcc80" ), NodeKind.EXPRESSION: dict( shape="hexagon", style="filled", fillcolor="#dda0dd" ), NodeKind.COMPONENT_EVAL: dict( shape="box", style="filled,bold", fillcolor="#90ee90" ), NodeKind.SUM: dict(shape="diamond", style="filled", fillcolor="#fffacd"), NodeKind.SPECTRUM_FED_OP: dict( shape="box", style="filled,bold", fillcolor="#f08080" ), NodeKind.CONVOLUTION: dict( shape="octagon", style="filled", fillcolor="#e0e0ff" ), NodeKind.PROFILE_SAMPLE: dict( shape="parallelogram", style="filled", fillcolor="#c8e6c9" ), NodeKind.PROFILE_AVERAGE: dict( shape="parallelogram", style="filled", fillcolor="#a5d6a7" ), NodeKind.SUBCYCLE_MASK: dict( shape="trapezium", style="filled", fillcolor="#ffe0b2" ), NodeKind.SUBCYCLE_REMAP: dict( shape="trapezium", style="filled", fillcolor="#ffcc80" ), } _EDGE_STYLE: dict[EdgeKind, dict[str, str]] = { EdgeKind.PARAM_INPUT: dict(color="#333333"), EdgeKind.TRACE_INPUT: dict(color="#ff6600", style="dashed"), EdgeKind.BASE_INPUT: dict(color="#0066cc", style="dashed"), EdgeKind.ADDEND: dict(color="#009933", style="bold"), EdgeKind.SPECTRUM_INPUT: dict(color="#cc0000", style="bold"), EdgeKind.EXPR_REF: dict(color="#9933cc", style="dotted"), } # --- Profile collapsing --- # Maps each per-sample node id to the representative node id for # its group. Nodes not in this dict are emitted as-is. collapsed: dict[int, int] = {} # sample_nid -> representative_nid # Groups: representative_nid -> (base_name, kind, count, first_node) _profile_groups: dict[int, tuple[str, NodeKind, int, GraphNode]] = {} if collapse_profiles: _sample_re = re.compile( r"^(.+?)_(profile_sample|profile_expr|sample)_(\d+)$" ) # Group nodes by (base_name, kind) groups: dict[tuple[str, NodeKind], list[GraphNode]] = {} for node in self.nodes: m = _sample_re.match(node.name) if m is not None: base = m.group(1) groups.setdefault((base, node.kind), []).append(node) for (base, kind), members in groups.items(): if len(members) < 2: continue rep = members[0] for member in members: collapsed[member.id] = rep.id _profile_groups[rep.id] = (base, kind, len(members), rep) hidden_nodes = {nid for nid in collapsed if collapsed[nid] != nid} lines: list[str] = [ "digraph ModelGraph {", " rankdir=BT;", ' node [fontname="Helvetica", fontsize=10];', ' edge [fontname="Helvetica", fontsize=8];', ] for node in self.nodes: if node.id in hidden_nodes: continue attrs = dict(_NODE_STYLE.get(node.kind, {})) if node.id in _profile_groups: base, kind, count, _ = _profile_groups[node.id] label_parts = [f"{base} (\u00d7{count})", kind.name] else: label_parts = [node.name, node.kind.name] if node.function_name: label_parts.append(f"fn={node.function_name}") if node.value is not None: label_parts.append(f"val={node.value:g}") if node.expr_string: label_parts.append(f"expr={node.expr_string}") attrs["label"] = "\\n".join(label_parts) attr_str = ", ".join(f'{k}="{v}"' for k, v in attrs.items()) lines.append(f" n{node.id} [{attr_str}];") seen_edges: set[tuple[int, int, EdgeKind, int | None]] = set() for edge in self.edges: src = collapsed.get(edge.source, edge.source) tgt = collapsed.get(edge.target, edge.target) key = (src, tgt, edge.kind, edge.position) if key in seen_edges: continue seen_edges.add(key) attrs = dict(_EDGE_STYLE.get(edge.kind, {})) label = edge.kind.name if edge.position is not None: label += f"[{edge.position}]" attrs["label"] = label attr_str = ", ".join(f'{k}="{v}"' for k, v in attrs.items()) lines.append(f" n{src} -> n{tgt} [{attr_str}];") lines.append("}") return "\n".join(lines)
# #
[docs] @dataclass(frozen=True) class ExprProgram: """Compiled expression: flat int array encoding an RPN program. Encoding: pairs of ``(node_kind, operand)``. - ``CONST``: operand is float bits (``np.float64.view(np.int64)``) - ``PARAM_REF``: operand is row index into trace matrix - Operators: operand is 0 (unused) """ instructions: np.ndarray # (2 * n_instructions,) int64
# #
[docs] @dataclass(frozen=True) class ScheduledPlan2D: """Compiled 2D execution schedule. No Python objects in the hot path (except ``expr_programs``). """ energy: np.ndarray # (n_energy,) time: np.ndarray # (n_time,) n_params: int n_time: int # --- Parameter mapping --- param_traces_init: np.ndarray # (n_params, n_time) opt_indices: np.ndarray # (n_opt,) int opt_param_names: list[str] # (n_opt,) canonical optimizer param names # --- Dynamics subgraphs (grouped by target PARAM_PLUS_TRACE) --- # Each "dynamics group" corresponds to one time-dependent parameter. # Multiple dynamics components (e.g. bi-exponential) targeting the # same PARAM_PLUS_TRACE are grouped together. Substeps within a # group are indexed via the CSR-style dyn_group_indptr. n_dyn_groups: int dyn_group_target_row: np.ndarray # (n_dyn_groups,) int dyn_group_base_row: np.ndarray # (n_dyn_groups,) int dyn_group_indptr: np.ndarray # (n_dyn_groups + 1,) int -- CSR into substep arrays dyn_sub_func_id: np.ndarray # (n_substeps,) int dyn_sub_param_rows: np.ndarray # (n_substeps, max_dyn_params) int, -1 padded dyn_sub_n_params: np.ndarray # (n_substeps,) int # Per-substep subcycle schedule data. SUBCYCLE_REMAP / # SUBCYCLE_MASK graph nodes are compiled away here: non-subcycle # substeps get ``time`` and ``ones`` defaults, subcycle substeps get # ``time_norm`` and ``time_n_sub`` copies. Evaluator applies them as # ``func(dyn_sub_time_axes[s], ...) * dyn_sub_masks[s]``. dyn_sub_time_axes: np.ndarray # (n_substeps, n_time) float64 dyn_sub_masks: np.ndarray # (n_substeps, n_time) float64 # --- Expression evaluation --- n_expressions: int expr_target_rows: np.ndarray # (n_expressions,) int expr_programs: list[ExprProgram] # --- Interleaved parameter resolution schedule --- # Dynamics groups, expressions, and convolution steps may depend on # each other (e.g. a dynamics param that is an expression of another # dynamics param; a conv step that wraps a resolved param). The # resolution_kinds / resolution_indices arrays encode the correct # topological execution order: # kind=0 -> dynamics group step, index into dyn_group_* arrays # kind=1 -> expression step, index into expr_* arrays / expr_programs # kind=2 -> convolution step, index into conv_* arrays resolution_kinds: np.ndarray # (n_dyn_groups + n_expressions + n_conv_steps,) int8 resolution_indices: np.ndarray # (n_dyn_groups + n_expressions + n_conv_steps,) int # --- Resolved-trace convolution program --- # Each conv step rewrites a trace row in-place after its PARAM_PLUS_TRACE # is fully resolved. Chained CONVOLUTION nodes emit multiple steps # targeting the same row, executed in order. Kernel values are # recomputed per theta from conv_param_rows; support values are frozen # at plan-build time from ``node.arrays["kernel_time"]``. Only # ``package == "time"`` kernels are lowered. n_conv_steps: int conv_target_rows: np.ndarray # (n_conv_steps,) int -- trace row rewritten conv_func_ids: np.ndarray # (n_conv_steps,) int -- kernel function registry id conv_param_indptr: np.ndarray # (n_conv_steps + 1,) int -- CSR row pointers conv_param_rows: np.ndarray # (total_conv_params,) int -- kernel param trace rows conv_support_indptr: ( np.ndarray ) # (n_conv_steps + 1,) int -- CSR into support values conv_support_values: ( np.ndarray ) # (total_support,) float -- kernel time axis samples # --- Profile-varying parameter groups (fixed aux_axis shape) --- n_aux: int aux_axis: np.ndarray # (n_aux,) n_profile_samples: int profile_sample_base_rows: np.ndarray # (n_profile_samples,) int profile_sample_component_indptr: np.ndarray # (n_profile_samples + 1,) int profile_component_func_ids: np.ndarray # (n_profile_components,) int profile_component_param_indptr: np.ndarray # (n_profile_components + 1,) int profile_component_param_rows: np.ndarray # (total_profile_component_params,) int n_profile_exprs: int profile_expr_programs: list[ExprProgram] # --- Scheduled component ops --- n_ops: int op_schedule: np.ndarray # (n_ops,) int op_kinds: np.ndarray # (n_ops,) OpKind int codes op_param_indptr: np.ndarray # (n_ops + 1,) int -- CSR row pointers op_param_source_kinds: np.ndarray # (total_op_params,) ParamSourceKind int codes op_param_indices: np.ndarray # (total_op_params,) int -- row/group indices op_needs_spectrum: np.ndarray # (n_ops,) bool op_is_pre_spectrum: np.ndarray # (n_ops,) bool op_is_profiled: np.ndarray # (n_ops,) bool op_is_constant: np.ndarray # (n_ops,) bool cached_result: np.ndarray # (n_time, n_energy) cached_peak_sum: np.ndarray # (n_time, n_energy)
# #
[docs] @dataclass(frozen=True) class ScheduledPlan1D: """Compiled 1D execution schedule for ENERGY_1D models. Simpler than ``ScheduledPlan2D``: no time axis, no dynamics, no trace matrix. Parameters are stored as a flat ``(n_params,)`` scalar vector. """ energy: np.ndarray # (n_energy,) n_params: int # --- Parameter mapping --- param_values_init: np.ndarray # (n_params,) initial scalar values opt_indices: np.ndarray # (n_opt,) int -- indices into param_values opt_param_names: list[str] # (n_opt,) canonical optimizer param names # --- Expression evaluation (topological order, no dynamics) --- n_expressions: int expr_target_indices: np.ndarray # (n_expressions,) int expr_programs: list[ExprProgram] # --- Profile-varying parameter groups (fixed aux_axis shape) --- n_aux: int aux_axis: np.ndarray # (n_aux,) n_profile_samples: int profile_sample_base_indices: np.ndarray # (n_profile_samples,) int profile_sample_component_indptr: np.ndarray # (n_profile_samples + 1,) int profile_component_func_ids: np.ndarray # (n_profile_components,) int profile_component_param_indptr: np.ndarray # (n_profile_components + 1,) int profile_component_param_indices: np.ndarray # (total_profile_component_params,) int n_profile_exprs: int profile_expr_programs: list[ExprProgram] # --- Scheduled component ops --- n_ops: int op_kinds: np.ndarray # (n_ops,) OpKind int codes op_param_indptr: np.ndarray # (n_ops + 1,) int -- CSR row pointers op_param_source_kinds: np.ndarray # (total_op_params,) ParamSourceKind int codes op_param_indices: np.ndarray # (total_op_params,) int -- source indices op_needs_spectrum: np.ndarray # (n_ops,) bool op_is_pre_spectrum: np.ndarray # (n_ops,) bool op_is_profiled: np.ndarray # (n_ops,) bool op_is_constant: np.ndarray # (n_ops,) bool cached_result: np.ndarray # (n_energy,) cached_peak_sum: np.ndarray # (n_energy,)
_PACKAGE_SHORT_NAMES: dict[str, str] = { "fcts_energy": "energy", "fcts_time": "time", "fcts_profile": "profile", } # def _package_short_name(comp: Component) -> str: """Return ``"energy"``, ``"time"``, or ``"profile"`` for a component.""" mod_name = comp.package_name.rsplit(".", maxsplit=1)[-1] return _PACKAGE_SHORT_NAMES.get(mod_name, mod_name) # def _par_initial_value(par: Par) -> float: """Extract the current scalar value from a Par's lmfit_par.""" vals = list(par.lmfit_par.valuesdict().values()) return float(vals[0]) if vals else 0.0 # def _par_bounds(par: Par) -> tuple[float, float] | None: """Extract bounds from a Par's lmfit_par, or None.""" for p in par.lmfit_par.values(): mn = p.min if p.min is not None else -np.inf mx = p.max if p.max is not None else np.inf return (float(mn), float(mx)) return None # def _is_expression_par(par: Par) -> bool: """True if this Par is defined by an expression string.""" return len(par.info) == 1 and isinstance(par.info[0], str) # def _par_expression_string(par: Par) -> str | None: """Return the expression string to use for graph wiring. Prefer the lmfit expression when available because Dynamics models auto-prefix references there (e.g. ``expFun_01_A`` -> ``GLP_01_A_expFun_01_A`` or ``parTEST_expFun_01_A``). Fall back to the raw ``par.info[0]`` expression for plain energy-model expressions. """ if not _is_expression_par(par): return None for lmfit_par in par.lmfit_par.values(): if lmfit_par.expr: return str(lmfit_par.expr) return str(par.info[0]) # def _extract_expression_references(expr_string: str) -> list[str]: """Extract identifier-like references from an expression string. This is intentionally generic: it collects all Python-identifier-like tokens in lexical order and leaves semantic filtering to the caller (for example, by checking whether the token is present in ``resolved_param`` or ``node_by_name``). """ pattern = r"\b[A-Za-z_][A-Za-z0-9_]*\b" refs: list[str] = [] seen: set[str] = set() for token in re.findall(pattern, expr_string): if token in seen: continue refs.append(token) seen.add(token) return refs # def _component_param_names(comp: Component) -> list[str]: """Return function parameter names, excluding axis arg and ``spectrum``.""" args = comp.fct_args[1:] # drop first (axis: x or t) if args and args[-1] == "spectrum": args = args[:-1] return args # def _par_is_vary(par: Par) -> bool: """True if the Par has ``vary=True`` in its lmfit parameter.""" return bool( par.lmfit_par.valuesdict() and any(p.vary for p in par.lmfit_par.values()) ) # # class _GraphBuilder: """Mutable state used while building a GraphIR from a Model.""" def __init__(self) -> None: self.nodes: list[GraphNode] = [] self.edges: list[GraphEdge] = [] self.node_by_name: dict[str, int] = {} self._next_id: int = 0 self._source_order: int = 0 self._removed: set[int] = set() # def add_node(self, kind: NodeKind, name: str, **kwargs) -> int: """Create a node, assign it an id and source_order, return id.""" nid = self._next_id self._next_id += 1 order = self._source_order self._source_order += 1 node = GraphNode(id=nid, kind=kind, name=name, source_order=order, **kwargs) self.nodes.append(node) self.node_by_name[name] = nid return nid # def add_edge( self, source: int, target: int, kind: EdgeKind, *, position: int | None = None ) -> None: """Append an edge connecting source -> target with the given kind.""" self.edges.append( GraphEdge(source=source, target=target, kind=kind, position=position) ) # def mark_removed(self, nid: int) -> None: """Mark a node for removal during finalization.""" self._removed.add(nid) # def finalize(self) -> tuple[list[GraphNode], list[GraphEdge], dict[str, int]]: """Remove marked nodes/edges and re-index so node id == list position.""" if not self._removed: return self.nodes, self.edges, dict(self.node_by_name) kept_nodes = [n for n in self.nodes if n.id not in self._removed] kept_edges = [ e for e in self.edges if e.source not in self._removed and e.target not in self._removed ] # Re-index: old id -> new dense id id_map = {old.id: new_id for new_id, old in enumerate(kept_nodes)} for node in kept_nodes: node.id = id_map[node.id] for edge in kept_edges: edge.source = id_map[edge.source] edge.target = id_map[edge.target] node_by_name = {n.name: n.id for n in kept_nodes} return kept_nodes, kept_edges, node_by_name #
[docs] def build_graph(model: Model) -> GraphIR: """Walk the OOP Model tree and emit a GraphIR. Parameters ---------- model : Model A fully-constructed model (components, pars, and any dynamics / profile models already attached). Returns ------- GraphIR The semantic DAG representation. """ b = _GraphBuilder() # ----- determine domain ----- has_energy = model.energy is not None has_time = model.time is not None if has_energy and has_time: domain = DomainKind.ENERGY_TIME_2D elif has_time: domain = DomainKind.TIME_1D else: domain = DomainKind.ENERGY_1D # ------------------------------------------------------------------ # # 1. Create parameter nodes for every component # # ------------------------------------------------------------------ # # Maps par.name -> node id of the *resolved* value (which might be a # PARAM_PLUS_TRACE node if time-dependent, or an EXPRESSION node). resolved_param: dict[str, int] = {} for comp in model.components: for par in comp.pars: _emit_par_nodes(b, par, resolved_param) # ------------------------------------------------------------------ # # 2. Create component nodes and wire PARAM_INPUT edges # # ------------------------------------------------------------------ # # Collect component node ids for combination wiring. # comp_nodes: list of (comp, node_id, is_spectrum_fed) comp_nodes: list[tuple[Component, int, bool]] = [] for comp in model.components: if comp.comp_type == "none": continue pkg_name = _package_short_name(comp) is_shirley = comp.fct_str == "Shirley" # Convolution components if comp.comp_type == "conv": nid = b.add_node( NodeKind.CONVOLUTION, comp.comp_name, function_name=comp.fct_str, package=pkg_name, ) # Store kernel-related arrays if available if comp.time is not None: b.nodes[nid].arrays["kernel_time"] = comp.time elif is_shirley: nid = b.add_node( NodeKind.SPECTRUM_FED_OP, comp.comp_name, function_name=comp.fct_str, package=pkg_name, ) else: nid = b.add_node( NodeKind.COMPONENT_EVAL, comp.comp_name, function_name=comp.fct_str, package=pkg_name, ) # Wire PARAM_INPUT edges from resolved params to component param_names = _component_param_names(comp) for pos, _pname in enumerate(param_names): par = comp.pars[pos] src = resolved_param[par.name] b.add_edge(src, nid, EdgeKind.PARAM_INPUT, position=pos) comp_nodes.append((comp, nid, is_shirley)) # ------------------------------------------------------------------ # # 3. Emit PROFILE_*, SUBCYCLE_* nodes for components that use them # # ------------------------------------------------------------------ # # profile_samples maps par_name -> [sample_nid_0, ..., sample_nid_n]. # Populated by p_vary components, consumed by expr_refs_profile_dep # components. This preserves per-sample context across components so # that expression params referencing a profiled param get per-sample # EXPRESSION nodes (not a single averaged replacement). profile_samples: dict[str, list[int]] = {} for i, (comp, nid, is_shirley) in enumerate(comp_nodes): new_nid = _emit_profile_nodes(b, comp, nid, resolved_param, profile_samples) if new_nid != nid: b.mark_removed(nid) comp_nodes[i] = (comp, new_nid, is_shirley) for comp in model.components: if comp.comp_type == "none": continue _emit_subcycle_nodes(b, comp, resolved_param) # ------------------------------------------------------------------ # # 4. Create expression edges # # ------------------------------------------------------------------ # # Runs after profile nodes. Expression params that are # expr_refs_profile_dep are already fully wired per-sample inside # _emit_profile_nodes, so skip them here. for comp in model.components: for par in comp.pars: if not _is_expression_par(par): continue if par.expr_refs_profile_dep: continue # handled per-sample in _emit_profile_nodes expr_nid = b.node_by_name[par.name] expr_str = _par_expression_string(par) if expr_str is None: continue refs = _extract_expression_references(expr_str) for ref_name in refs: if ref_name in resolved_param: ref_nid = resolved_param[ref_name] b.add_edge(ref_nid, expr_nid, EdgeKind.EXPR_REF) # ------------------------------------------------------------------ # # 5. Create SUM / combination nodes # # ------------------------------------------------------------------ # _emit_combination_nodes(b, comp_nodes) nodes, edges, node_by_name = b.finalize() return GraphIR( nodes=nodes, edges=edges, domain=domain, energy=model.energy, time=model.time, node_by_name=node_by_name, )
# def _emit_par_nodes( b: _GraphBuilder, par: Par, resolved_param: dict[str, int], ) -> None: """Create graph nodes for one Par (and its dynamics/profile subgraph).""" # Expression parameter if _is_expression_par(par): nid = b.add_node( NodeKind.EXPRESSION, par.name, expr_string=par.info[0], value=_par_initial_value(par), ) resolved_param[par.name] = nid return # Base parameter node if _par_is_vary(par): base_nid = b.add_node( NodeKind.OPT_PARAM, par.name, value=_par_initial_value(par), vary=True, bounds=_par_bounds(par), ) else: base_nid = b.add_node( NodeKind.STATIC_PARAM, par.name, value=_par_initial_value(par), ) # Default: resolved value is the base node itself resolved_param[par.name] = base_nid # Time-dependent parameter (Dynamics subgraph) if par.t_vary and par.t_model is not None: _emit_dynamics_subgraph(b, par, base_nid, resolved_param) # def _emit_dynamics_subgraph( b: _GraphBuilder, par: Par, base_nid: int, resolved_param: dict[str, int], ) -> None: """Emit DYNAMICS_TRACE + PARAM_PLUS_TRACE nodes for a time-dep par.""" t_model = par.t_model assert t_model is not None # Create parameter nodes for each dynamics component's parameters. # Dynamics pars can be expressions (e.g. multi-cycle subcycle models # where expFun_02_A = "-expFun_01_A"). dyn_param_nids: list[list[int]] = [] dyn_expr_pars: list[tuple[Par, int]] = [] # (par, node_id) for deferred wiring for dyn_comp in t_model.components: if dyn_comp.comp_type == "none": dyn_param_nids.append([]) continue comp_nids: list[int] = [] for dyn_par in dyn_comp.pars: if _is_expression_par(dyn_par): # Store the canonical (lmfit-prefixed) expression, not # the raw YAML text, so expr_string matches the EXPR_REF # edges wired below. canonical_expr = _par_expression_string(dyn_par) or dyn_par.info[0] dnid = b.add_node( NodeKind.EXPRESSION, dyn_par.name, expr_string=canonical_expr, value=_par_initial_value(dyn_par), ) dyn_expr_pars.append((dyn_par, dnid)) elif _par_is_vary(dyn_par): dnid = b.add_node( NodeKind.OPT_PARAM, dyn_par.name, value=_par_initial_value(dyn_par), vary=True, bounds=_par_bounds(dyn_par), ) else: dnid = b.add_node( NodeKind.STATIC_PARAM, dyn_par.name, value=_par_initial_value(dyn_par), ) comp_nids.append(dnid) dyn_param_nids.append(comp_nids) # Wire EXPR_REF edges for dynamics expression pars. # Dynamics expressions may be auto-prefixed by lmfit (e.g. # "expFun_01_A" -> "GLP_01_A_expFun_01_A" or "parTEST_expFun_01_A"). # Use the canonical expression string and then match identifier refs # against graph node names. for dyn_par, expr_nid in dyn_expr_pars: expr_str = _par_expression_string(dyn_par) if expr_str is None: continue refs = _extract_expression_references(expr_str) for ref_name in refs: if ref_name in b.node_by_name: b.add_edge(b.node_by_name[ref_name], expr_nid, EdgeKind.EXPR_REF) # Create DYNAMICS_TRACE or CONVOLUTION nodes per dynamics component. # Convolution components (gaussCONV, etc.) are CONVOLUTION nodes that # wrap the resolved trace (conv(trace, kernel)), not addends. trace_nids: list[int] = [] conv_nids: list[int] = [] for i, dyn_comp in enumerate(t_model.components): if dyn_comp.comp_type == "none": continue node_name = f"{par.name}_dynamics" if len(t_model.components) > 1: node_name = f"{par.name}_{dyn_comp.comp_name}_dynamics" if dyn_comp.comp_type == "conv": nid = b.add_node( NodeKind.CONVOLUTION, node_name, function_name=dyn_comp.fct_str, package="time", ) if dyn_comp.time is not None: b.nodes[nid].arrays["kernel_time"] = dyn_comp.time # Wire dynamics params -> conv node for pos, dnid in enumerate(dyn_param_nids[i]): b.add_edge(dnid, nid, EdgeKind.PARAM_INPUT, position=pos) conv_nids.append(nid) else: nid = b.add_node( NodeKind.DYNAMICS_TRACE, node_name, function_name=dyn_comp.fct_str, package="time", ) # Wire dynamics params -> trace node for pos, dnid in enumerate(dyn_param_nids[i]): b.add_edge(dnid, nid, EdgeKind.PARAM_INPUT, position=pos) trace_nids.append(nid) # Create PARAM_PLUS_TRACE node (base + sum of traces) resolved_name = f"{par.name}_resolved" resolved_nid = b.add_node( NodeKind.PARAM_PLUS_TRACE, resolved_name, ) b.add_edge(base_nid, resolved_nid, EdgeKind.BASE_INPUT) for trace_nid in trace_nids: b.add_edge(trace_nid, resolved_nid, EdgeKind.TRACE_INPUT) # Convolution nodes wrap the resolved trace: conv(resolved, kernel). # Each CONVOLUTION takes TRACE_INPUT from the current resolved node # and produces the new resolved value. for conv_nid in conv_nids: b.add_edge(resolved_nid, conv_nid, EdgeKind.TRACE_INPUT) resolved_nid = conv_nid resolved_param[par.name] = resolved_nid # def _rewire_param_input( b: _GraphBuilder, target: int, *, old_source: int, new_source: int ) -> None: """Replace the source of a PARAM_INPUT edge targeting *target*.""" for edge in b.edges: if ( edge.target == target and edge.kind == EdgeKind.PARAM_INPUT and edge.source == old_source ): edge.source = new_source return # def _rewire_trace_input( b: _GraphBuilder, target: int, *, old_source: int, new_source: int ) -> None: """Replace the source of a TRACE_INPUT edge targeting *target*.""" for edge in b.edges: if ( edge.target == target and edge.kind == EdgeKind.TRACE_INPUT and edge.source == old_source ): edge.source = new_source return # def _emit_profile_nodes( b: _GraphBuilder, comp: Component, comp_nid: int, resolved_param: dict[str, int], profile_samples: dict[str, list[int]], ) -> int: """Emit per-sample evaluation subgraph for profiled components. A component needs profile treatment when any of its params is ``p_vary`` (directly profiled) or ``expr_refs_profile_dep`` (an expression referencing a profiled param on another component). The interpreter evaluates the full component at each aux-axis point and averages the resulting traces: ``mean_i(f(p_0, ..., p_i, ...))``. Graph structure per aux point *i*: - For ``p_vary`` params: a PROFILE_SAMPLE node computes the parameter value at aux point *i* from the base value and the profile function. - For ``expr_refs_profile_dep`` params: a per-sample EXPRESSION node evaluates the expression using the PROFILE_SAMPLE of the referenced profiled param at the same aux point. - A per-sample COMPONENT_EVAL evaluates the component function with these per-sample inputs. After all aux points: - A component-level PROFILE_AVERAGE averages the per-sample COMPONENT_EVAL traces and replaces the original component in the combination graph. ``profile_samples`` is shared across components so that ``expr_refs_profile_dep`` params on one component can reference the PROFILE_SAMPLE nodes created for a ``p_vary`` param on a different component. Returns ------- int Node id that replaces *comp_nid* in the combination graph, or *comp_nid* unchanged if no profile treatment is needed. """ p_vary_pars = [p for p in comp.pars if p.p_vary and p.p_model is not None] expr_dep_pars = [p for p in comp.pars if p.expr_refs_profile_dep] if not p_vary_pars and not expr_dep_pars: return comp_nid # --- Determine aux_axis length --- n_aux: int | None = None if p_vary_pars: aux_axis = p_vary_pars[0].p_model.aux_axis # type: ignore[union-attr] if aux_axis is None: return comp_nid n_aux = len(aux_axis) if n_aux is None: # expr_dep_pars only — infer n_aux from an already-populated # profile_samples entry. for ref_name in _expr_dep_profile_refs(expr_dep_pars): if ref_name in profile_samples: n_aux = len(profile_samples[ref_name]) aux_axis = None # not needed for expr-dep-only components break if n_aux is None: return comp_nid pkg_name = _package_short_name(comp) # --- Profile function parameter nodes for p_vary pars --- prof_param_nids_by_par: dict[str, list[int]] = {} for par in p_vary_pars: p_model = par.p_model assert p_model is not None prof_param_nids: list[int] = [] prof_resolved: dict[str, int] = {} for prof_comp in p_model.components: if prof_comp.comp_type == "none": continue for prof_par in prof_comp.pars: _emit_par_nodes(b, prof_par, prof_resolved) prof_param_nids.append(prof_resolved[prof_par.name]) prof_param_nids_by_par[par.name] = prof_param_nids # --- Per-sample nodes --- sample_comp_nids: list[int] = [] for aux_i in range(n_aux): # PROFILE_SAMPLE for each p_vary param sample_nids: dict[str, int] = {} # par.name -> sample nid for par in p_vary_pars: base_nid = resolved_param[par.name] p_aux = p_vary_pars[0].p_model.aux_axis # type: ignore[union-attr] sample_nid = b.add_node( NodeKind.PROFILE_SAMPLE, f"{par.name}_profile_sample_{aux_i}", arrays={"aux_axis": p_aux}, ) b.add_edge(base_nid, sample_nid, EdgeKind.PARAM_INPUT, position=0) for pi, pnid in enumerate(prof_param_nids_by_par[par.name]): b.add_edge(pnid, sample_nid, EdgeKind.PARAM_INPUT, position=pi + 1) sample_nids[par.name] = sample_nid profile_samples.setdefault(par.name, []).append(sample_nid) # Per-sample EXPRESSION for each expr_refs_profile_dep param expr_sample_nids: dict[str, int] = {} # par.name -> expr nid for par in expr_dep_pars: expr_str = _par_expression_string(par) if expr_str is None: continue expr_nid = b.add_node( NodeKind.EXPRESSION, f"{par.name}_profile_expr_{aux_i}", expr_string=expr_str, value=_par_initial_value(par), ) # Wire EXPR_REF to the per-sample PROFILE_SAMPLE of the # referenced profiled param (not the averaged value). refs = _extract_expression_references(expr_str) for ref_name in refs: if ref_name in profile_samples: b.add_edge( profile_samples[ref_name][aux_i], expr_nid, EdgeKind.EXPR_REF, ) elif ref_name in resolved_param: b.add_edge( resolved_param[ref_name], expr_nid, EdgeKind.EXPR_REF, ) expr_sample_nids[par.name] = expr_nid # Per-sample COMPONENT_EVAL sample_eval_nid = b.add_node( NodeKind.COMPONENT_EVAL, f"{comp.comp_name}_sample_{aux_i}", function_name=comp.fct_str, package=pkg_name, ) # Wire params into the sample component eval param_names = _component_param_names(comp) for pos, _pname in enumerate(param_names): par = comp.pars[pos] if par.name in sample_nids: src = sample_nids[par.name] elif par.name in expr_sample_nids: src = expr_sample_nids[par.name] else: src = resolved_param[par.name] b.add_edge(src, sample_eval_nid, EdgeKind.PARAM_INPUT, position=pos) sample_comp_nids.append(sample_eval_nid) # --- Component-level PROFILE_AVERAGE over sample traces --- comp_avg_nid = b.add_node( NodeKind.PROFILE_AVERAGE, f"{comp.comp_name}_profile_avg", ) for sc_nid in sample_comp_nids: b.add_edge(sc_nid, comp_avg_nid, EdgeKind.ADDEND) return comp_avg_nid # def _expr_dep_profile_refs(expr_dep_pars: list[Par]) -> list[str]: """Collect profiled param names referenced by expr_refs_profile_dep pars.""" refs: list[str] = [] seen: set[str] = set() for par in expr_dep_pars: expr_str = _par_expression_string(par) if expr_str is None: continue for ref in _extract_expression_references(expr_str): if ref not in seen: refs.append(ref) seen.add(ref) return refs # def _emit_subcycle_nodes( b: _GraphBuilder, comp: Component, resolved_param: dict[str, int], ) -> None: """Emit SUBCYCLE_REMAP and SUBCYCLE_MASK nodes for subcycle dynamics. SUBCYCLE_REMAP feeds into the DYNAMICS_TRACE (provides the remapped time axis). SUBCYCLE_MASK consumes the DYNAMICS_TRACE output (zeroes inactive regions). """ # Standalone TIME_1D dynamics model: subcycle info lives directly on the # time component itself rather than on a parent Par.t_model. comp_nid = b.node_by_name.get(comp.comp_name) if ( comp_nid is not None and comp.subcycle != 0 and comp.time_norm is not None and comp.time_n_sub is not None ): remap_nid = b.add_node( NodeKind.SUBCYCLE_REMAP, f"{comp.comp_name}_remap", arrays={"time_norm": comp.time_norm}, ) b.add_edge(remap_nid, comp_nid, EdgeKind.TRACE_INPUT) mask_nid = b.add_node( NodeKind.SUBCYCLE_MASK, f"{comp.comp_name}_mask", arrays={"time_n_sub": comp.time_n_sub}, ) b.add_edge(comp_nid, mask_nid, EdgeKind.TRACE_INPUT) for par in comp.pars: if not par.t_vary or par.t_model is None: continue t_model = par.t_model for dyn_comp in t_model.components: if dyn_comp.subcycle == 0: continue if dyn_comp.time_norm is None or dyn_comp.time_n_sub is None: continue # Find the DYNAMICS_TRACE node for this dynamics component trace_name = f"{par.name}_dynamics" if len(t_model.components) > 1: trace_name = f"{par.name}_{dyn_comp.comp_name}_dynamics" trace_nid = b.node_by_name.get(trace_name) if trace_nid is None: continue # SUBCYCLE_REMAP before the dynamics trace remap_nid = b.add_node( NodeKind.SUBCYCLE_REMAP, f"{par.name}_{dyn_comp.comp_name}_remap", arrays={"time_norm": dyn_comp.time_norm}, ) b.add_edge(remap_nid, trace_nid, EdgeKind.TRACE_INPUT) # SUBCYCLE_MASK after the dynamics trace mask_nid = b.add_node( NodeKind.SUBCYCLE_MASK, f"{par.name}_{dyn_comp.comp_name}_mask", arrays={"time_n_sub": dyn_comp.time_n_sub}, ) b.add_edge(trace_nid, mask_nid, EdgeKind.TRACE_INPUT) # Rewire PARAM_PLUS_TRACE to consume the masked trace # instead of the raw DYNAMICS_TRACE. resolved_nid = resolved_param.get(par.name) if resolved_nid is not None: _rewire_trace_input( b, resolved_nid, old_source=trace_nid, new_source=mask_nid ) # def _emit_combination_nodes( b: _GraphBuilder, comp_nodes: list[tuple[Component, int, bool]], ) -> None: """Create SUM nodes that mirror the model's LIFO combine semantics. Classification: - **peaks**: ``comp_type == "add"`` — feed ``peak_sum`` - **backgrounds**: ``comp_type == "back"`` but *not* spectrum-fed (Offset, LinBack) — feed ``total`` directly, *not* ``peak_sum`` - **spectrum-fed**: Shirley — receives ``SPECTRUM_INPUT`` from ``peak_sum``, feeds ``total`` - **convolution**: ``comp_type == "conv"`` — receives ``ADDEND`` from ``peak_sum``, feeds ``total`` This matches the spec example (lowered_evaluator.md lines 298-305): only peaks contribute to ``peak_sum``; Offset/LinBack are added at the ``total`` level. """ if not comp_nodes: return # Classify nodes peaks: list[int] = [] # comp_type == "add" backgrounds: list[int] = [] # Offset, LinBack (comp_type == "back", not Shirley) spectrum_fed: list[int] = [] # Shirley convolution: list[int] = [] # conv components for comp, nid, is_shirley in comp_nodes: if comp.comp_type == "conv": convolution.append(nid) elif is_shirley: spectrum_fed.append(nid) elif comp.comp_type == "add": peaks.append(nid) else: # comp_type == "back" but not Shirley -> Offset, LinBack backgrounds.append(nid) # peak_sum: accumulates only peak (comp_type == "add") components peak_sum_nid: int | None = None if peaks: peak_sum_nid = b.add_node(NodeKind.SUM, "peak_sum") for nid in peaks: b.add_edge(nid, peak_sum_nid, EdgeKind.ADDEND) # Wire SPECTRUM_INPUT edges from peak_sum to spectrum-fed ops if peak_sum_nid is not None: for nid in spectrum_fed: b.add_edge(peak_sum_nid, nid, EdgeKind.SPECTRUM_INPUT) # Convolution nodes get ADDEND edge from peak_sum (signal to convolve) if peak_sum_nid is not None: for nid in convolution: b.add_edge(peak_sum_nid, nid, EdgeKind.ADDEND) # total: final sum of everything # Collect all addends for total. Use peak_sum (not individual peaks) # so the graph reflects the semantic grouping. total_addends: list[int] = [] if peak_sum_nid is not None: total_addends.append(peak_sum_nid) total_addends.extend(backgrounds) total_addends.extend(spectrum_fed) total_addends.extend(convolution) if len(total_addends) > 1: total_nid = b.add_node(NodeKind.SUM, "total") for nid in total_addends: b.add_edge(nid, total_nid, EdgeKind.ADDEND) # --------------------------------------------------------------------------- # can_lower_2d # --------------------------------------------------------------------------- _LOWERABLE_2D_FUNCTIONS: frozenset[str] = frozenset( { "Gauss", "GaussAsym", "Lorentz", "Voigt", "GLS", "GLP", "DS", "Offset", "LinBack", "Shirley", } ) # Base set of non-lowerable node kinds shared across backends. # Backend-specific sets below carve out nodes as each backend gains # support (e.g. profile nodes for 1D and 2D). _NON_LOWERABLE_NODE_KINDS_BASE: frozenset[NodeKind] = frozenset( { NodeKind.CONVOLUTION, NodeKind.PROFILE_SAMPLE, NodeKind.PROFILE_AVERAGE, NodeKind.SUBCYCLE_MASK, NodeKind.SUBCYCLE_REMAP, } ) # 2D backend: profiles, resolved-trace time-domain convolution, and # subcycle dynamics are compilable. Convolution lowerability is # additionally gated per-node by ``_is_lowerable_convolution_2d``. # Subcycle substeps are compiled away into dyn_sub_time_axes / # dyn_sub_masks schedule arrays. _NON_LOWERABLE_2D_NODE_KINDS: frozenset[NodeKind] = ( _NON_LOWERABLE_NODE_KINDS_BASE - frozenset( { NodeKind.PROFILE_SAMPLE, NodeKind.PROFILE_AVERAGE, NodeKind.CONVOLUTION, NodeKind.SUBCYCLE_MASK, NodeKind.SUBCYCLE_REMAP, } ) ) # def _is_lowerable_convolution_2d(node: GraphNode, graph: GraphIR) -> bool: """Return True if a CONVOLUTION node can be lowered by the 2D backend. Lowering contract -- all of: 1. Time-domain kernel (``package == "time"``) with a registered kernel function and a ``kernel_time`` array populated at graph-build time. 2. Resolved-trace shape: the node has exactly one ``TRACE_INPUT`` ancestor, and walking that chain terminates at a ``PARAM_PLUS_TRACE``. The structural check (item 2) is the same walk the scheduler performs in :func:`_resolve_convolution_target_row`, kept in sync so that unsupported conv flavours (spectrum-level ``comp_type="conv"`` components that receive ``ADDEND`` from ``peak_sum``, externally produced graphs with unusual conv topology, etc.) fall back to MCP cleanly instead of raising inside ``schedule_2d``. """ if ( node.kind != NodeKind.CONVOLUTION or node.package != "time" or node.function_name not in _FUNCTION_NAME_TO_CONV_KERNEL or "kernel_time" not in node.arrays ): return False id_to_node = {n.id: n for n in graph.nodes} current = node # Bounded walk: guards against malformed chains and cycles. for _ in range(len(id_to_node)): trace_parents = [ e for e in graph.edges if e.target == current.id and e.kind == EdgeKind.TRACE_INPUT ] if len(trace_parents) != 1: return False parent = id_to_node[trace_parents[0].source] if parent.kind == NodeKind.PARAM_PLUS_TRACE: return True if parent.kind != NodeKind.CONVOLUTION: return False current = parent return False # def _is_lowerable_subcycle_2d(node: GraphNode, graph: GraphIR) -> bool: """Return True if a SUBCYCLE_REMAP / MASK node meets the lowering contract. The scheduler replaces per-substep ``graph.time`` / all-ones defaults with the payload carried by these nodes; once ``can_lower_2d`` vouches, ``schedule_2d`` reads the arrays unconditionally. All of: 1. The required payload array is present on the node and is a 1D ``float64``-compatible ndarray of length ``len(graph.time)``. 2. Canonical topology: SUBCYCLE_REMAP -> DYNAMICS_TRACE (TRACE_INPUT) DYNAMICS_TRACE -> SUBCYCLE_MASK (TRACE_INPUT, if present) SUBCYCLE_MASK -> PARAM_PLUS_TRACE (TRACE_INPUT) Anything off-pattern (multiple consumers, chain doesn't end at a PPT, etc.) is rejected so malformed graphs fall back to MCP cleanly instead of silently producing wrong output. """ if node.kind not in (NodeKind.SUBCYCLE_REMAP, NodeKind.SUBCYCLE_MASK): return False if graph.time is None: return False n_time = len(graph.time) key = "time_norm" if node.kind == NodeKind.SUBCYCLE_REMAP else "time_n_sub" arr = node.arrays.get(key) if arr is None or not isinstance(arr, np.ndarray): return False if arr.ndim != 1 or arr.shape[0] != n_time: return False id_to_node = {n.id: n for n in graph.nodes} if node.kind == NodeKind.SUBCYCLE_REMAP: outs = [ e for e in graph.edges if e.source == node.id and e.kind == EdgeKind.TRACE_INPUT ] if len(outs) != 1: return False return id_to_node[outs[0].target].kind == NodeKind.DYNAMICS_TRACE # SUBCYCLE_MASK: one TRACE_INPUT in from DYNAMICS_TRACE, one out to PPT. ins = [ e for e in graph.edges if e.target == node.id and e.kind == EdgeKind.TRACE_INPUT ] outs = [ e for e in graph.edges if e.source == node.id and e.kind == EdgeKind.TRACE_INPUT ] if len(ins) != 1 or len(outs) != 1: return False return ( id_to_node[ins[0].source].kind == NodeKind.DYNAMICS_TRACE and id_to_node[outs[0].target].kind == NodeKind.PARAM_PLUS_TRACE ) #
[docs] def can_lower_2d(graph: GraphIR) -> bool: """Check whether the 2D NumPy backend can compile this graph. Parameters ---------- graph : GraphIR The model graph to check. Returns ------- bool True if ``schedule_2d`` can compile this graph. """ if graph.domain != DomainKind.ENERGY_TIME_2D: return False for node in graph.nodes: # Reject future node types not yet compilable if node.kind in _NON_LOWERABLE_2D_NODE_KINDS: return False # Convolution nodes are per-node gated: only time-domain kernels # with a registered kernel function, frozen support, and the # resolved-trace topology the scheduler expects are lowered. if node.kind == NodeKind.CONVOLUTION: if not _is_lowerable_convolution_2d(node, graph): return False # Subcycle nodes are per-node gated: required payload arrays and # canonical REMAP -> DYNAMICS_TRACE -> MASK? -> PPT topology. if node.kind in (NodeKind.SUBCYCLE_REMAP, NodeKind.SUBCYCLE_MASK): if not _is_lowerable_subcycle_2d(node, graph): return False # Check component functions are supported if node.kind in (NodeKind.COMPONENT_EVAL, NodeKind.SPECTRUM_FED_OP): if node.function_name not in _LOWERABLE_2D_FUNCTIONS: return False # Check dynamics functions are supported if node.kind == NodeKind.DYNAMICS_TRACE: if node.function_name not in _FUNCTION_NAME_TO_DYN_FUNC: return False # Check profile nodes have required aux_axis array if node.kind == NodeKind.PROFILE_SAMPLE: if "aux_axis" not in node.arrays: return False # Check expressions are arithmetic-only (defer full AST check # to the expression compiler; here just reject obvious non-starters) if node.kind == NodeKind.EXPRESSION and node.expr_string is not None: if not _is_arithmetic_expression(node.expr_string): return False return True
# Node kinds that are never valid in 1D energy models. 1D models have # no time axis, so DYNAMICS_TRACE / PARAM_PLUS_TRACE should not appear. # Start from the 2D blocklist so future unsupported node kinds propagate # automatically, then carve out the profile nodes that 1D lowers # explicitly. _NON_LOWERABLE_1D_NODE_KINDS: frozenset[NodeKind] = ( _NON_LOWERABLE_NODE_KINDS_BASE - frozenset({NodeKind.PROFILE_SAMPLE, NodeKind.PROFILE_AVERAGE}) | frozenset( { NodeKind.DYNAMICS_TRACE, NodeKind.PARAM_PLUS_TRACE, } ) ) #
[docs] def can_lower_1d(graph: GraphIR) -> bool: """Check whether the 1D NumPy backend can compile this graph. Parameters ---------- graph : GraphIR The model graph to check. Returns ------- bool True if ``schedule_1d`` can compile this graph. """ if graph.domain != DomainKind.ENERGY_1D: return False for node in graph.nodes: if node.kind in _NON_LOWERABLE_1D_NODE_KINDS: return False if node.kind in (NodeKind.COMPONENT_EVAL, NodeKind.SPECTRUM_FED_OP): if node.function_name not in _LOWERABLE_2D_FUNCTIONS: return False if node.kind == NodeKind.PROFILE_SAMPLE: if "aux_axis" not in node.arrays: return False if node.kind == NodeKind.EXPRESSION and node.expr_string is not None: if not _is_arithmetic_expression(node.expr_string): return False return True
# def _is_arithmetic_expression(expr_string: str) -> bool: """Return True if the expression uses only arithmetic ops and param refs. Does a lightweight check via the ``ast`` module. Rejects function calls, attribute access, subscripts, etc. """ import ast try: tree = ast.parse(expr_string, mode="eval") except SyntaxError: return False for node in ast.walk(tree): if isinstance( node, ( ast.Expression, ast.BinOp, ast.UnaryOp, ast.Constant, ast.Name, ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Pow, ast.USub, ast.UAdd, ast.Load, ), ): continue return False return True # --------------------------------------------------------------------------- # Symbolic expression compiler # --------------------------------------------------------------------------- # #
[docs] @dataclass(frozen=True) class SymbolicRPN: """Symbolic RPN program with parameter references by name. This is the *frontend* output of the expression compiler. ``schedule_2d`` binds names to trace-matrix row indices and produces the final ``ExprProgram``. Each instruction is a ``(ExprNodeKind, operand)`` pair: - ``CONST``: operand is the float value itself - ``PARAM_REF``: operand is the parameter name (str) - Operators: operand is ``None`` """ instructions: list[tuple[ExprNodeKind, float | str | None]] referenced_names: list[str] # unique param names in order of first appearance
#
[docs] def compile_expr_symbolic(expr_string: str) -> SymbolicRPN: """Parse an arithmetic expression string into symbolic RPN. Uses the Python ``ast`` module to walk the expression tree and emit a postfix instruction sequence. Parameter references are kept as name strings; the scheduler resolves them to row indices. Parameters ---------- expr_string : str Arithmetic expression (e.g. ``"3/4*GLP_01_A"``). Returns ------- SymbolicRPN The symbolic RPN program. Raises ------ ValueError If the expression contains unsupported AST nodes. """ import ast tree = ast.parse(expr_string, mode="eval") instructions: list[tuple[ExprNodeKind, float | str | None]] = [] names_seen: dict[str, None] = {} # ordered set via dict _OP_MAP: dict[type, ExprNodeKind] = { ast.Add: ExprNodeKind.ADD, ast.Sub: ExprNodeKind.SUB, ast.Mult: ExprNodeKind.MUL, ast.Div: ExprNodeKind.DIV, ast.Pow: ExprNodeKind.POW, } # def _walk(node: ast.AST) -> None: if isinstance(node, ast.Expression): _walk(node.body) elif isinstance(node, ast.Constant): if not isinstance(node.value, (int, float)): raise ValueError( f"Unsupported constant {node.value!r}" f" in expression: {expr_string!r}" ) instructions.append((ExprNodeKind.CONST, float(node.value))) elif isinstance(node, ast.Name): instructions.append((ExprNodeKind.PARAM_REF, node.id)) names_seen.setdefault(node.id, None) elif isinstance(node, ast.BinOp): _walk(node.left) _walk(node.right) op_kind = _OP_MAP.get(type(node.op)) if op_kind is None: raise ValueError( f"Unsupported binary op {type(node.op).__name__!r}" f" in expression: {expr_string!r}" ) instructions.append((op_kind, None)) elif isinstance(node, ast.UnaryOp): if isinstance(node.op, ast.USub): _walk(node.operand) instructions.append((ExprNodeKind.NEG, None)) elif isinstance(node.op, ast.UAdd): _walk(node.operand) # UAdd is a no-op else: raise ValueError( f"Unsupported unary op {type(node.op).__name__!r}" f" in expression: {expr_string!r}" ) else: raise ValueError( f"Unsupported AST node {type(node).__name__!r}" f" in expression: {expr_string!r}" ) _walk(tree) return SymbolicRPN( instructions=instructions, referenced_names=list(names_seen), )
# def _bind_expr_to_rows( symbolic: SymbolicRPN, name_to_row: dict[str, int], ) -> ExprProgram: """Convert a symbolic RPN program to a row-bound ExprProgram. Parameters ---------- symbolic : SymbolicRPN The symbolic RPN from ``compile_expr_symbolic``. name_to_row : dict[str, int] Maps parameter names to trace-matrix row indices. Returns ------- ExprProgram Row-bound RPN program ready for the evaluator. """ flat: list[int] = [] for kind, operand in symbolic.instructions: flat.append(int(kind)) if kind == ExprNodeKind.CONST: assert isinstance(operand, (int, float)) flat.append(int(np.float64(operand).view(np.int64))) elif kind == ExprNodeKind.PARAM_REF: assert isinstance(operand, str) flat.append(name_to_row[operand]) else: flat.append(0) return ExprProgram(instructions=np.array(flat, dtype=np.int64)) # --------------------------------------------------------------------------- # schedule_2d # --------------------------------------------------------------------------- # def _walk_convolution_to_param_plus_trace( conv_node: GraphNode, edges: list[GraphEdge], id_to_node: dict[int, GraphNode], ) -> GraphNode: """Walk a conv chain back to its underlying PARAM_PLUS_TRACE node.""" current = conv_node # Bounded walk: guards against pathological cycles in malformed graphs. for _ in range(len(id_to_node)): trace_parents = [ e for e in edges if e.target == current.id and e.kind == EdgeKind.TRACE_INPUT ] if len(trace_parents) != 1: raise ValueError( f"CONVOLUTION node {conv_node.name!r} does not have a" f" single TRACE_INPUT ancestor (found {len(trace_parents)})" ) parent = id_to_node[trace_parents[0].source] if parent.kind == NodeKind.PARAM_PLUS_TRACE: return parent if parent.kind != NodeKind.CONVOLUTION: raise ValueError( f"CONVOLUTION chain for {conv_node.name!r} walks through" f" unsupported node kind {parent.kind.name!r}" ) current = parent raise ValueError(f"CONVOLUTION chain for {conv_node.name!r} exceeds graph size") # def _resolve_convolution_target_row( conv_node: GraphNode, edges: list[GraphEdge], id_to_node: dict[int, GraphNode], name_to_row: dict[str, int], ) -> int: """Walk a conv chain back to the underlying PARAM_PLUS_TRACE row. CONVOLUTION nodes wrap a resolved trace: their TRACE_INPUT source is either the PARAM_PLUS_TRACE (single conv) or an earlier CONVOLUTION (chained conv). The lowered plan rewrites the PPT row in place, so every conv in a chain shares the same target row. """ ppt_node = _walk_convolution_to_param_plus_trace(conv_node, edges, id_to_node) return name_to_row[ppt_node.name] # def _topological_sort(graph: GraphIR) -> list[int]: """Topological sort of graph node IDs. Tie-breaker: when two nodes have no dependency ordering between them, sort by ``node.source_order`` (lower first). This makes the schedule deterministic. Does NOT assume ``node.id == list index``. """ import heapq id_to_node: dict[int, GraphNode] = {n.id: n for n in graph.nodes} in_degree: dict[int, int] = {n.id: 0 for n in graph.nodes} children: dict[int, list[int]] = {n.id: [] for n in graph.nodes} for edge in graph.edges: children[edge.source].append(edge.target) in_degree[edge.target] += 1 # Priority queue: (source_order, node_id) — lower source_order first heap: list[tuple[int, int]] = [] for node in graph.nodes: if in_degree[node.id] == 0: heapq.heappush(heap, (node.source_order, node.id)) result: list[int] = [] while heap: _order, nid = heapq.heappop(heap) result.append(nid) for child in children[nid]: in_degree[child] -= 1 if in_degree[child] == 0: heapq.heappush( heap, (id_to_node[child].source_order, child), ) n = len(graph.nodes) if len(result) != n: raise ValueError(f"Graph has a cycle: sorted {len(result)} of {n} nodes") return result #
[docs] def schedule_2d(graph: GraphIR) -> ScheduledPlan2D: """Compile a GraphIR into a flat 2D execution schedule. Parameters ---------- graph : GraphIR Must pass ``can_lower_2d(graph)``. Returns ------- ScheduledPlan2D Packed-array execution schedule for ``evaluate_2d``. Raises ------ ValueError If the graph cannot be lowered (domain, unsupported nodes, etc.). """ if not can_lower_2d(graph): raise ValueError("Graph cannot be lowered to 2D backend") assert graph.energy is not None assert graph.time is not None n_time = len(graph.time) # ------------------------------------------------------------------ # # 1. Topological sort # # ------------------------------------------------------------------ # topo_order = _topological_sort(graph) # Build id -> node lookup. Do NOT assume node.id == list index; # external graph producers may use arbitrary ids. id_to_node: dict[int, GraphNode] = {n.id: n for n in graph.nodes} # ------------------------------------------------------------------ # # 2. Assign trace-matrix rows # # ------------------------------------------------------------------ # # Nodes that occupy a row in the trace matrix: all parameter-like # nodes (STATIC_PARAM, OPT_PARAM, PARAM_PLUS_TRACE, EXPRESSION). _ROW_KINDS = frozenset( { NodeKind.STATIC_PARAM, NodeKind.OPT_PARAM, NodeKind.PARAM_PLUS_TRACE, NodeKind.EXPRESSION, } ) # Collect nodes that need rows, bucketed for ordering: # opt params first, then static, then computed (PPT + EXPR) opt_nodes: list[GraphNode] = [] static_nodes: list[GraphNode] = [] computed_nodes: list[GraphNode] = [] for nid in topo_order: node = id_to_node[nid] if node.kind not in _ROW_KINDS or _is_profile_expr_node(node): continue if node.kind == NodeKind.OPT_PARAM and node.vary: opt_nodes.append(node) elif node.kind in (NodeKind.STATIC_PARAM, NodeKind.OPT_PARAM): # OPT_PARAM with vary=False is treated like static static_nodes.append(node) else: computed_nodes.append(node) all_row_nodes = opt_nodes + static_nodes + computed_nodes n_params = len(all_row_nodes) name_to_row: dict[str, int] = {} for row, node in enumerate(all_row_nodes): name_to_row[node.name] = row row_is_constant = np.zeros(n_params, dtype=np.bool_) for node in static_nodes: row_is_constant[name_to_row[node.name]] = True # CONVOLUTION nodes rewrite a PARAM_PLUS_TRACE row in place, so they # alias their underlying PPT's row rather than occupying a new one. # Downstream edges (component PARAM_INPUT) reference the last node in # the conv chain; this aliasing lets name_to_row resolve them to the # PPT row. row_is_constant becomes False for conv-touched rows because # kernel params may vary per call. for node in graph.nodes: if node.kind != NodeKind.CONVOLUTION: continue ppt_row = _resolve_convolution_target_row( node, graph.edges, id_to_node, name_to_row ) name_to_row[node.name] = ppt_row row_is_constant[ppt_row] = False # opt_indices and opt_param_names n_opt = len(opt_nodes) opt_indices = np.arange(n_opt, dtype=np.intp) opt_param_names = [n.name for n in opt_nodes] # ------------------------------------------------------------------ # # 3. Compile dynamics subgraphs (grouped by PARAM_PLUS_TRACE) # # ------------------------------------------------------------------ # # First pass: extract per-DYNAMICS_TRACE info and find which PPT # each trace feeds. dyn_trace_nodes = [ id_to_node[nid] for nid in topo_order if id_to_node[nid].kind == NodeKind.DYNAMICS_TRACE ] # Per-trace (substep) info, indexed by position in dyn_trace_nodes. sub_func_ids: list[int] = [] sub_param_row_lists: list[list[int]] = [] sub_ppt_row: list[int] = [] # which PPT row this trace targets sub_base_row: list[int] = [] # Per-substep subcycle arrays. ``None`` marks "no subcycle" # (use defaults: ``graph.time`` axis, all-ones mask). sub_time_axis: list[np.ndarray | None] = [] sub_mask: list[np.ndarray | None] = [] for dyn_node in dyn_trace_nodes: assert dyn_node.function_name is not None func_kind = _FUNCTION_NAME_TO_DYN_FUNC.get(dyn_node.function_name) if func_kind is None: raise ValueError(f"Unknown dynamics function: {dyn_node.function_name!r}") sub_func_ids.append(int(func_kind)) # Param rows (from PARAM_INPUT edges, ordered by position) param_edges = sorted( ( e for e in graph.edges if e.target == dyn_node.id and e.kind == EdgeKind.PARAM_INPUT ), key=lambda e: e.position or 0, ) sub_param_row_lists.append( [name_to_row[id_to_node[e.source].name] for e in param_edges] ) # Subcycle REMAP upstream: incoming TRACE_INPUT from SUBCYCLE_REMAP # carries the ``time_norm`` axis that replaces ``graph.time`` for # this substep. Invariants (array present, (n_time,) shape) are # enforced by ``_is_lowerable_subcycle_2d`` before we get here. remap_time_axis: np.ndarray | None = None for e in graph.edges: if e.target == dyn_node.id and e.kind == EdgeKind.TRACE_INPUT: src = id_to_node[e.source] if src.kind == NodeKind.SUBCYCLE_REMAP: remap_time_axis = src.arrays["time_norm"] break sub_time_axis.append(remap_time_axis) # Target: the PARAM_PLUS_TRACE node this trace feeds. For # subcycle-bearing traces, a SUBCYCLE_MASK sits between the # DYNAMICS_TRACE and its PARAM_PLUS_TRACE consumer; walk past it # and record the mask array for the evaluator. ppt_edges = [ e for e in graph.edges if e.source == dyn_node.id and e.kind == EdgeKind.TRACE_INPUT ] assert len(ppt_edges) == 1, ( f"DYNAMICS_TRACE '{dyn_node.name}' must feed exactly one" f" downstream node, found {len(ppt_edges)}" ) downstream = id_to_node[ppt_edges[0].target] mask_array: np.ndarray | None = None if downstream.kind == NodeKind.SUBCYCLE_MASK: mask_array = downstream.arrays["time_n_sub"] mask_out = [ e for e in graph.edges if e.source == downstream.id and e.kind == EdgeKind.TRACE_INPUT ] assert len(mask_out) == 1, ( f"SUBCYCLE_MASK '{downstream.name}' must feed exactly one" f" PARAM_PLUS_TRACE, found {len(mask_out)}" ) ppt_node = id_to_node[mask_out[0].target] else: ppt_node = downstream sub_mask.append(mask_array) assert ppt_node.kind == NodeKind.PARAM_PLUS_TRACE, ( f"DYNAMICS_TRACE '{dyn_node.name}' chain does not terminate at" f" PARAM_PLUS_TRACE (got {ppt_node.kind.name})" ) sub_ppt_row.append(name_to_row[ppt_node.name]) # Base row: the BASE_INPUT to the PARAM_PLUS_TRACE base_edges = [ e for e in graph.edges if e.target == ppt_node.id and e.kind == EdgeKind.BASE_INPUT ] assert len(base_edges) == 1 sub_base_row.append(name_to_row[id_to_node[base_edges[0].source].name]) # Second pass: group substeps by target PPT row, preserving topo order. # Each unique PPT row becomes one dynamics group. seen_ppt: dict[int, int] = {} # ppt_row -> group index group_target_rows: list[int] = [] group_base_rows: list[int] = [] group_substeps: list[list[int]] = [] # group_idx -> [substep indices] for sub_idx, ppt_row in enumerate(sub_ppt_row): if ppt_row not in seen_ppt: gid = len(group_target_rows) seen_ppt[ppt_row] = gid group_target_rows.append(ppt_row) group_base_rows.append(sub_base_row[sub_idx]) group_substeps.append([]) group_substeps[seen_ppt[ppt_row]].append(sub_idx) n_dyn_groups = len(group_target_rows) n_substeps = len(dyn_trace_nodes) # Pack substep arrays (flat, ordered by group then topo within group) flat_sub_indices: list[int] = [] indptr: list[int] = [0] for subs in group_substeps: flat_sub_indices.extend(subs) indptr.append(indptr[-1] + len(subs)) max_dyn_params = ( max(len(r) for r in sub_param_row_lists) if sub_param_row_lists else 0 ) dyn_sub_param_rows = np.full((n_substeps, max_dyn_params), -1, dtype=np.intp) dyn_sub_n_params_list: list[int] = [] dyn_sub_func_id_list: list[int] = [] # Per-substep subcycle schedule arrays. Default: ``graph.time`` and # an all-ones mask. Overridden from upstream SUBCYCLE_REMAP / MASK # nodes. Float64 kept uniform so the evaluator hot loop has a single # dtype contract. dyn_sub_time_axes = np.empty((n_substeps, n_time), dtype=np.float64) dyn_sub_time_axes[:, :] = graph.time dyn_sub_masks = np.ones((n_substeps, n_time), dtype=np.float64) for flat_i, orig_i in enumerate(flat_sub_indices): rows = sub_param_row_lists[orig_i] dyn_sub_param_rows[flat_i, : len(rows)] = rows dyn_sub_n_params_list.append(len(rows)) dyn_sub_func_id_list.append(sub_func_ids[orig_i]) if sub_time_axis[orig_i] is not None: dyn_sub_time_axes[flat_i, :] = sub_time_axis[orig_i] if sub_mask[orig_i] is not None: dyn_sub_masks[flat_i, :] = sub_mask[orig_i] dyn_group_target_row = np.array(group_target_rows, dtype=np.intp) dyn_group_base_row = np.array(group_base_rows, dtype=np.intp) dyn_group_indptr = np.array(indptr, dtype=np.intp) dyn_sub_func_id = np.array(dyn_sub_func_id_list, dtype=np.intp) dyn_sub_n_params = np.array(dyn_sub_n_params_list, dtype=np.intp) # ------------------------------------------------------------------ # # 4. Compile expressions (topological order) # # ------------------------------------------------------------------ # expr_nodes_topo = [ id_to_node[nid] for nid in topo_order if id_to_node[nid].kind == NodeKind.EXPRESSION and not _is_profile_expr_node(id_to_node[nid]) ] n_expressions = len(expr_nodes_topo) # Build per-expression name→row maps from EXPR_REF edges. # The graph's EXPR_REF edges are the single source of truth for # which node each name in the expression resolves to. We walk # those edges to build a per-node override map, then compile the # symbolic RPN (for operator structure) and bind names to rows # using the edge-derived map. # # Index: expr_node.id -> {identifier_in_expr_string -> row} expr_ref_maps: dict[int, dict[str, int]] = {} for expr_node in expr_nodes_topo: ref_map: dict[str, int] = {} for edge in graph.edges: if edge.target != expr_node.id or edge.kind != EdgeKind.EXPR_REF: continue src_node = id_to_node[edge.source] src_row = name_to_row[src_node.name] # The expr_string references names that appear as identifiers. # The EXPR_REF source node's name is the canonical form (may # include "_resolved" suffix or prefixed dynamics names). # Walk the expression's identifier tokens to find which one # this edge satisfies. assert expr_node.expr_string is not None for token in _extract_expression_references(expr_node.expr_string): if token == src_node.name: ref_map[token] = src_row break else: # The identifier in the expression doesn't match the # source node name literally. This shouldn't happen if # build_graph stored the canonical expr_string, but # fall back: try matching any identifier that names a # parameter whose resolved row IS this source row. for token in _extract_expression_references(expr_node.expr_string): if token in name_to_row and name_to_row[token] == src_row: ref_map[token] = src_row break # Also try: source is a resolved node, token is the # base param name if token not in ref_map and src_node.name.endswith("_resolved"): base_name = src_node.name.removesuffix("_resolved") if token == base_name: ref_map[token] = src_row break expr_ref_maps[expr_node.id] = ref_map # Compile and bind expressions expr_programs: list[ExprProgram] = [] expr_target_rows_list: list[int] = [] for expr_node in expr_nodes_topo: assert expr_node.expr_string is not None symbolic = compile_expr_symbolic(expr_node.expr_string) # Build the binding map: start with the base name_to_row, then # overlay the edge-derived overrides for this expression. binding = dict(name_to_row) binding.update(expr_ref_maps.get(expr_node.id, {})) program = _bind_expr_to_rows(symbolic, binding) expr_programs.append(program) target_row = name_to_row[expr_node.name] expr_target_rows_list.append(target_row) row_is_constant[target_row] = all( row_is_constant[binding[name]] for name in symbolic.referenced_names ) expr_target_rows = np.array(expr_target_rows_list, dtype=np.intp) # Build resolution schedule: map DYNAMICS_TRACE node ids to their # group index, and emit each group exactly once (on the *last* trace # in that group, so all expression deps are resolved first). _dyn_id_to_ppt = {n.id: sub_ppt_row[i] for i, n in enumerate(dyn_trace_nodes)} _ppt_to_group = seen_ppt # ppt_row -> group index _expr_id_to_idx = {n.id: i for i, n in enumerate(expr_nodes_topo)} # Find which DYNAMICS_TRACE is the last in each group (in topo order). # We emit the group step at that point so all substep expressions are # resolved before the group evaluates. _last_dyn_in_group: dict[int, int] = {} # group_idx -> last dyn node id for nid in topo_order: if nid in _dyn_id_to_ppt: gid = _ppt_to_group[_dyn_id_to_ppt[nid]] _last_dyn_in_group[gid] = nid # Compile CONVOLUTION nodes into a conv program. Walk topo order # so chained conv steps stay in the right sequence and the # kind=2 resolution entries are emitted after the dyn/expr step that # populates the target PARAM_PLUS_TRACE row. conv_nodes_topo: list[GraphNode] = [ id_to_node[nid] for nid in topo_order if id_to_node[nid].kind == NodeKind.CONVOLUTION ] _conv_id_to_idx: dict[int, int] = {n.id: i for i, n in enumerate(conv_nodes_topo)} n_conv_steps = len(conv_nodes_topo) conv_target_rows_list: list[int] = [] conv_func_ids_list: list[int] = [] conv_param_indptr_list: list[int] = [0] conv_param_rows_list: list[int] = [] conv_support_indptr_list: list[int] = [0] conv_support_values_list: list[float] = [] for conv_node in conv_nodes_topo: assert conv_node.function_name is not None kernel_kind = _FUNCTION_NAME_TO_CONV_KERNEL.get(conv_node.function_name) if kernel_kind is None: raise ValueError(f"Unknown convolution kernel: {conv_node.function_name!r}") conv_func_ids_list.append(int(kernel_kind)) target_row = _resolve_convolution_target_row( conv_node, graph.edges, id_to_node, name_to_row ) conv_target_rows_list.append(target_row) kernel_param_edges = sorted( ( e for e in graph.edges if e.target == conv_node.id and e.kind == EdgeKind.PARAM_INPUT ), key=lambda e: e.position or 0, ) for edge in kernel_param_edges: src_node = id_to_node[edge.source] conv_param_rows_list.append(name_to_row[src_node.name]) conv_param_indptr_list.append(len(conv_param_rows_list)) support = conv_node.arrays.get("kernel_time") if support is None: raise ValueError( f"CONVOLUTION node {conv_node.name!r} is missing" " kernel_time array (required for lowered convolution)" ) conv_support_values_list.extend(float(v) for v in np.asarray(support)) conv_support_indptr_list.append(len(conv_support_values_list)) conv_target_rows = np.array(conv_target_rows_list, dtype=np.intp) conv_func_ids = np.array(conv_func_ids_list, dtype=np.intp) conv_param_indptr = np.array(conv_param_indptr_list, dtype=np.intp) conv_param_rows = np.array(conv_param_rows_list, dtype=np.intp) conv_support_indptr = np.array(conv_support_indptr_list, dtype=np.intp) conv_support_values = np.array(conv_support_values_list, dtype=np.float64) resolution_kinds_list: list[int] = [] resolution_indices_list: list[int] = [] emitted_groups: set[int] = set() for nid in topo_order: if nid in _dyn_id_to_ppt: gid = _ppt_to_group[_dyn_id_to_ppt[nid]] if _last_dyn_in_group[gid] == nid and gid not in emitted_groups: resolution_kinds_list.append(0) # dynamics group resolution_indices_list.append(gid) emitted_groups.add(gid) elif nid in _expr_id_to_idx: resolution_kinds_list.append(1) # expression resolution_indices_list.append(_expr_id_to_idx[nid]) elif nid in _conv_id_to_idx: resolution_kinds_list.append(2) # convolution step resolution_indices_list.append(_conv_id_to_idx[nid]) resolution_kinds = np.array(resolution_kinds_list, dtype=np.int8) resolution_indices = np.array(resolution_indices_list, dtype=np.intp) # ------------------------------------------------------------------ # # 4b. Compile PROFILE_SAMPLE groups # # ------------------------------------------------------------------ # # Build edge indexes for profile compilation (mirrors schedule_1d). param_edges_by_target: dict[int, list[GraphEdge]] = {} expr_ref_edges_by_target: dict[int, list[GraphEdge]] = {} addend_edges_by_target: dict[int, list[GraphEdge]] = {} spectrum_input_targets: set[int] = set() for edge in graph.edges: if edge.kind == EdgeKind.PARAM_INPUT: param_edges_by_target.setdefault(edge.target, []).append(edge) elif edge.kind == EdgeKind.EXPR_REF: expr_ref_edges_by_target.setdefault(edge.target, []).append(edge) elif edge.kind == EdgeKind.ADDEND: addend_edges_by_target.setdefault(edge.target, []).append(edge) elif edge.kind == EdgeKind.SPECTRUM_INPUT: spectrum_input_targets.add(edge.target) profile_sample_groups: dict[str, list[GraphNode]] = {} for nid in topo_order: node = id_to_node[nid] if node.kind != NodeKind.PROFILE_SAMPLE: continue group_name = _profile_group_name(node.name, "profile_sample") profile_sample_groups.setdefault(group_name, []).append(node) plan_aux_axis = np.zeros(0, dtype=np.float64) n_aux = 0 profile_sample_base_rows_list: list[int] = [] profile_sample_component_indptr_list: list[int] = [0] profile_component_func_ids_list: list[int] = [] profile_component_param_indptr_list: list[int] = [0] profile_component_param_rows_list: list[int] = [] profile_sample_is_constant_list: list[bool] = [] profile_sample_group_idx: dict[str, int] = {} for group_name, sample_nodes in profile_sample_groups.items(): sample_nodes_sorted = sorted( sample_nodes, key=lambda node: _profile_group_index(node.name, "profile_sample"), ) aux_indices = [ _profile_group_index(node.name, "profile_sample") for node in sample_nodes_sorted ] if aux_indices != list(range(len(sample_nodes_sorted))): raise ValueError( f"PROFILE_SAMPLE nodes for {group_name!r} do not cover " "a contiguous aux-axis range" ) aux_axis = sample_nodes_sorted[0].arrays.get("aux_axis") if aux_axis is None: raise ValueError(f"PROFILE_SAMPLE {group_name!r} is missing aux_axis") aux_axis = np.asarray(aux_axis, dtype=np.float64) if n_aux == 0: n_aux = len(aux_axis) plan_aux_axis = aux_axis.copy() elif len(aux_axis) != n_aux or not np.array_equal(aux_axis, plan_aux_axis): raise ValueError("All lowered profile groups must share one fixed aux_axis") if len(sample_nodes_sorted) != n_aux: raise ValueError( f"PROFILE_SAMPLE group {group_name!r} has " f"{len(sample_nodes_sorted)} samples but aux_axis length {n_aux}" ) rep_node = sample_nodes_sorted[0] rep_param_edges = sorted( param_edges_by_target.get(rep_node.id, []), key=lambda edge: edge.position or 0, ) if not rep_param_edges: raise ValueError(f"PROFILE_SAMPLE {group_name!r} has no PARAM_INPUT edges") base_node = id_to_node[rep_param_edges[0].source] base_row = name_to_row[base_node.name] is_constant = bool(row_is_constant[base_row]) profile_sample_base_rows_list.append(base_row) component_func_by_name: dict[str, int] = {} component_param_rows_by_name: dict[str, list[int]] = {} component_order: list[str] = [] for edge in rep_param_edges[1:]: src_node = id_to_node[edge.source] src_row = name_to_row[src_node.name] # PARAM_PLUS_TRACE nodes have a "_resolved" suffix; CONVOLUTION # nodes wrap a PPT for profile-time-dynamics and carry their own # "<par>_<kernel>_dynamics" name. In both cases, walk back to # the underlying PPT so profile component parsing sees the # original profile parameter name. parse_name = src_node.name if src_node.kind == NodeKind.PARAM_PLUS_TRACE: parse_name = parse_name.removesuffix("_resolved") elif src_node.kind == NodeKind.CONVOLUTION: ppt_node = _walk_convolution_to_param_plus_trace( src_node, graph.edges, id_to_node ) parse_name = ppt_node.name.removesuffix("_resolved") comp_name, func_name = _parse_profile_component_param_name( group_name, parse_name, ) prof_func_kind = _FUNCTION_NAME_TO_PROFILE_FUNC.get(func_name) if prof_func_kind is None: raise ValueError(f"Unknown profile function: {func_name!r}") if comp_name not in component_func_by_name: component_order.append(comp_name) component_func_by_name[comp_name] = int(prof_func_kind) component_param_rows_by_name[comp_name] = [] component_param_rows_by_name[comp_name].append(src_row) is_constant = is_constant and bool(row_is_constant[src_row]) for comp_name in component_order: profile_component_func_ids_list.append(component_func_by_name[comp_name]) profile_component_param_rows_list.extend( component_param_rows_by_name[comp_name] ) profile_component_param_indptr_list.append( len(profile_component_param_rows_list) ) profile_sample_component_indptr_list.append( len(profile_component_func_ids_list) ) group_idx = len(profile_sample_base_rows_list) - 1 profile_sample_group_idx[group_name] = group_idx profile_sample_is_constant_list.append(is_constant) n_profile_samples = len(profile_sample_base_rows_list) profile_sample_base_rows = np.array(profile_sample_base_rows_list, dtype=np.intp) profile_sample_component_indptr = np.array( profile_sample_component_indptr_list, dtype=np.intp ) profile_component_func_ids = np.array( profile_component_func_ids_list, dtype=np.intp ) profile_component_param_indptr = np.array( profile_component_param_indptr_list, dtype=np.intp ) profile_component_param_rows = np.array( profile_component_param_rows_list, dtype=np.intp ) profile_sample_is_constant = np.array( profile_sample_is_constant_list, dtype=np.bool_ ) # ------------------------------------------------------------------ # # 4c. Compile per-sample profile expressions # # ------------------------------------------------------------------ # profile_expr_groups: dict[str, list[GraphNode]] = {} for nid in topo_order: node = id_to_node[nid] if _is_profile_expr_node(node): group_name = _profile_group_name(node.name, "profile_expr") profile_expr_groups.setdefault(group_name, []).append(node) profile_expr_programs_2d: list[ExprProgram] = [] profile_expr_is_constant_list: list[bool] = [] profile_expr_group_idx: dict[str, int] = {} for group_name, p_expr_nodes in profile_expr_groups.items(): p_expr_nodes_sorted = sorted( p_expr_nodes, key=lambda node: _profile_group_index(node.name, "profile_expr"), ) aux_indices = [ _profile_group_index(node.name, "profile_expr") for node in p_expr_nodes_sorted ] if aux_indices != list(range(len(p_expr_nodes_sorted))): raise ValueError( f"Profile expression nodes for {group_name!r} do not cover " "a contiguous aux-axis range" ) if len(p_expr_nodes_sorted) != n_aux: raise ValueError( f"Profile expression group {group_name!r} has " f"{len(p_expr_nodes_sorted)} samples but aux_axis length {n_aux}" ) rep_node = p_expr_nodes_sorted[0] if rep_node.expr_string is None: raise ValueError( f"Profile expression {group_name!r} is missing expr_string" ) expr_refs = set(_extract_expression_references(rep_node.expr_string)) prof_ref_map: dict[str, int] = {} for edge in expr_ref_edges_by_target.get(rep_node.id, []): src_node = id_to_node[edge.source] if src_node.kind == NodeKind.PROFILE_SAMPLE: sample_name = _profile_group_name(src_node.name, "profile_sample") src_idx = n_params + profile_sample_group_idx[sample_name] match_name = sample_name else: src_idx = name_to_row[src_node.name] match_name = src_node.name # PARAM_PLUS_TRACE nodes carry a "_resolved" suffix; # the expression string uses the bare param name. if match_name not in expr_refs and src_node.name.endswith("_resolved"): match_name = src_node.name.removesuffix("_resolved") if match_name in expr_refs: prof_ref_map[match_name] = src_idx symbolic = compile_expr_symbolic(rep_node.expr_string) prof_binding: dict[str, int] = dict(name_to_row) prof_binding.update(prof_ref_map) program = _bind_expr_to_rows(symbolic, prof_binding) profile_expr_programs_2d.append(program) is_constant = True for name in symbolic.referenced_names: bound_idx = int(prof_binding[name]) if bound_idx < n_params: is_constant = is_constant and bool(row_is_constant[bound_idx]) else: is_constant = is_constant and bool( profile_sample_is_constant[bound_idx - n_params] ) profile_expr_is_constant_list.append(is_constant) profile_expr_group_idx[group_name] = len(profile_expr_programs_2d) - 1 n_profile_exprs = len(profile_expr_programs_2d) profile_expr_is_constant = np.array(profile_expr_is_constant_list, dtype=np.bool_) # ------------------------------------------------------------------ # # 5. Schedule component ops # # ------------------------------------------------------------------ # # Identify peak_sum contributors: nodes with ADDEND edges into the # "peak_sum" SUM node (if it exists). peak_sum_sources: set[int] = set() peak_sum_nid = graph.node_by_name.get("peak_sum") if peak_sum_nid is not None: for e in addend_edges_by_target.get(peak_sum_nid, []): peak_sum_sources.add(e.source) # Collect per-sample component inputs for PROFILE_AVERAGE nodes. profile_avg_sample_inputs: dict[int, list[GraphNode]] = {} sample_component_ids: set[int] = set() for nid in topo_order: node = id_to_node[nid] if node.kind != NodeKind.PROFILE_AVERAGE: continue sample_nodes = [ id_to_node[edge.source] for edge in addend_edges_by_target.get(node.id, []) if id_to_node[edge.source].kind in (NodeKind.COMPONENT_EVAL, NodeKind.SPECTRUM_FED_OP) ] profile_avg_sample_inputs[node.id] = sample_nodes sample_component_ids.update(sample.id for sample in sample_nodes) comp_op_kinds = {NodeKind.COMPONENT_EVAL, NodeKind.SPECTRUM_FED_OP} comp_nodes_topo = [ id_to_node[nid] for nid in topo_order if ( (id_to_node[nid].kind in comp_op_kinds and nid not in sample_component_ids) or id_to_node[nid].kind == NodeKind.PROFILE_AVERAGE ) ] n_ops = len(comp_nodes_topo) op_schedule = np.arange(n_ops, dtype=np.intp) op_kinds_list: list[int] = [] op_param_indptr_list: list[int] = [0] op_param_source_kinds_list: list[int] = [] op_param_indices_list: list[int] = [] op_needs_spectrum_list: list[bool] = [] op_is_pre_spectrum_list: list[bool] = [] op_is_profiled_list: list[bool] = [] op_is_constant_list: list[bool] = [] for comp_node in comp_nodes_topo: if comp_node.kind == NodeKind.PROFILE_AVERAGE: # Profiled component: gather from per-sample COMPONENT_EVAL inputs. sample_nodes = sorted( profile_avg_sample_inputs.get(comp_node.id, []), key=lambda node: _profile_component_sample_index(node.name), ) if not sample_nodes: raise ValueError( f"PROFILE_AVERAGE {comp_node.name!r} has no sample component inputs" ) if len(sample_nodes) != n_aux: raise ValueError( f"PROFILE_AVERAGE {comp_node.name!r} has " f"{len(sample_nodes)} samples " f"but aux_axis length {n_aux}" ) rep_node = sample_nodes[0] assert rep_node.function_name is not None op = _FUNCTION_NAME_TO_OP.get(rep_node.function_name) if op is None: raise ValueError( f"Unknown component function: {rep_node.function_name!r}" ) op_kinds_list.append(int(op)) op_is_profiled_list.append(True) rep_param_edges = sorted( param_edges_by_target.get(rep_node.id, []), key=lambda edge: edge.position or 0, ) sample_param_edges = [ sorted( param_edges_by_target.get(sample_node.id, []), key=lambda edge: edge.position or 0, ) for sample_node in sample_nodes ] is_constant = True for pos, rep_edge in enumerate(rep_param_edges): src_node = id_to_node[rep_edge.source] if src_node.kind == NodeKind.PROFILE_SAMPLE: group_name = _profile_group_name(src_node.name, "profile_sample") source_kind = int(ParamSourceKind.PROFILE_SAMPLE) source_idx = profile_sample_group_idx[group_name] is_constant = is_constant and bool( profile_sample_is_constant[source_idx] ) for aux_i, edges in enumerate(sample_param_edges): sample_src = id_to_node[edges[pos].source] if sample_src.kind != NodeKind.PROFILE_SAMPLE: raise ValueError( "Mixed parameter source kinds " f"in profiled op {comp_node.name!r}" ) if ( _profile_group_name(sample_src.name, "profile_sample") != group_name or _profile_group_index(sample_src.name, "profile_sample") != aux_i ): raise ValueError( "Inconsistent PROFILE_SAMPLE wiring " f"in {comp_node.name!r}" ) elif _is_profile_expr_node(src_node): group_name = _profile_group_name(src_node.name, "profile_expr") source_kind = int(ParamSourceKind.PROFILE_EXPR) source_idx = profile_expr_group_idx[group_name] is_constant = is_constant and bool( profile_expr_is_constant[source_idx] ) for aux_i, edges in enumerate(sample_param_edges): sample_src = id_to_node[edges[pos].source] if not _is_profile_expr_node(sample_src): raise ValueError( "Mixed expression source kinds " f"in profiled op {comp_node.name!r}" ) if ( _profile_group_name(sample_src.name, "profile_expr") != group_name or _profile_group_index(sample_src.name, "profile_expr") != aux_i ): raise ValueError( "Inconsistent profile-expression " f"wiring in {comp_node.name!r}" ) else: source_kind = int(ParamSourceKind.SCALAR) source_idx = name_to_row[src_node.name] is_constant = is_constant and bool(row_is_constant[source_idx]) for edges in sample_param_edges[1:]: if id_to_node[edges[pos].source].id != src_node.id: raise ValueError( "Scalar parameter source changed " f"across samples in {comp_node.name!r}" ) op_param_source_kinds_list.append(source_kind) op_param_indices_list.append(source_idx) op_param_indptr_list.append(len(op_param_indices_list)) else: # Non-profiled component: standard param row wiring. assert comp_node.function_name is not None op = _FUNCTION_NAME_TO_OP.get(comp_node.function_name) if op is None: raise ValueError( f"Unknown component function: {comp_node.function_name!r}" ) op_kinds_list.append(int(op)) op_is_profiled_list.append(False) param_edges = sorted( param_edges_by_target.get(comp_node.id, []), key=lambda edge: edge.position or 0, ) is_constant = True for pe in param_edges: src_node = id_to_node[pe.source] src_row = name_to_row[src_node.name] op_param_source_kinds_list.append(int(ParamSourceKind.SCALAR)) op_param_indices_list.append(src_row) is_constant = is_constant and bool(row_is_constant[src_row]) op_param_indptr_list.append(len(op_param_indices_list)) has_spec_input = comp_node.id in spectrum_input_targets op_needs_spectrum_list.append(has_spec_input) op_is_pre_spectrum_list.append(comp_node.id in peak_sum_sources) op_is_constant_list.append((not has_spec_input) and is_constant) op_kinds = np.array(op_kinds_list, dtype=np.intp) op_param_indptr = np.array(op_param_indptr_list, dtype=np.intp) op_param_source_kinds = np.array(op_param_source_kinds_list, dtype=np.int8) op_param_indices = np.array(op_param_indices_list, dtype=np.intp) op_needs_spectrum = np.array(op_needs_spectrum_list, dtype=np.bool_) op_is_pre_spectrum = np.array(op_is_pre_spectrum_list, dtype=np.bool_) op_is_profiled = np.array(op_is_profiled_list, dtype=np.bool_) op_is_constant = np.array(op_is_constant_list, dtype=np.bool_) # ------------------------------------------------------------------ # # 6. Initialize trace matrix # # ------------------------------------------------------------------ # param_traces_init = np.zeros((n_params, n_time), dtype=np.float64) # Static and opt params: broadcast initial value for node in opt_nodes + static_nodes: row = name_to_row[node.name] param_traces_init[row, :] = node.value if node.value is not None else 0.0 # PARAM_PLUS_TRACE: base + dynamics trace at initial values. # We need to evaluate dynamics functions at initial parameter values # to populate these rows. from trspecfit.functions import time as fcts_time _DYN_DISPATCH: dict[int, Callable[..., Any]] = { int(DynFuncKind.EXPFUN): fcts_time.expFun, int(DynFuncKind.SINFUN): fcts_time.sinFun, int(DynFuncKind.LINFUN): fcts_time.linFun, int(DynFuncKind.SINDIVX): fcts_time.sinDivX, int(DynFuncKind.ERFFUN): fcts_time.erfFun, int(DynFuncKind.SQRTFUN): fcts_time.sqrtFun, } _CONV_KERNEL_DISPATCH: dict[int, Callable[..., Any]] = { int(ConvKernelKind.GAUSSCONV): fcts_time.gaussCONV, int(ConvKernelKind.LORENTZCONV): fcts_time.lorentzCONV, int(ConvKernelKind.VOIGTCONV): fcts_time.voigtCONV, int(ConvKernelKind.EXPSYMCONV): fcts_time.expSymCONV, int(ConvKernelKind.EXPDECAYCONV): fcts_time.expDecayCONV, int(ConvKernelKind.EXPRISECONV): fcts_time.expRiseCONV, int(ConvKernelKind.BOXCONV): fcts_time.boxCONV, } from trspecfit.eval_2d import eval_expr_program from trspecfit.utils.arrays import my_conv # Dynamics groups, expressions, and resolved-trace convolutions are # interleaved in topological order so that downstream consumers see # the fully resolved trace (base + dynamics + expressions + IRF). for step in range(len(resolution_kinds)): kind = int(resolution_kinds[step]) idx = int(resolution_indices[step]) if kind == 0: # dynamics group target = int(dyn_group_target_row[idx]) base = int(dyn_group_base_row[idx]) param_traces_init[target, :] = param_traces_init[base, :] for s in range(int(dyn_group_indptr[idx]), int(dyn_group_indptr[idx + 1])): n_dp = int(dyn_sub_n_params[s]) p_vals = [ float(param_traces_init[dyn_sub_param_rows[s, j], 0]) for j in range(n_dp) ] func = _DYN_DISPATCH[int(dyn_sub_func_id[s])] param_traces_init[target, :] += ( func(dyn_sub_time_axes[s], *p_vals) * dyn_sub_masks[s] ) elif kind == 1: # expression target_row = int(expr_target_rows[idx]) program = expr_programs[idx] param_traces_init[target_row, :] = eval_expr_program( program, param_traces_init ) else: # kind == 2: resolved-trace convolution target_row = int(conv_target_rows[idx]) kernel_func = _CONV_KERNEL_DISPATCH[int(conv_func_ids[idx])] p_start = int(conv_param_indptr[idx]) p_end = int(conv_param_indptr[idx + 1]) kernel_params = [ float(param_traces_init[int(conv_param_rows[j]), 0]) for j in range(p_start, p_end) ] s_start = int(conv_support_indptr[idx]) s_end = int(conv_support_indptr[idx + 1]) support = conv_support_values[s_start:s_end] kernel = kernel_func(support, *kernel_params) param_traces_init[target_row, :] = my_conv( graph.time, param_traces_init[target_row, :], kernel ) # ------------------------------------------------------------------ # # 6b. Precompute constant component contributions # # ------------------------------------------------------------------ # # Evaluate profile sample/expr values at initial traces for caching. from trspecfit.eval_2d import ( _evaluate_profile_expr_values_2d, _evaluate_profile_sample_values_2d, ) profile_sample_values_init = _evaluate_profile_sample_values_2d( plan_aux_axis, param_traces_init, profile_sample_base_rows, profile_sample_component_indptr, profile_component_func_ids, profile_component_param_indptr, profile_component_param_rows, ) profile_expr_values_init = _evaluate_profile_expr_values_2d( param_traces_init, profile_sample_values_init, n_params, profile_expr_programs_2d, ) energy = graph.energy[np.newaxis, :] cached_result = np.zeros((n_time, len(graph.energy)), dtype=np.float64) cached_peak_sum = np.zeros_like(cached_result) for op_idx in range(n_ops): if not op_is_constant[op_idx]: continue start = int(op_param_indptr[op_idx]) end = int(op_param_indptr[op_idx + 1]) if op_is_profiled[op_idx]: from trspecfit.eval_2d import _evaluate_profiled_op_2d component = _evaluate_profiled_op_2d( energy, int(op_kinds[op_idx]), op_param_source_kinds[start:end], op_param_indices[start:end], param_traces_init, profile_sample_values_init, profile_expr_values_init, cached_peak_sum, needs_spectrum=bool(op_needs_spectrum[op_idx]), n_aux=n_aux, ) else: param_rows = op_param_indices[start:end] params = [ param_traces_init[int(row), :][:, np.newaxis] for row in param_rows ] func, _needs = OP_DISPATCH[int(op_kinds[op_idx])] if op_needs_spectrum[op_idx]: component = func(energy, *params, cached_peak_sum) else: component = func(energy, *params) cached_result += component if op_is_pre_spectrum[op_idx]: cached_peak_sum += component # ------------------------------------------------------------------ # # 7. Pack into ScheduledPlan2D # # ------------------------------------------------------------------ # return ScheduledPlan2D( energy=graph.energy, time=graph.time, n_params=n_params, n_time=n_time, param_traces_init=param_traces_init, opt_indices=opt_indices, opt_param_names=opt_param_names, n_dyn_groups=n_dyn_groups, dyn_group_target_row=dyn_group_target_row, dyn_group_base_row=dyn_group_base_row, dyn_group_indptr=dyn_group_indptr, dyn_sub_func_id=dyn_sub_func_id, dyn_sub_param_rows=dyn_sub_param_rows, dyn_sub_n_params=dyn_sub_n_params, dyn_sub_time_axes=dyn_sub_time_axes, dyn_sub_masks=dyn_sub_masks, n_expressions=n_expressions, expr_target_rows=expr_target_rows, expr_programs=expr_programs, resolution_kinds=resolution_kinds, resolution_indices=resolution_indices, n_aux=n_aux, aux_axis=plan_aux_axis, n_profile_samples=n_profile_samples, profile_sample_base_rows=profile_sample_base_rows, profile_sample_component_indptr=profile_sample_component_indptr, profile_component_func_ids=profile_component_func_ids, profile_component_param_indptr=profile_component_param_indptr, profile_component_param_rows=profile_component_param_rows, n_profile_exprs=n_profile_exprs, profile_expr_programs=profile_expr_programs_2d, n_ops=n_ops, op_schedule=op_schedule, op_kinds=op_kinds, op_param_indptr=op_param_indptr, op_param_source_kinds=op_param_source_kinds, op_param_indices=op_param_indices, op_needs_spectrum=op_needs_spectrum, op_is_pre_spectrum=op_is_pre_spectrum, op_is_profiled=op_is_profiled, op_is_constant=op_is_constant, cached_result=cached_result, cached_peak_sum=cached_peak_sum, n_conv_steps=n_conv_steps, conv_target_rows=conv_target_rows, conv_func_ids=conv_func_ids, conv_param_indptr=conv_param_indptr, conv_param_rows=conv_param_rows, conv_support_indptr=conv_support_indptr, conv_support_values=conv_support_values, )
# --------------------------------------------------------------------------- # schedule_1d # --------------------------------------------------------------------------- # def _eval_expr_scalar(program: ExprProgram, values: np.ndarray) -> float: """Evaluate an RPN ExprProgram against a scalar parameter vector. Parameters ---------- program Compiled RPN instruction array. values ``(n_params,)`` scalar parameter vector. Returns ------- float Scalar result. """ stack: list[float] = [] instr = program.instructions n_instr = len(instr) // 2 for i in range(n_instr): kind = ExprNodeKind(instr[2 * i]) operand = instr[2 * i + 1] if kind == ExprNodeKind.CONST: stack.append(float(np.int64(operand).view(np.float64))) elif kind == ExprNodeKind.PARAM_REF: stack.append(float(values[int(operand)])) elif kind == ExprNodeKind.ADD: b, a = stack.pop(), stack.pop() stack.append(a + b) elif kind == ExprNodeKind.SUB: b, a = stack.pop(), stack.pop() stack.append(a - b) elif kind == ExprNodeKind.MUL: b, a = stack.pop(), stack.pop() stack.append(a * b) elif kind == ExprNodeKind.DIV: b, a = stack.pop(), stack.pop() stack.append(a / b) elif kind == ExprNodeKind.NEG: stack.append(-stack.pop()) elif kind == ExprNodeKind.POW: b, a = stack.pop(), stack.pop() stack.append(a**b) assert len(stack) == 1 return stack[0] # def _profile_group_name(node_name: str, label: str) -> str: """Return the shared profile group base name from a per-sample node name.""" match = re.fullmatch(rf"(.+)_{label}_(\d+)", node_name) if match is None: raise ValueError(f"Malformed profile node name: {node_name!r}") return match.group(1) # def _profile_group_index(node_name: str, label: str) -> int: """Return the aux-axis sample index encoded in a per-sample node name.""" match = re.fullmatch(rf"(.+)_{label}_(\d+)", node_name) if match is None: raise ValueError(f"Malformed profile node name: {node_name!r}") return int(match.group(2)) # def _profile_component_sample_index(node_name: str) -> int: """Return the aux-axis sample index from ``<component>_sample_<i>``.""" return _profile_group_index(node_name, "sample") # def _is_profile_expr_node(node: GraphNode) -> bool: """Return True for per-sample profile-expression nodes.""" return ( node.kind == NodeKind.EXPRESSION and re.fullmatch(r"(.+)_profile_expr_(\d+)", node.name) is not None ) # def _parse_profile_component_param_name( target_param_name: str, source_param_name: str, ) -> tuple[str, str]: """Parse ``<target>_<profile_comp>_<par>`` into component + function name.""" prefix = f"{target_param_name}_" if not source_param_name.startswith(prefix): raise ValueError( f"Profile parameter {source_param_name!r} does not match " f"target parameter {target_param_name!r}" ) remainder = source_param_name[len(prefix) :] comp_name, sep, _par_name = remainder.rpartition("_") if not sep or not comp_name: raise ValueError(f"Malformed profile parameter name: {source_param_name!r}") func_name, sep2, comp_idx = comp_name.rpartition("_") if not sep2 or not comp_idx.isdigit(): raise ValueError(f"Malformed profile component name: {comp_name!r}") return comp_name, func_name # def _eval_expr_vector(program: ExprProgram, traces: np.ndarray) -> np.ndarray: """Evaluate an RPN ExprProgram against an aux-resolved trace matrix.""" from trspecfit.eval_2d import eval_expr_program return eval_expr_program(program, traces) # def _evaluate_profile_sample_values( aux_axis: np.ndarray, scalar_values: np.ndarray, profile_sample_base_indices: np.ndarray, profile_sample_component_indptr: np.ndarray, profile_component_func_ids: np.ndarray, profile_component_param_indptr: np.ndarray, profile_component_param_indices: np.ndarray, ) -> np.ndarray: """Evaluate lowered PROFILE_SAMPLE groups into ``(n_groups, n_aux)`` values.""" n_groups = len(profile_sample_base_indices) n_aux = len(aux_axis) if n_groups == 0: return np.zeros((0, n_aux), dtype=np.float64) sample_values = np.empty((n_groups, n_aux), dtype=np.float64) for group_idx in range(n_groups): base_idx = int(profile_sample_base_indices[group_idx]) values = np.full(n_aux, scalar_values[base_idx], dtype=np.float64) comp_start = int(profile_sample_component_indptr[group_idx]) comp_end = int(profile_sample_component_indptr[group_idx + 1]) for comp_idx in range(comp_start, comp_end): func = PROFILE_DISPATCH[int(profile_component_func_ids[comp_idx])] param_start = int(profile_component_param_indptr[comp_idx]) param_end = int(profile_component_param_indptr[comp_idx + 1]) params = [ float(scalar_values[int(idx)]) for idx in profile_component_param_indices[param_start:param_end] ] values += np.asarray(func(aux_axis, *params), dtype=np.float64) sample_values[group_idx, :] = values return sample_values # def _evaluate_profile_expr_values( scalar_values: np.ndarray, profile_sample_values: np.ndarray, n_params: int, profile_expr_programs: list[ExprProgram], ) -> np.ndarray: """Evaluate lowered per-sample profile expressions over the aux axis.""" n_exprs = len(profile_expr_programs) if n_exprs == 0: n_aux = profile_sample_values.shape[1] if profile_sample_values.size else 0 return np.zeros((0, n_aux), dtype=np.float64) n_aux = profile_sample_values.shape[1] traces = np.empty( (n_params + profile_sample_values.shape[0], n_aux), dtype=np.float64 ) traces[:n_params, :] = scalar_values[:, np.newaxis] if profile_sample_values.size: traces[n_params:, :] = profile_sample_values expr_values = np.empty((n_exprs, n_aux), dtype=np.float64) for expr_idx, program in enumerate(profile_expr_programs): expr_values[expr_idx, :] = _eval_expr_vector(program, traces) return expr_values # def _evaluate_scheduled_op_1d( energy: np.ndarray, kind: int, param_source_kinds: np.ndarray, param_indices: np.ndarray, scalar_values: np.ndarray, profile_sample_values: np.ndarray, profile_expr_values: np.ndarray, peak_sum: np.ndarray, *, needs_spectrum: bool, is_profiled: bool, n_aux: int, ) -> np.ndarray: """Evaluate one scheduled 1D op in either scalar or profiled mode.""" func, _needs = OP_DISPATCH[kind] if not is_profiled: scalar_params = [float(scalar_values[int(row)]) for row in param_indices] if needs_spectrum: return np.asarray(func(energy, *scalar_params, peak_sum), dtype=np.float64) return np.asarray(func(energy, *scalar_params), dtype=np.float64) energy_2d = energy[np.newaxis, :] params_2d: list[np.ndarray] = [] for source_kind, source_idx in zip( param_source_kinds, param_indices, strict=True, ): if int(source_kind) == int(ParamSourceKind.SCALAR): param = np.full( (n_aux, 1), scalar_values[int(source_idx)], dtype=np.float64 ) elif int(source_kind) == int(ParamSourceKind.PROFILE_SAMPLE): param = profile_sample_values[int(source_idx), :][:, np.newaxis] else: param = profile_expr_values[int(source_idx), :][:, np.newaxis] params_2d.append(param) if needs_spectrum: profiled = func(energy_2d, *params_2d, peak_sum[np.newaxis, :]) else: profiled = func(energy_2d, *params_2d) result: np.ndarray = np.asarray(profiled, dtype=np.float64).mean(axis=0) return result #
[docs] def schedule_1d(graph: GraphIR) -> ScheduledPlan1D: """Compile a GraphIR into a flat 1D execution schedule. Parameters ---------- graph : GraphIR Must pass ``can_lower_1d(graph)``. Returns ------- ScheduledPlan1D Packed-array execution schedule for ``evaluate_1d``. Raises ------ ValueError If the graph cannot be lowered (domain, unsupported nodes, etc.). """ if not can_lower_1d(graph): raise ValueError("Graph cannot be lowered to 1D backend") assert graph.energy is not None # ------------------------------------------------------------------ # # 1. Topological sort + helper lookups # # ------------------------------------------------------------------ # topo_order = _topological_sort(graph) id_to_node: dict[int, GraphNode] = {n.id: n for n in graph.nodes} param_edges_by_target: dict[int, list[GraphEdge]] = {} expr_ref_edges_by_target: dict[int, list[GraphEdge]] = {} addend_edges_by_target: dict[int, list[GraphEdge]] = {} spectrum_input_targets: set[int] = set() for edge in graph.edges: if edge.kind == EdgeKind.PARAM_INPUT: param_edges_by_target.setdefault(edge.target, []).append(edge) elif edge.kind == EdgeKind.EXPR_REF: expr_ref_edges_by_target.setdefault(edge.target, []).append(edge) elif edge.kind == EdgeKind.ADDEND: addend_edges_by_target.setdefault(edge.target, []).append(edge) elif edge.kind == EdgeKind.SPECTRUM_INPUT: spectrum_input_targets.add(edge.target) # ------------------------------------------------------------------ # # 2. Assign scalar parameter indices # # ------------------------------------------------------------------ # _ROW_KINDS = frozenset( { NodeKind.STATIC_PARAM, NodeKind.OPT_PARAM, NodeKind.EXPRESSION, } ) opt_nodes: list[GraphNode] = [] static_nodes: list[GraphNode] = [] computed_nodes: list[GraphNode] = [] for nid in topo_order: node = id_to_node[nid] if node.kind not in _ROW_KINDS or _is_profile_expr_node(node): continue if node.kind == NodeKind.OPT_PARAM and node.vary: opt_nodes.append(node) elif node.kind in (NodeKind.STATIC_PARAM, NodeKind.OPT_PARAM): static_nodes.append(node) else: computed_nodes.append(node) all_param_nodes = opt_nodes + static_nodes + computed_nodes n_params = len(all_param_nodes) name_to_idx = {node.name: idx for idx, node in enumerate(all_param_nodes)} idx_is_constant = np.zeros(n_params, dtype=np.bool_) for node in static_nodes: idx_is_constant[name_to_idx[node.name]] = True n_opt = len(opt_nodes) opt_indices = np.arange(n_opt, dtype=np.intp) opt_param_names = [n.name for n in opt_nodes] # ------------------------------------------------------------------ # # 3. Compile scalar expressions # # ------------------------------------------------------------------ # expr_nodes_topo = [ id_to_node[nid] for nid in topo_order if id_to_node[nid].kind == NodeKind.EXPRESSION and not _is_profile_expr_node(id_to_node[nid]) ] expr_programs: list[ExprProgram] = [] expr_target_indices_list: list[int] = [] for expr_node in expr_nodes_topo: assert expr_node.expr_string is not None symbolic = compile_expr_symbolic(expr_node.expr_string) expr_refs = set(_extract_expression_references(expr_node.expr_string)) ref_map: dict[str, int] = {} for edge in expr_ref_edges_by_target.get(expr_node.id, []): src_node = id_to_node[edge.source] src_idx = name_to_idx[src_node.name] if src_node.name in expr_refs: ref_map[src_node.name] = src_idx binding = dict(name_to_idx) binding.update(ref_map) program = _bind_expr_to_rows(symbolic, binding) expr_programs.append(program) target_idx = name_to_idx[expr_node.name] expr_target_indices_list.append(target_idx) idx_is_constant[target_idx] = all( idx_is_constant[int(binding[name])] for name in symbolic.referenced_names ) n_expressions = len(expr_programs) expr_target_indices = np.array(expr_target_indices_list, dtype=np.intp) # ------------------------------------------------------------------ # # 4. Compile PROFILE_SAMPLE groups # # ------------------------------------------------------------------ # profile_sample_groups: dict[str, list[GraphNode]] = {} for nid in topo_order: node = id_to_node[nid] if node.kind != NodeKind.PROFILE_SAMPLE: continue group_name = _profile_group_name(node.name, "profile_sample") profile_sample_groups.setdefault(group_name, []).append(node) plan_aux_axis = np.zeros(0, dtype=np.float64) n_aux = 0 profile_sample_base_indices_list: list[int] = [] profile_sample_component_indptr_list: list[int] = [0] profile_component_func_ids_list: list[int] = [] profile_component_param_indptr_list: list[int] = [0] profile_component_param_indices_list: list[int] = [] profile_sample_is_constant_list: list[bool] = [] profile_sample_group_idx: dict[str, int] = {} for group_name, sample_nodes in profile_sample_groups.items(): sample_nodes_sorted = sorted( sample_nodes, key=lambda node: _profile_group_index(node.name, "profile_sample"), ) aux_indices = [ _profile_group_index(node.name, "profile_sample") for node in sample_nodes_sorted ] if aux_indices != list(range(len(sample_nodes_sorted))): raise ValueError( f"PROFILE_SAMPLE nodes for {group_name!r} do not cover " "a contiguous aux-axis range" ) aux_axis = sample_nodes_sorted[0].arrays.get("aux_axis") if aux_axis is None: raise ValueError(f"PROFILE_SAMPLE {group_name!r} is missing aux_axis") aux_axis = np.asarray(aux_axis, dtype=np.float64) if n_aux == 0: n_aux = len(aux_axis) plan_aux_axis = aux_axis.copy() elif len(aux_axis) != n_aux or not np.array_equal(aux_axis, plan_aux_axis): raise ValueError("All lowered profile groups must share one fixed aux_axis") if len(sample_nodes_sorted) != n_aux: raise ValueError( f"PROFILE_SAMPLE group {group_name!r} has {len(sample_nodes_sorted)} " f"samples but aux_axis length {n_aux}" ) rep_node = sample_nodes_sorted[0] param_edges = sorted( param_edges_by_target.get(rep_node.id, []), key=lambda edge: edge.position or 0, ) if not param_edges: raise ValueError(f"PROFILE_SAMPLE {group_name!r} has no PARAM_INPUT edges") base_node = id_to_node[param_edges[0].source] if base_node.name not in name_to_idx: raise ValueError( f"PROFILE_SAMPLE base source {base_node.name!r} is not scalar-lowerable" ) base_idx = name_to_idx[base_node.name] is_constant = bool(idx_is_constant[base_idx]) profile_sample_base_indices_list.append(base_idx) component_func_by_name: dict[str, int] = {} component_param_indices_by_name: dict[str, list[int]] = {} component_order: list[str] = [] for edge in param_edges[1:]: src_node = id_to_node[edge.source] if src_node.name not in name_to_idx: raise ValueError( f"Profile parameter source {src_node.name!r} " "is not scalar-lowerable" ) comp_name, func_name = _parse_profile_component_param_name( group_name, src_node.name, ) prof_func_kind = _FUNCTION_NAME_TO_PROFILE_FUNC.get(func_name) if prof_func_kind is None: raise ValueError(f"Unknown profile function: {func_name!r}") if comp_name not in component_func_by_name: component_order.append(comp_name) component_func_by_name[comp_name] = int(prof_func_kind) component_param_indices_by_name[comp_name] = [] src_idx = name_to_idx[src_node.name] component_param_indices_by_name[comp_name].append(src_idx) is_constant = is_constant and bool(idx_is_constant[src_idx]) for comp_name in component_order: profile_component_func_ids_list.append(component_func_by_name[comp_name]) profile_component_param_indices_list.extend( component_param_indices_by_name[comp_name] ) profile_component_param_indptr_list.append( len(profile_component_param_indices_list) ) profile_sample_component_indptr_list.append( len(profile_component_func_ids_list) ) group_idx = len(profile_sample_base_indices_list) - 1 profile_sample_group_idx[group_name] = group_idx profile_sample_is_constant_list.append(is_constant) n_profile_samples = len(profile_sample_base_indices_list) profile_sample_base_indices = np.array( profile_sample_base_indices_list, dtype=np.intp ) profile_sample_component_indptr = np.array( profile_sample_component_indptr_list, dtype=np.intp ) profile_component_func_ids = np.array( profile_component_func_ids_list, dtype=np.intp ) profile_component_param_indptr = np.array( profile_component_param_indptr_list, dtype=np.intp ) profile_component_param_indices = np.array( profile_component_param_indices_list, dtype=np.intp ) profile_sample_is_constant = np.array( profile_sample_is_constant_list, dtype=np.bool_, ) # ------------------------------------------------------------------ # # 5. Compile per-sample profile expressions # # ------------------------------------------------------------------ # profile_expr_groups: dict[str, list[GraphNode]] = {} for nid in topo_order: node = id_to_node[nid] if _is_profile_expr_node(node): group_name = _profile_group_name(node.name, "profile_expr") profile_expr_groups.setdefault(group_name, []).append(node) profile_expr_programs: list[ExprProgram] = [] profile_expr_is_constant_list: list[bool] = [] profile_expr_group_idx: dict[str, int] = {} for group_name, expr_nodes in profile_expr_groups.items(): expr_nodes_sorted = sorted( expr_nodes, key=lambda node: _profile_group_index(node.name, "profile_expr"), ) aux_indices = [ _profile_group_index(node.name, "profile_expr") for node in expr_nodes_sorted ] if aux_indices != list(range(len(expr_nodes_sorted))): raise ValueError( f"Profile expression nodes for {group_name!r} do not cover " "a contiguous aux-axis range" ) if len(expr_nodes_sorted) != n_aux: raise ValueError( f"Profile expression group {group_name!r} has {len(expr_nodes_sorted)} " f"samples but aux_axis length {n_aux}" ) rep_node = expr_nodes_sorted[0] if rep_node.expr_string is None: raise ValueError( f"Profile expression {group_name!r} is missing expr_string" ) expr_refs = set(_extract_expression_references(rep_node.expr_string)) prof_ref_map: dict[str, int] = {} for edge in expr_ref_edges_by_target.get(rep_node.id, []): src_node = id_to_node[edge.source] if src_node.kind == NodeKind.PROFILE_SAMPLE: sample_name = _profile_group_name(src_node.name, "profile_sample") src_idx = n_params + profile_sample_group_idx[sample_name] match_name = sample_name else: src_idx = name_to_idx[src_node.name] match_name = src_node.name if match_name in expr_refs: prof_ref_map[match_name] = src_idx symbolic = compile_expr_symbolic(rep_node.expr_string) binding = dict(name_to_idx) binding.update(prof_ref_map) program = _bind_expr_to_rows(symbolic, binding) profile_expr_programs.append(program) is_constant = True for name in symbolic.referenced_names: bound_idx = int(binding[name]) if bound_idx < n_params: is_constant = is_constant and bool(idx_is_constant[bound_idx]) else: is_constant = is_constant and bool( profile_sample_is_constant[bound_idx - n_params] ) profile_expr_is_constant_list.append(is_constant) profile_expr_group_idx[group_name] = len(profile_expr_programs) - 1 n_profile_exprs = len(profile_expr_programs) profile_expr_is_constant = np.array(profile_expr_is_constant_list, dtype=np.bool_) # ------------------------------------------------------------------ # # 6. Schedule component ops # # ------------------------------------------------------------------ # peak_sum_sources: set[int] = set() peak_sum_nid = graph.node_by_name.get("peak_sum") if peak_sum_nid is not None: for edge in addend_edges_by_target.get(peak_sum_nid, []): peak_sum_sources.add(edge.source) profile_avg_sample_inputs: dict[int, list[GraphNode]] = {} sample_component_ids: set[int] = set() for nid in topo_order: node = id_to_node[nid] if node.kind != NodeKind.PROFILE_AVERAGE: continue sample_nodes = [ id_to_node[edge.source] for edge in addend_edges_by_target.get(node.id, []) if id_to_node[edge.source].kind in (NodeKind.COMPONENT_EVAL, NodeKind.SPECTRUM_FED_OP) ] profile_avg_sample_inputs[node.id] = sample_nodes sample_component_ids.update(sample.id for sample in sample_nodes) comp_nodes_topo = [ id_to_node[nid] for nid in topo_order if ( ( id_to_node[nid].kind in (NodeKind.COMPONENT_EVAL, NodeKind.SPECTRUM_FED_OP) and nid not in sample_component_ids ) or id_to_node[nid].kind == NodeKind.PROFILE_AVERAGE ) ] op_kinds_list: list[int] = [] op_param_indptr_list: list[int] = [0] op_param_source_kinds_list: list[int] = [] op_param_indices_list: list[int] = [] op_needs_spectrum_list: list[bool] = [] op_is_pre_spectrum_list: list[bool] = [] op_is_profiled_list: list[bool] = [] op_is_constant_list: list[bool] = [] for comp_node in comp_nodes_topo: if comp_node.kind == NodeKind.PROFILE_AVERAGE: sample_nodes = sorted( profile_avg_sample_inputs.get(comp_node.id, []), key=lambda node: _profile_component_sample_index(node.name), ) if not sample_nodes: raise ValueError( f"PROFILE_AVERAGE {comp_node.name!r} has no sample component inputs" ) if len(sample_nodes) != n_aux: raise ValueError( f"PROFILE_AVERAGE {comp_node.name!r} has " f"{len(sample_nodes)} samples " f"but aux_axis length {n_aux}" ) rep_node = sample_nodes[0] assert rep_node.function_name is not None op = _FUNCTION_NAME_TO_OP.get(rep_node.function_name) if op is None: raise ValueError( f"Unknown component function: {rep_node.function_name!r}" ) op_kinds_list.append(int(op)) op_is_profiled_list.append(True) rep_param_edges = sorted( param_edges_by_target.get(rep_node.id, []), key=lambda edge: edge.position or 0, ) sample_param_edges = [ sorted( param_edges_by_target.get(sample_node.id, []), key=lambda edge: edge.position or 0, ) for sample_node in sample_nodes ] is_constant = True for pos, rep_edge in enumerate(rep_param_edges): src_node = id_to_node[rep_edge.source] if src_node.kind == NodeKind.PROFILE_SAMPLE: group_name = _profile_group_name(src_node.name, "profile_sample") source_kind = int(ParamSourceKind.PROFILE_SAMPLE) source_idx = profile_sample_group_idx[group_name] is_constant = is_constant and bool( profile_sample_is_constant[source_idx] ) for aux_i, edges in enumerate(sample_param_edges): sample_src = id_to_node[edges[pos].source] if sample_src.kind != NodeKind.PROFILE_SAMPLE: raise ValueError( "Mixed parameter source kinds " f"in profiled op {comp_node.name!r}" ) if ( _profile_group_name(sample_src.name, "profile_sample") != group_name or _profile_group_index(sample_src.name, "profile_sample") != aux_i ): raise ValueError( "Inconsistent PROFILE_SAMPLE wiring " f"in {comp_node.name!r}" ) elif _is_profile_expr_node(src_node): group_name = _profile_group_name(src_node.name, "profile_expr") source_kind = int(ParamSourceKind.PROFILE_EXPR) source_idx = profile_expr_group_idx[group_name] is_constant = is_constant and bool( profile_expr_is_constant[source_idx] ) for aux_i, edges in enumerate(sample_param_edges): sample_src = id_to_node[edges[pos].source] if not _is_profile_expr_node(sample_src): raise ValueError( "Mixed expression source kinds " f"in profiled op {comp_node.name!r}" ) if ( _profile_group_name(sample_src.name, "profile_expr") != group_name or _profile_group_index(sample_src.name, "profile_expr") != aux_i ): raise ValueError( "Inconsistent profile-expression " f"wiring in {comp_node.name!r}" ) else: if src_node.name not in name_to_idx: raise ValueError( f"Non-scalar parameter source {src_node.name!r} in 1D op" ) source_kind = int(ParamSourceKind.SCALAR) source_idx = name_to_idx[src_node.name] is_constant = is_constant and bool(idx_is_constant[source_idx]) for edges in sample_param_edges[1:]: if id_to_node[edges[pos].source].id != src_node.id: raise ValueError( "Scalar parameter source changed " f"across samples in {comp_node.name!r}" ) op_param_source_kinds_list.append(source_kind) op_param_indices_list.append(source_idx) op_param_indptr_list.append(len(op_param_indices_list)) else: assert comp_node.function_name is not None op = _FUNCTION_NAME_TO_OP.get(comp_node.function_name) if op is None: raise ValueError( f"Unknown component function: {comp_node.function_name!r}" ) op_kinds_list.append(int(op)) op_is_profiled_list.append(False) param_edges = sorted( param_edges_by_target.get(comp_node.id, []), key=lambda edge: edge.position or 0, ) is_constant = True for edge in param_edges: src_node = id_to_node[edge.source] if src_node.name not in name_to_idx: raise ValueError( f"Non-scalar parameter source {src_node.name!r} in 1D op" ) src_idx = name_to_idx[src_node.name] op_param_source_kinds_list.append(int(ParamSourceKind.SCALAR)) op_param_indices_list.append(src_idx) is_constant = is_constant and bool(idx_is_constant[src_idx]) op_param_indptr_list.append(len(op_param_indices_list)) has_spec_input = comp_node.id in spectrum_input_targets op_needs_spectrum_list.append(has_spec_input) op_is_pre_spectrum_list.append(comp_node.id in peak_sum_sources) op_is_constant_list.append((not has_spec_input) and is_constant) n_ops = len(comp_nodes_topo) op_kinds = np.array(op_kinds_list, dtype=np.intp) op_param_indptr = np.array(op_param_indptr_list, dtype=np.intp) op_param_source_kinds = np.array(op_param_source_kinds_list, dtype=np.int8) op_param_indices = np.array(op_param_indices_list, dtype=np.intp) op_needs_spectrum = np.array(op_needs_spectrum_list, dtype=np.bool_) op_is_pre_spectrum = np.array(op_is_pre_spectrum_list, dtype=np.bool_) op_is_profiled = np.array(op_is_profiled_list, dtype=np.bool_) op_is_constant = np.array(op_is_constant_list, dtype=np.bool_) # ------------------------------------------------------------------ # # 7. Initialize scalar + profile values # # ------------------------------------------------------------------ # param_values_init = np.zeros(n_params, dtype=np.float64) for node in opt_nodes + static_nodes: param_values_init[name_to_idx[node.name]] = ( node.value if node.value is not None else 0.0 ) for i in range(n_expressions): target_idx = int(expr_target_indices[i]) param_values_init[target_idx] = _eval_expr_scalar( expr_programs[i], param_values_init ) profile_sample_values_init = _evaluate_profile_sample_values( plan_aux_axis, param_values_init, profile_sample_base_indices, profile_sample_component_indptr, profile_component_func_ids, profile_component_param_indptr, profile_component_param_indices, ) profile_expr_values_init = _evaluate_profile_expr_values( param_values_init, profile_sample_values_init, n_params, profile_expr_programs, ) # ------------------------------------------------------------------ # # 8. Precompute constant component contributions # # ------------------------------------------------------------------ # energy = graph.energy cached_result = np.zeros(len(energy), dtype=np.float64) cached_peak_sum = np.zeros_like(cached_result) for op_idx in range(n_ops): if not op_is_constant[op_idx]: continue start = int(op_param_indptr[op_idx]) end = int(op_param_indptr[op_idx + 1]) component = _evaluate_scheduled_op_1d( energy, int(op_kinds[op_idx]), op_param_source_kinds[start:end], op_param_indices[start:end], param_values_init, profile_sample_values_init, profile_expr_values_init, cached_peak_sum, needs_spectrum=bool(op_needs_spectrum[op_idx]), is_profiled=bool(op_is_profiled[op_idx]), n_aux=n_aux, ) cached_result += component if op_is_pre_spectrum[op_idx]: cached_peak_sum += component # ------------------------------------------------------------------ # # 9. Pack into ScheduledPlan1D # # ------------------------------------------------------------------ # return ScheduledPlan1D( energy=energy, n_params=n_params, param_values_init=param_values_init, opt_indices=opt_indices, opt_param_names=opt_param_names, n_expressions=n_expressions, expr_target_indices=expr_target_indices, expr_programs=expr_programs, n_aux=n_aux, aux_axis=plan_aux_axis, n_profile_samples=n_profile_samples, profile_sample_base_indices=profile_sample_base_indices, profile_sample_component_indptr=profile_sample_component_indptr, profile_component_func_ids=profile_component_func_ids, profile_component_param_indptr=profile_component_param_indptr, profile_component_param_indices=profile_component_param_indices, n_profile_exprs=n_profile_exprs, profile_expr_programs=profile_expr_programs, n_ops=n_ops, op_kinds=op_kinds, op_param_indptr=op_param_indptr, op_param_source_kinds=op_param_source_kinds, op_param_indices=op_param_indices, op_needs_spectrum=op_needs_spectrum, op_is_pre_spectrum=op_is_pre_spectrum, op_is_profiled=op_is_profiled, op_is_constant=op_is_constant, cached_result=cached_result, cached_peak_sum=cached_peak_sum, )