"""
Temporal dynamics functions for time-resolved spectroscopy.

Function Conventions
--------------------
Use CamelCase naming (UpperCamelCase or lowerCamelCase) for function names.

**Dynamics Functions:**
Signature: func(t, par1, par2, ..., t0, y0)
- t: Time axis (numpy array)
- par1, par2, ...: Function-specific parameters
- t0: Time zero (function starts at this time)
- y0: Offset value (baseline)
- Returns: f(t) = 0 for t < t0, dynamics for t >= t0

**Convolution Kernels:**
Signature: funcCONV(t, par1, par2, ...)
- t: Time axis centered at zero (from create_t_kernel)
- par1, par2, ...: Kernel parameters
- Returns: Normalized kernel function
- Must have a companion funcCONV_kernel_width(...) helper for support width

**Time Zero Convention:**
All dynamics functions are zero before t0 and activate at t >= t0.
This reflects physical causality: response begins after excitation.

**Offset Convention:**
Parameter y0 sets the asymptotic value or baseline.
- Decays: approach y0 as t → ∞
- Rises: start from 0, reach y0 + A
- Oscillations: oscillate around y0

**Time Resolution:**
Functions inherit time axis from Dynamics model. Consider:
- Time step size relative to dynamics (dt << tau)
- Time range coverage (include full decay/rise)
- Kernel width appropriate for convolution

Parameter Naming
----------------
Common parameter names:
- A: Amplitude (change in signal)
- tau: Time constant (decay/rise time, 1/e point)
- t0: Time zero (start of dynamics)
- y0: Offset/baseline value
- f: Frequency (for oscillations)
- phi: Phase (for oscillations)
- SD: Standard deviation (for Gaussian kernels)
- W: FWHM (for Lorentzian kernels)

Adding New Functions
--------------------
To add a new dynamics or convolution function:

1. Implement following conventions above
2. Ensure f(t<t0) = 0 for dynamics functions
3. Add kernel_width function for convolution kernels
4. Test with realistic time-resolved data
"""

import numpy as np
from scipy.special import erf, wofz


#
def none(t: np.ndarray) -> np.ndarray:
    """
    Placeholder function to define empty subcycles in a mcp.Dynamics model.

    Used to define empty subcycles in multi-cycle Dynamics models without
    adding any time-dependent behavior. This allows subcycle numbering to
    work correctly when some subcycles should have no dynamics.

    Usage (in model YAML file)::

        model_sub2:
          none: {}

    Parameters
    ----------
    t : ndarray
        Time axis (not used)

    Returns
    -------
    ndarray
        Array of zeros with same shape as t
    """

    # This function should never actually be called:
    # It is caught in mcp.Model.combine() and skipped entirely.
    return np.zeros_like(t)


#
def linFun(t: np.ndarray, m: float, t0: float, y0: float) -> np.ndarray:
    """
    Linear dynamics (constant rate of change).

    Parameters
    ----------
    t : ndarray
        Time axis
    m : float
        Slope (rate of change). Units: [signal units]/[time units]
        - m > 0: Linear increase
        - m < 0: Linear decrease
    t0 : float
        Time zero (start of linear change)
    y0 : float
        Offset value at t0 (initial value)

    Returns
    -------
    ndarray
        Linear function: 0 for t<t0, m*(t-t0)+y0 for t>=t0
    """

    return np.where(t < t0, 0.0, m * (t - t0) + y0)


#
def expFun(t: np.ndarray, A: float, tau: float, t0: float, y0: float) -> np.ndarray:
    """
    Exponential decay or rise dynamics.

    Parameters
    ----------
    t : ndarray
        Time axis
    A : float
        Amplitude (initial change at t0).
        - A > 0: Decay from y0+A to y0
        - A < 0: Rise from y0 to y0+|A|
    tau : float
        Time constant (1/e time). Units: [time units]
        At t = t0 + tau, signal changes by factor of e (≈2.718)
    t0 : float
        Time zero (start of exponential)
    y0 : float
        Asymptotic value (baseline as t → ∞)

    Returns
    -------
    ndarray
        Exponential: 0 for t<t0, A*exp(-(t-t0)/tau)+y0 for t>=t0
    """

    return np.where(t < t0, 0.0, A * np.exp(-1 / tau * (t - t0)) + y0)


