"""
Round-trip tests for the fit archive: save → load → field equality.

For each (model family, fit type), run the fit, write the archive via
``Project.save_fits``, read it back via ``FitResults.load`` (and via
``Project.load_fits``), and verify every user-visible slot field is
reconstructed exactly. Also exercises the design invariant that
``observed - fit`` reproduces residuals for any fit_type without reading
``file.data``.

Coverage matrix
---------------
- F1  (basic):              baseline, spectrum, sbs
- F3  (basic + dynamics):   2d
- F6  (profile-only):       baseline, spectrum, sbs
- F8  (profile + dynamics): baseline, 2d

The matrix covers basic / profile / profile+dynamics × applicable fit
types — and aligns with
``tests/roundtrip/matrix.py`` for B/Sp/SbS on F1/F6 (the families that
support 1D fits). F6 has no top-level dynamics so its 2d slot would
have nothing extra to assert vs F3; F8 covers the full "profile + 2d
dynamics" payload.
"""

from __future__ import annotations

import matplotlib

matplotlib.use("Agg")

from typing import Any

import numpy as np
import pandas as pd
import pytest
from _utils import make_project, simulate_noisy
from roundtrip.families import FAMILIES

from trspecfit import FitResults
from trspecfit.utils.fit_io import SavedFitSlot


#
def _build_fit_file(family_id: str, *, spec_fun_str: str = "fit_model_gir"):
    """Build (truth_file, fit_file, family) for a family, with noisy data.

    Mirrors the setup pattern used in roundtrip/test_focused.py: truth
    file + simulator → noisy data → empty fit file with the same model.
    Noise is small (0.01) so fits converge but chi2/AIC/BIC stay finite
    (clean data → chi2=0 → log(0)=nan in AIC).
    """

    family = FAMILIES[family_id]
    truth_project = make_project(name="rt_truth", spec_fun_str=spec_fun_str)
    truth_file = family.build_truth(truth_project, variant="default")
    data = simulate_noisy(truth_file.model_active, noise_level=0.01)

    fit_project = make_project(name="rt_fit", spec_fun_str=spec_fun_str)
    fit_kwargs: dict[str, Any] = {
        "data": data,
        "energy": truth_file.energy,
        "time": truth_file.time,
        "variant": "default",
    }
    if family.needs_aux:
        fit_kwargs["aux"] = truth_file.aux_axis
    fit_file = family.build_fit(fit_project, **fit_kwargs)
    return truth_file, fit_file, family


#
def _save_load_one(project, archive_path) -> tuple[SavedFitSlot, FitResults]:
    """Save → load and return (loaded slot, FitResults) for the single in-memory slot.

    Asserts that exactly one slot survives the round-trip — keeps the
    individual tests focused on field equality rather than count book-
    keeping.
    """

    project.save_fits(archive_path, show_output=0)
    loaded = FitResults.load(archive_path)
    assert len(loaded) == 1, f"expected 1 loaded slot, got {len(loaded)}"
    return next(iter(loaded)), loaded


