Source code for trspecfit.eval_1d

"""1D evaluator for the compiled backend.

All component functions live in ``trspecfit.functions.energy`` as the
single source of truth.  For 1D evaluation, parameters are plain
scalars (no ``(n_time, 1)`` broadcasting needed).
"""

from __future__ import annotations

import numpy as np

from trspecfit.graph_ir import (
    ExprNodeKind,
    ExprProgram,
    ScheduledPlan1D,
    _evaluate_profile_expr_values,
    _evaluate_profile_sample_values,
    _evaluate_scheduled_op_1d,
)

# ---------------------------------------------------------------------------
# Scalar RPN expression evaluator
# ---------------------------------------------------------------------------


#
[docs] def eval_expr_program_1d( program: ExprProgram, values: np.ndarray, ) -> float: """Evaluate an RPN ExprProgram against a scalar parameter vector. Parameters ---------- program Compiled RPN instruction array. values ``(n_params,)`` scalar parameter vector. Returns ------- float Scalar result. """ stack: list[float] = [] instr = program.instructions n_instr = len(instr) // 2 for i in range(n_instr): kind = ExprNodeKind(instr[2 * i]) operand = instr[2 * i + 1] if kind == ExprNodeKind.CONST: stack.append(float(np.int64(operand).view(np.float64))) elif kind == ExprNodeKind.PARAM_REF: stack.append(float(values[int(operand)])) elif kind == ExprNodeKind.ADD: b, a = stack.pop(), stack.pop() stack.append(a + b) elif kind == ExprNodeKind.SUB: b, a = stack.pop(), stack.pop() stack.append(a - b) elif kind == ExprNodeKind.MUL: b, a = stack.pop(), stack.pop() stack.append(a * b) elif kind == ExprNodeKind.DIV: b, a = stack.pop(), stack.pop() stack.append(a / b) elif kind == ExprNodeKind.NEG: stack.append(-stack.pop()) elif kind == ExprNodeKind.POW: b, a = stack.pop(), stack.pop() stack.append(a**b) assert len(stack) == 1 return stack[0]
# --------------------------------------------------------------------------- # Core 1D evaluator # --------------------------------------------------------------------------- #
[docs] def evaluate_1d(plan: ScheduledPlan1D, theta: np.ndarray) -> np.ndarray: """Evaluate the compiled 1D model at optimizer parameters *theta*. Parameters ---------- plan Immutable compiled execution schedule from ``schedule_1d``. theta ``(n_opt,)`` optimizer parameter vector. Order must match ``plan.opt_param_names``. Returns ------- ndarray ``(n_energy,)`` model spectrum. Raises ------ ValueError If ``len(theta) != len(plan.opt_indices)``. """ if len(theta) != len(plan.opt_indices): raise ValueError( f"theta length {len(theta)} does not match " f"plan.opt_indices length {len(plan.opt_indices)}" ) # 1a. Copy parameter values -> scratch values = plan.param_values_init.copy() # 1b. Write optimizer params values[plan.opt_indices] = theta # 1c. Resolve expressions in topological order for i in range(plan.n_expressions): target = int(plan.expr_target_indices[i]) values[target] = eval_expr_program_1d(plan.expr_programs[i], values) profile_sample_values = _evaluate_profile_sample_values( plan.aux_axis, values, plan.profile_sample_base_indices, plan.profile_sample_component_indptr, plan.profile_component_func_ids, plan.profile_component_param_indptr, plan.profile_component_param_indices, ) profile_expr_values = _evaluate_profile_expr_values( values, profile_sample_values, plan.n_params, plan.profile_expr_programs, ) # 2. Component evaluation energy = plan.energy result = plan.cached_result.copy() peak_sum = plan.cached_peak_sum.copy() for op_idx in range(plan.n_ops): if plan.op_is_constant[op_idx]: continue kind = int(plan.op_kinds[op_idx]) start = int(plan.op_param_indptr[op_idx]) end = int(plan.op_param_indptr[op_idx + 1]) needs_spectrum = bool(plan.op_needs_spectrum[op_idx]) is_pre = bool(plan.op_is_pre_spectrum[op_idx]) component = _evaluate_scheduled_op_1d( energy, kind, plan.op_param_source_kinds[start:end], plan.op_param_indices[start:end], values, profile_sample_values, profile_expr_values, peak_sum, needs_spectrum=needs_spectrum, is_profiled=bool(plan.op_is_profiled[op_idx]), n_aux=plan.n_aux, ) result += component if is_pre: peak_sum += component return result