#
def sinFun(
    t: np.ndarray, A: float, f: float, phi: float, t0: float, y0: float
) -> np.ndarray:
    """
    Sinusoidal oscillations (coherent dynamics).

    Parameters
    ----------
    t : ndarray
        Time axis
    A : float
        Oscillation amplitude (peak-to-peak = 2A)
    f : float
        Frequency in [1/time units]
        Period = 1/f
    phi : float
        Phase offset in radians
        - phi = 0: Sine starts at zero
        - phi = π/2: Starts at maximum (cosine)
        - phi = π: Starts at zero (negative slope)
    t0 : float
        Time zero (start of oscillation)
    y0 : float
        Offset (center line of oscillation)

    Returns
    -------
    ndarray
        Sinusoid: 0 for t<t0, A*sin(2πf(t-t0)+phi)+y0 for t>=t0
    """

    return np.where(t < t0, 0.0, A * np.sin(2 * np.pi * f * (t - t0) + phi) + y0)


#
def sinDivX(t: np.ndarray, A: float, f: float, t0: float, y0: float) -> np.ndarray:
    """
    Damped sinc function: sin(x)/x oscillation.

    Parameters
    ----------
    t : ndarray
        Time axis
    A : float
        Amplitude scaling factor
    f : float
        Frequency in [1/time units]
    t0 : float
        Time zero (start of oscillation)
    y0 : float
        Offset value

    Returns
    -------
    ndarray
        Sinc oscillation: 0 for t<t0, A*sin(2πf(t-t0))/(2πf(t-t0))+y0 for t>=t0
    """

    # np.sinc(u) = sin(pi*u)/(pi*u), so u=2*f*(t-t0) gives sin(2*pi*f*dt)/(2*pi*f*dt)
    return np.where(t < t0, 0.0, A * np.sinc(2 * f * (t - t0)) + y0)


#
def erfFun(t: np.ndarray, A: float, SD: float, t0: float, y0: float) -> np.ndarray:
    """
    Error function rise (step with Gaussian broadening).
    erfFun ≈ step ⊗ Gaussian(SD)

    Parameters
    ----------
    t : ndarray
        Time axis
    A : float
        Amplitude (total change from initial to final value)
    SD : float
        Standard deviation of Gaussian broadening (rise time ~2.355*SD)
        Smaller SD → sharper rise
    t0 : float
        Center of rise (50% point)
    y0 : float
        Final value (asymptote as t → ∞)

    Returns
    -------
    ndarray
        Error function: A/2 * (1 + erf((t-t0)/(SD*√2))) + y0
    """

    return np.asarray(A / 2 * (1 + erf((t - t0) / (SD * np.sqrt(2)))) + y0)


#
def sqrtFun(t: np.ndarray, A: float, t0: float, y0: float) -> np.ndarray:
    """
    Square root rise (diffusion dynamics).

    Parameters
    ----------
    t : ndarray
        Time axis
    A : float
        Amplitude scaling factor
    t0 : float
        Time zero (start of diffusion)
    y0 : float
        Offset value

    Returns
    -------
    ndarray
        Square root rise: 0 for t<t0, A*√(t-t0)+y0 for t>=t0
    """

    # numpy array .clip sets all t<t0 to zero
    return np.asarray(A * np.sqrt((t - t0).clip(0)) + y0)


#
# convolution functions
# kernels followed by respective recommended kernel width
#


#
def gaussCONV(x: np.ndarray, SD: float) -> np.ndarray:
    """
    Gaussian convolution kernel (instrumental response function).

    Parameters
    ----------
    x : ndarray
        Time axis (typically from Component.create_t_kernel, centered at 0)
    SD : float
        Standard deviation (Gaussian width).
        FWHM = 2.355 * SD = 2*√(2ln2) * SD

    Returns
    -------
    ndarray
        Gaussian kernel (unnormalized, will be normalized in convolution)
    """

    return np.exp(-1 / 2 * (x / SD) ** 2)


#
def gaussCONV_kernel_width(SD: float | None = None) -> int:
    """
    Kernel width multiplier for Gaussian convolution.
    Kernel extends to ±4*SD from center.
    At 4*SD, Gaussian has decayed to exp(-8) ≈ 3×10⁻⁴ of peak value.
    """

    return 4