#
def _assert_slot_round_tripped(loaded: SavedFitSlot, original: SavedFitSlot) -> None:
    """Assert every persisted SavedFitSlot field round-trips exactly.

    Covers identity (fingerprint, hashes, selection), arrays, metrics,
    params, and provenance. Also verifies the design invariant that
    ``observed - fit`` reproduces residuals on the loaded slot alone (no
    ``file.data`` lookup) — both via direct subtraction and against the
    stored chi2.
    """

    # --- identity ------------------------------------------------------
    assert loaded.file_name == original.file_name
    assert loaded.model_name == original.model_name
    assert loaded.fit_type == original.fit_type
    assert loaded.selection_json == original.selection_json
    assert loaded.selection == original.selection
    assert loaded.history_key == original.history_key
    assert loaded.observed_sha256 == original.observed_sha256
    assert loaded.file_fingerprint == original.file_fingerprint

    # --- arrays --------------------------------------------------------
    np.testing.assert_array_equal(loaded.observed, original.observed)
    np.testing.assert_array_equal(loaded.fit, original.fit)
    assert loaded.observed.shape == loaded.fit.shape
    assert loaded.observed.dtype == original.observed.dtype
    assert loaded.fit.dtype == original.fit.dtype

    # --- metrics -------------------------------------------------------
    assert set(loaded.metrics.keys()) == set(original.metrics.keys())
    metric_keys = ("chi2_raw", "chi2_red_raw", "chi2", "chi2_red", "r2", "aic", "bic")
    if original.fit_type == "sbs":
        for k in metric_keys:
            # equal_nan=True so NaN-valued calibrated metrics (when no σ was
            # set on the file at fit time) round-trip as exact NaN matches.
            np.testing.assert_allclose(
                loaded.metrics[k], original.metrics[k], rtol=0, atol=0, equal_nan=True
            )
    else:
        for k in metric_keys:
            orig_v = original.metrics[k]
            loaded_v = loaded.metrics[k]
            if isinstance(orig_v, float) and np.isnan(orig_v):
                assert np.isnan(loaded_v)
            else:
                assert loaded_v == pytest.approx(orig_v, rel=0, abs=0)

    # --- noise metadata -----------------------------------------------
    assert loaded.noise_type == original.noise_type
    assert loaded.sigma_source == original.sigma_source
    assert loaded.sigma_type == original.sigma_type
    for name in ("sigma_data", "sigma_eff"):
        orig_v = getattr(original, name)
        loaded_v = getattr(loaded, name)
        if np.isnan(orig_v):
            assert np.isnan(loaded_v), f"{name} NaN round-trip failed"
        else:
            assert loaded_v == pytest.approx(orig_v, rel=0, abs=0)

    # --- params --------------------------------------------------------
    _assert_params_equal(loaded.params, original.params, fit_type=original.fit_type)

    # --- provenance ----------------------------------------------------
    assert loaded.fit_alg == original.fit_alg
    assert loaded.yaml_filename == original.yaml_filename
    assert loaded.timestamp == original.timestamp

    # --- residual reconstruction (design invariant) --------------------
    # chi2_raw is the lmfit-unweighted SSE diagnostic; chi2 is σ-calibrated
    # and NaN when no σ was set on the file, so we cross-check against the raw
    # column (always populated and grid-derived).
    residual = loaded.observed - loaded.fit
    assert residual.shape == loaded.observed.shape
    if loaded.fit_type == "sbs":
        assert residual.ndim == 2
        for i in range(residual.shape[0]):
            assert loaded.metrics["chi2_raw"][i] == pytest.approx(
                float(np.sum(residual[i] ** 2))
            )
    else:
        assert loaded.metrics["chi2_raw"] == pytest.approx(float(np.sum(residual**2)))


#
def _assert_params_equal(
    loaded: pd.DataFrame, original: pd.DataFrame, *, fit_type: str
) -> None:
    """Compare params DataFrames column-wise.

    Handles two layouts:

    - **long-form** (baseline / spectrum / 2d): mixed-dtype columns
      including ``expr`` (str | None) and ``stderr`` (float | None).
      ``_restore_long_params_nones`` in the reader maps ``""`` → ``None``
      and ``NaN`` → ``None`` so the round-tripped frame matches the
      lmfit-original.
    - **wide-form** (sbs): all-float columns, one row per slice.

    Compared column-by-column rather than via ``assert_frame_equal``
    because the writer round-trips through structured arrays — minor
    dtype quirks (object vs string) on the ``expr`` column would
    otherwise fail an exact frame-equality check despite values matching.
    """

    assert list(loaded.columns) == list(original.columns)
    assert len(loaded) == len(original)
    for col in original.columns:
        orig_vals = original[col].to_list()
        load_vals = loaded[col].to_list()
        assert len(orig_vals) == len(load_vals)
        for o, ll in zip(orig_vals, load_vals, strict=True):
            if isinstance(o, float) and np.isnan(o):
                # Both should be NaN (or None ↔ None handled below).
                assert isinstance(ll, float) and np.isnan(ll), (
                    f"col {col!r}: orig=NaN, loaded={ll!r}"
                )
            elif o is None:
                assert ll is None, f"col {col!r}: orig=None, loaded={ll!r}"
            elif isinstance(o, float):
                assert ll == pytest.approx(o, rel=0, abs=0), (
                    f"col {col!r}: orig={o!r}, loaded={ll!r}"
                )
            else:
                assert ll == o, f"col {col!r}: orig={o!r}, loaded={ll!r}"


