"""2D evaluator for the compiled backend.
All component functions live in ``trspecfit.functions.energy`` as the
single source of truth. Peak functions broadcast naturally with
``(n_time, 1)`` params and ``(1, n_energy)`` energy. Background
functions (Offset, LinBack, Shirley) accept optional or axis-agnostic
signatures that work for both 1D and 2D evaluation.
"""
from __future__ import annotations
import numpy as np
from trspecfit.functions import time as fcts_time
from trspecfit.graph_ir import (
OP_DISPATCH,
PROFILE_DISPATCH,
ConvKernelKind,
DynFuncKind,
ExprNodeKind,
ExprProgram,
ParamSourceKind,
ScheduledPlan2D,
)
from trspecfit.utils.arrays import my_conv
# ---------------------------------------------------------------------------
# Dynamics dispatch table
# ---------------------------------------------------------------------------
DYNAMICS_DISPATCH: dict[int, tuple] = {
DynFuncKind.EXPFUN: (fcts_time.expFun, 4),
DynFuncKind.SINFUN: (fcts_time.sinFun, 5),
DynFuncKind.LINFUN: (fcts_time.linFun, 3),
DynFuncKind.SINDIVX: (fcts_time.sinDivX, 4),
DynFuncKind.ERFFUN: (fcts_time.erfFun, 4),
DynFuncKind.SQRTFUN: (fcts_time.sqrtFun, 3),
}
# Convolution kernel dispatch: kernel function evaluated on the frozen
# kernel-time support with per-theta kernel parameters. Mirrors MCP's
# Model.combine(...) path.
CONV_KERNEL_DISPATCH: dict[int, tuple] = {
ConvKernelKind.GAUSSCONV: (fcts_time.gaussCONV, 1),
ConvKernelKind.LORENTZCONV: (fcts_time.lorentzCONV, 1),
ConvKernelKind.VOIGTCONV: (fcts_time.voigtCONV, 2),
ConvKernelKind.EXPSYMCONV: (fcts_time.expSymCONV, 1),
ConvKernelKind.EXPDECAYCONV: (fcts_time.expDecayCONV, 1),
ConvKernelKind.EXPRISECONV: (fcts_time.expRiseCONV, 1),
ConvKernelKind.BOXCONV: (fcts_time.boxCONV, 1),
}
# ---------------------------------------------------------------------------
# Shared RPN expression evaluator
# ---------------------------------------------------------------------------
#
[docs]
def eval_expr_program(
program: ExprProgram,
traces: np.ndarray,
) -> np.ndarray:
"""Evaluate an RPN ExprProgram against the trace matrix.
Works for both plan initialization and hot-path evaluation.
Each PARAM_REF reads a full ``(n_time,)`` row from *traces*;
constants are broadcast to ``(n_time,)`` via ``np.full``.
Parameters
----------
program
Compiled RPN instruction array.
traces
``(n_params, n_time)`` trace matrix (current state).
Returns
-------
ndarray
``(n_time,)`` result.
"""
n_time = traces.shape[1]
stack: list[np.ndarray] = []
instr = program.instructions
n_instr = len(instr) // 2
for i in range(n_instr):
kind = ExprNodeKind(instr[2 * i])
operand = instr[2 * i + 1]
if kind == ExprNodeKind.CONST:
val = np.int64(operand).view(np.float64)
stack.append(np.full(n_time, val, dtype=np.float64))
elif kind == ExprNodeKind.PARAM_REF:
stack.append(traces[int(operand), :].copy())
elif kind == ExprNodeKind.ADD:
b, a = stack.pop(), stack.pop()
stack.append(a + b)
elif kind == ExprNodeKind.SUB:
b, a = stack.pop(), stack.pop()
stack.append(a - b)
elif kind == ExprNodeKind.MUL:
b, a = stack.pop(), stack.pop()
stack.append(a * b)
elif kind == ExprNodeKind.DIV:
b, a = stack.pop(), stack.pop()
stack.append(a / b)
elif kind == ExprNodeKind.NEG:
stack.append(-stack.pop())
elif kind == ExprNodeKind.POW:
b, a = stack.pop(), stack.pop()
stack.append(a**b)
assert len(stack) == 1
return stack[0]
# ---------------------------------------------------------------------------
# Profile evaluation helpers (2D)
# ---------------------------------------------------------------------------
#
def _evaluate_profile_sample_values_2d(
aux_axis: np.ndarray,
traces: np.ndarray,
profile_sample_base_rows: np.ndarray,
profile_sample_component_indptr: np.ndarray,
profile_component_func_ids: np.ndarray,
profile_component_param_indptr: np.ndarray,
profile_component_param_rows: np.ndarray,
) -> np.ndarray:
"""Evaluate lowered PROFILE_SAMPLE groups into ``(n_groups, n_time, n_aux)``.
Profile functions broadcast naturally: ``aux_axis`` is shaped
``(1, n_aux)`` and each param trace is ``(n_time, 1)``, yielding
``(n_time, n_aux)`` per function call.
"""
n_groups = len(profile_sample_base_rows)
n_time = traces.shape[1]
n_aux = len(aux_axis)
if n_groups == 0:
return np.zeros((0, n_time, n_aux), dtype=np.float64)
aux_2d = aux_axis[np.newaxis, :] # (1, n_aux)
sample_values = np.empty((n_groups, n_time, n_aux), dtype=np.float64)
for group_idx in range(n_groups):
base_row = int(profile_sample_base_rows[group_idx])
# base trace -> (n_time, 1) -> broadcast to (n_time, n_aux)
values = np.broadcast_to(
traces[base_row, :][:, np.newaxis], (n_time, n_aux)
).copy()
comp_start = int(profile_sample_component_indptr[group_idx])
comp_end = int(profile_sample_component_indptr[group_idx + 1])
for comp_idx in range(comp_start, comp_end):
func = PROFILE_DISPATCH[int(profile_component_func_ids[comp_idx])]
param_start = int(profile_component_param_indptr[comp_idx])
param_end = int(profile_component_param_indptr[comp_idx + 1])
params = [
traces[int(row), :][:, np.newaxis] # (n_time, 1)
for row in profile_component_param_rows[param_start:param_end]
]
values += np.asarray(func(aux_2d, *params), dtype=np.float64)
sample_values[group_idx] = values
return sample_values
#
def _evaluate_profile_expr_values_2d(
traces: np.ndarray,
profile_sample_values: np.ndarray,
n_params: int,
profile_expr_programs: list[ExprProgram],
) -> np.ndarray:
"""Evaluate lowered per-sample profile expressions over (n_time, n_aux).
Builds a virtual trace matrix ``(n_params + n_groups, n_time * n_aux)``
so the standard RPN evaluator can be reused unchanged.
"""
n_exprs = len(profile_expr_programs)
if n_exprs == 0:
n_time = traces.shape[1]
n_aux = profile_sample_values.shape[2] if profile_sample_values.size else 0
return np.zeros((0, n_time, n_aux), dtype=np.float64)
n_time = traces.shape[1]
n_aux = profile_sample_values.shape[2]
n_groups = profile_sample_values.shape[0]
n_cols = n_time * n_aux
# Virtual trace: regular params repeated across aux, profile samples
# flattened from (n_groups, n_time, n_aux) -> (n_groups, n_time*n_aux).
virtual = np.empty((n_params + n_groups, n_cols), dtype=np.float64)
virtual[:n_params, :] = np.repeat(traces, n_aux, axis=1)
if n_groups > 0:
virtual[n_params:, :] = profile_sample_values.reshape(n_groups, n_cols)
expr_values = np.empty((n_exprs, n_time, n_aux), dtype=np.float64)
for expr_idx, program in enumerate(profile_expr_programs):
result = eval_expr_program(program, virtual) # (n_time*n_aux,)
expr_values[expr_idx] = result.reshape(n_time, n_aux)
return expr_values
#
def _evaluate_profiled_op_2d(
energy: np.ndarray,
kind: int,
param_source_kinds: np.ndarray,
param_indices: np.ndarray,
traces: np.ndarray,
profile_sample_values: np.ndarray,
profile_expr_values: np.ndarray,
peak_sum: np.ndarray,
*,
needs_spectrum: bool,
n_aux: int,
) -> np.ndarray:
"""Evaluate one profiled 2D op: loop over aux points, average."""
func, _needs = OP_DISPATCH[kind]
n_time = traces.shape[1]
n_energy = energy.shape[-1]
accumulated = np.zeros((n_time, n_energy), dtype=np.float64)
for aux_i in range(n_aux):
params: list[np.ndarray] = []
for source_kind, source_idx in zip(
param_source_kinds,
param_indices,
strict=True,
):
sk = int(source_kind)
si = int(source_idx)
if sk == int(ParamSourceKind.SCALAR):
param = traces[si, :][:, np.newaxis] # (n_time, 1)
elif sk == int(ParamSourceKind.PROFILE_SAMPLE):
param = profile_sample_values[si, :, aux_i][
:, np.newaxis
] # (n_time, 1)
else:
param = profile_expr_values[si, :, aux_i][:, np.newaxis] # (n_time, 1)
params.append(param)
if needs_spectrum:
accumulated += func(energy, *params, peak_sum)
else:
accumulated += func(energy, *params)
accumulated /= n_aux
return accumulated
# ---------------------------------------------------------------------------
# Core 2D evaluator
# ---------------------------------------------------------------------------
#
[docs]
def evaluate_2d(plan: ScheduledPlan2D, theta: np.ndarray) -> np.ndarray:
"""Evaluate the compiled 2D model at optimizer parameters *theta*.
Parameters
----------
plan
Immutable compiled execution schedule from ``schedule_2d``.
theta
``(n_opt,)`` optimizer parameter vector. Order must match
``plan.opt_param_names``.
Returns
-------
ndarray
``(n_time, n_energy)`` model spectrum.
Raises
------
ValueError
If ``len(theta) != len(plan.opt_indices)``.
"""
# --- theta contract check ---
if len(theta) != len(plan.opt_indices):
raise ValueError(
f"theta length {len(theta)} does not match "
f"plan.opt_indices length {len(plan.opt_indices)}"
)
# 1a. Copy trace matrix -> scratch
traces = plan.param_traces_init.copy()
# 1b. Broadcast optimizer params
traces[plan.opt_indices, :] = theta[:, np.newaxis]
# 1c+d. Resolve dynamics groups and expressions in interleaved topo
# order. A dynamics group evaluates all substeps (e.g. two expFun
# in a bi-exponential) and sums them: target = base + sum(traces).
# Expression-valued dynamics params are resolved before the group
# that consumes them.
for step in range(len(plan.resolution_kinds)):
kind = int(plan.resolution_kinds[step])
idx = int(plan.resolution_indices[step])
if kind == 0: # dynamics group
target = int(plan.dyn_group_target_row[idx])
base = int(plan.dyn_group_base_row[idx])
traces[target, :] = traces[base, :]
s_start = int(plan.dyn_group_indptr[idx])
s_end = int(plan.dyn_group_indptr[idx + 1])
for s in range(s_start, s_end):
func_id = int(plan.dyn_sub_func_id[s])
func, _n_par = DYNAMICS_DISPATCH[func_id]
n_par = int(plan.dyn_sub_n_params[s])
param_rows = plan.dyn_sub_param_rows[s, :n_par]
dyn_params = [float(traces[int(row), 0]) for row in param_rows]
traces[target, :] += (
func(plan.dyn_sub_time_axes[s], *dyn_params) * plan.dyn_sub_masks[s]
)
elif kind == 1: # expression
target = int(plan.expr_target_rows[idx])
traces[target, :] = eval_expr_program(plan.expr_programs[idx], traces)
else: # kind == 2: resolved-trace convolution
target = int(plan.conv_target_rows[idx])
func_id = int(plan.conv_func_ids[idx])
kernel_func, _k_par = CONV_KERNEL_DISPATCH[func_id]
p_start = int(plan.conv_param_indptr[idx])
p_end = int(plan.conv_param_indptr[idx + 1])
kernel_params = [
float(traces[int(plan.conv_param_rows[j]), 0])
for j in range(p_start, p_end)
]
s_start = int(plan.conv_support_indptr[idx])
s_end = int(plan.conv_support_indptr[idx + 1])
support = plan.conv_support_values[s_start:s_end]
kernel = kernel_func(support, *kernel_params)
traces[target, :] = my_conv(plan.time, traces[target, :], kernel)
# 1e. Profile evaluation (after parameter resolution).
profile_sample_values = _evaluate_profile_sample_values_2d(
plan.aux_axis,
traces,
plan.profile_sample_base_rows,
plan.profile_sample_component_indptr,
plan.profile_component_func_ids,
plan.profile_component_param_indptr,
plan.profile_component_param_rows,
)
profile_expr_values = _evaluate_profile_expr_values_2d(
traces,
profile_sample_values,
plan.n_params,
plan.profile_expr_programs,
)
# 2. Component evaluation
energy = plan.energy[np.newaxis, :] # (1, n_energy)
result = plan.cached_result.copy()
peak_sum = plan.cached_peak_sum.copy()
for op_idx in range(plan.n_ops):
if plan.op_is_constant[op_idx]:
continue
kind = int(plan.op_kinds[op_idx])
start = int(plan.op_param_indptr[op_idx])
end = int(plan.op_param_indptr[op_idx + 1])
needs_spectrum = bool(plan.op_needs_spectrum[op_idx])
is_pre = bool(plan.op_is_pre_spectrum[op_idx])
if plan.op_is_profiled[op_idx]:
component = _evaluate_profiled_op_2d(
energy,
kind,
plan.op_param_source_kinds[start:end],
plan.op_param_indices[start:end],
traces,
profile_sample_values,
profile_expr_values,
peak_sum,
needs_spectrum=needs_spectrum,
n_aux=plan.n_aux,
)
else:
param_rows = plan.op_param_indices[start:end]
# Gather params as (n_time, 1) columns
params: list[np.ndarray] = [
traces[int(row), :][:, np.newaxis] for row in param_rows
]
func, _needs = OP_DISPATCH[kind]
if needs_spectrum:
component = func(energy, *params, peak_sum)
else:
component = func(energy, *params)
result += component
if is_pre:
peak_sum += component
return result