#
def lorentzCONV(x: np.ndarray, W: float) -> np.ndarray:
    """
    Lorentzian convolution kernel.

    Parameters
    ----------
    x : ndarray
        Time axis (centered at 0)
    W : float
        Full width at half maximum (FWHM) of Lorentzian

    Returns
    -------
    ndarray
        Lorentzian kernel (unnormalized)
    """

    return 1 / (1 + (2 * x / W) ** 2)


#
def lorentzCONV_kernel_width(W: float | None = None) -> int:
    """Kernel width multiplier for Lorentzian (10×W)."""

    return 10


#
def voigtCONV(x: np.ndarray, SD: float, W: float) -> np.ndarray:
    """
    Voigt convolution kernel (Gaussian and Lorentzian combined).

    Parameters
    ----------
    x : ndarray
        Time axis (centered at 0)
    SD : float
        Gaussian standard deviation
    W : float
        Lorentzian FWHM

    Returns
    -------
    ndarray
        Voigt kernel (normalized to peak = 1)
    """

    voigt = np.real(wofz((x + 1j * (W / 2)) / SD / np.sqrt(2)))
    return np.asarray(voigt / np.max(voigt))


#
def voigtCONV_kernel_width(SD: float = 1.0, W: float = 0.0) -> float:
    """Kernel width multiplier for Voigt support.

    ``create_t_kernel`` multiplies the first kernel parameter by this
    value. For ``voigtCONV`` the first parameter is ``SD``, but broad
    Lorentzian tails are controlled by ``W``. Return a multiplier large
    enough that the support spans at least ``max(12*SD, 10*W)``.
    """

    if SD <= 0:
        return 12.0
    return max(12.0, 10.0 * W / SD)


#
def expSymCONV(x: np.ndarray, tau: float) -> np.ndarray:
    """
    Symmetric exponential kernel (double exponential).
    Exponential decay in both directions from center:
    ``exp(-|x|/tau)``

    Parameters
    ----------
    x : ndarray
        Time axis (centered at 0)
    tau : float
        Decay time constant

    Returns
    -------
    ndarray
        Symmetric exponential kernel
    """

    return np.asarray(np.exp(-1 / tau * np.abs(x)))


#
def expSymCONV_kernel_width(tau: float | None = None) -> int:
    """Kernel width multiplier for symmetric exponential (6×tau)."""

    return 6


#
def expDecayCONV(x: np.ndarray, tau: float) -> np.ndarray:
    """
    Causal exponential kernel (one-sided decay).

    Parameters
    ----------
    x : ndarray
        Time axis (centered at 0)
    tau : float
        Decay time constant

    Returns
    -------
    ndarray
        One-sided exponential: 0 for x<0, exp(-x/tau) for x≥0
    """

    return np.where(x < 0, 0.0, expSymCONV(x, tau))


#
def expDecayCONV_kernel_width(tau: float | None = None) -> int:
    """Kernel width multiplier for decay exponential (6×tau)."""

    return 6


#
def expRiseCONV(x: np.ndarray, tau: float) -> np.ndarray:
    """
    Causal exponential rise kernel.

    Parameters
    ----------
    x : ndarray
        Time axis (centered at 0)
    tau : float
        Rise time constant

    Returns
    -------
    ndarray
        One-sided exponential: exp(x/tau) for x≤0, 0 for x>0
    """

    return np.where(x > 0, 0.0, expSymCONV(x, tau))


#
def expRiseCONV_kernel_width(tau: float | None = None) -> int:
    """Kernel width multiplier for rise exponential (6×tau)."""

    return 6


#
def boxCONV(x: np.ndarray, width: float) -> np.ndarray:
    """
    Box (rectangular) convolution kernel.

    Parameters
    ----------
    x : ndarray
        Time axis (centered at 0)
    width : float
        Width of rectangular window

    Returns
    -------
    ndarray
        Rectangular function: 1 inside width, 0 outside (with smooth edges)
    """

    return np.where(np.abs(x) <= width / 2, 1.0, 0.0)


#
def boxCONV_kernel_width(width: float | None = None) -> int:
    """Kernel width multiplier for box (1×width)."""

    return 1