# ---------------------------------------------------------------------------
# baseline round-trip
# ---------------------------------------------------------------------------


#
@pytest.mark.parametrize("family_id", ["F1", "F6", "F8"])
def test_baseline_roundtrip(family_id: str, tmp_path) -> None:
    """basic / profile / profile+dynamics × baseline."""

    _, fit_file, family = _build_fit_file(family_id)
    fit_file.fit_baseline(model_name=family.model_name("default"), stages=1, try_ci=0)
    project = fit_file.p
    archive_path = tmp_path / "baseline.fit.h5"

    loaded_slot, _ = _save_load_one(project, archive_path)
    original = project._fit_history[0]
    assert original.fit_type == "baseline"
    _assert_slot_round_tripped(loaded_slot, original)


# ---------------------------------------------------------------------------
# spectrum round-trip
# ---------------------------------------------------------------------------


#
@pytest.mark.parametrize("family_id", ["F1", "F6"])
def test_spectrum_roundtrip(family_id: str, tmp_path) -> None:
    """basic / profile × spectrum: 1D fit at a single time point.

    F6 covers the profile path through ``fit_spectrum``: profiles
    propagate into the per-spectrum lmfit params (one ``pExpDecay`` /
    ``pLinear`` parameter set per profiled base parameter), and the
    serialized params DataFrame must round-trip without losing those
    rows or their min/max/expr metadata.
    """

    _, fit_file, family = _build_fit_file(family_id)
    fit_file.fit_spectrum(
        family.model_name("default"),
        time_point=10,
        time_type="ind",
        stages=1,
        try_ci=0,
        show_plot=False,
    )
    project = fit_file.p
    archive_path = tmp_path / "spectrum.fit.h5"

    loaded_slot, _ = _save_load_one(project, archive_path)
    original = project._fit_history[0]
    assert original.fit_type == "spectrum"
    assert loaded_slot.selection["time_point"] == 10
    assert loaded_slot.selection["time_type"] == "ind"
    _assert_slot_round_tripped(loaded_slot, original)


# ---------------------------------------------------------------------------
# slice-by-slice round-trip
# ---------------------------------------------------------------------------


#
@pytest.mark.slow
@pytest.mark.parametrize("family_id", ["F1", "F6"])
def test_sbs_roundtrip(family_id: str, tmp_path) -> None:
    """basic / profile × slice-by-slice (per-slice metrics, wide-form params)."""

    _, fit_file, family = _build_fit_file(family_id, spec_fun_str="fit_model_mcp")
    fit_file.fit_slice_by_slice(
        family.model_name("default"),
        stages=1,
        n_workers=1,
        seed_source="model",
        seed_adapt=None,
        try_ci=0,
    )
    project = fit_file.p
    archive_path = tmp_path / "sbs.fit.h5"

    loaded_slot, _ = _save_load_one(project, archive_path)
    original = project._fit_history[0]
    assert original.fit_type == "sbs"
    # Per-slice metrics are arrays sized to the time axis.
    assert loaded_slot.metrics["chi2"].shape == (len(fit_file.time),)
    _assert_slot_round_tripped(loaded_slot, original)


# ---------------------------------------------------------------------------
# 2D round-trip
# ---------------------------------------------------------------------------


#
def _fit_2d_with_dynamics(family_id: str):
    """Run the baseline → add_dynamics → fit_2d pipeline for a 2D family.

    Mirrors the standard 2D workflow used in test_focused.py: fit the
    baseline first to seed amplitudes, attach dynamics on the fit-side
    file, then run the joint 2D fit.
    """

    _, fit_file, family = _build_fit_file(family_id)
    fit_file.fit_baseline(model_name=family.model_name("default"), stages=1, try_ci=0)
    assert family.add_dynamics is not None  # type guard
    family.add_dynamics(fit_file, "default")
    fit_file.fit_2d(model_name=family.model_name("default"), stages=1, try_ci=0)
    return fit_file, family


#
@pytest.mark.parametrize("family_id", ["F3", "F8"])
def test_2d_roundtrip(family_id: str, tmp_path) -> None:
    """basic+dynamics / profile+dynamics × 2d.

    The 2D slot lives alongside the baseline slot in ``_fit_history``;
    this test saves only the 2d slot via the ``fit_type`` filter so the
    round-trip is unambiguous.
    """

    fit_file, _ = _fit_2d_with_dynamics(family_id)
    project = fit_file.p
    # _fit_history holds [baseline, 2d]; filter to just the 2d slot.
    archive_path = tmp_path / "2d.fit.h5"
    project.save_fits(archive_path, fit_type="2d", show_output=0)
    loaded = FitResults.load(archive_path)
    assert len(loaded) == 1
    loaded_slot = next(iter(loaded))

    original = next(s for s in project._fit_history if s.fit_type == "2d")
    assert loaded_slot.observed.ndim == 2
    _assert_slot_round_tripped(loaded_slot, original)


# ---------------------------------------------------------------------------
# load entry-point parity
# ---------------------------------------------------------------------------


#
def test_project_load_fits_matches_fitresults_load(tmp_path) -> None:
    """``Project.load_fits`` is documented as a thin delegate to ``FitResults.load``.

    Verify both entry points return field-equal slot lists for the same
    archive — guards against drift if either path adds incidental
    transformations later.
    """

    _, fit_file, family = _build_fit_file("F1")
    fit_file.fit_baseline(model_name=family.model_name("default"), stages=1, try_ci=0)
    project = fit_file.p
    archive_path = tmp_path / "parity.fit.h5"
    project.save_fits(archive_path, show_output=0)

    via_class = list(FitResults.load(archive_path))
    via_project = list(project.load_fits(archive_path, show_output=0))
    assert len(via_class) == len(via_project) == 1
    _assert_slot_round_tripped(via_project[0], via_class[0])


# ---------------------------------------------------------------------------
# multi-file + multi-fit-type archive round-trip
# ---------------------------------------------------------------------------


#
@pytest.mark.slow
def test_multi_slot_roundtrip(tmp_path) -> None:
    """Archive with multiple slots from one file (baseline + spectrum + sbs).

    Exercises the writer's per-file slot-list handling and the reader's
    flatten-into-FitResults order. All three slots must be recoverable
    field-by-field, not just by count.
    """

    _, fit_file, family = _build_fit_file("F1", spec_fun_str="fit_model_mcp")
    fit_file.fit_baseline(model_name=family.model_name("default"), stages=1, try_ci=0)
    fit_file.fit_spectrum(
        family.model_name("default"),
        time_point=10,
        time_type="ind",
        stages=1,
        try_ci=0,
        show_plot=False,
    )
    fit_file.fit_slice_by_slice(
        family.model_name("default"),
        stages=1,
        n_workers=1,
        seed_source="model",
        seed_adapt=None,
        try_ci=0,
    )
    project = fit_file.p
    assert len(project._fit_history) == 3

    archive_path = tmp_path / "multi.fit.h5"
    project.save_fits(archive_path, show_output=0)

    loaded = FitResults.load(archive_path)
    assert len(loaded) == 3

    # Match loaded slots to originals by history_key (order-independent).
    by_key = {s.history_key: s for s in loaded}
    for original in project._fit_history:
        assert original.history_key in by_key
        _assert_slot_round_tripped(by_key[original.history_key], original)
