"""Imposed gas temperature profile (PFRGasTemperatureProfile).

This example demonstrates the **Dirichlet-on-gas** reactor
:class:`~bloc.reactors.PFRGasTemperatureProfile`: instead of prescribing the
*wall* temperature and letting the gas lag behind it through a convective
film (as :class:`~bloc.reactors.PFRWallProfile` / ``TubeFurnace`` does), the
**gas** temperature itself is imposed along the tube as a known profile
``T_gas(x)``.  The energy equation is switched off, so only the kinetics
evolve under that prescribed thermal history.

When to use it
--------------
Use this model when you have a *measured* in-stream gas temperature (e.g. a
thermocouple trace down the axis of a tube furnace) and you want the predicted
species — not a predicted temperature.  Decoupling chemistry from the energy
balance isolates mechanism behaviour under a known thermal history, which is
the standard setup for kinetics validation against measured profiles.

What the script does
--------------------
A dilute C2H4-in-N2 feed is pushed through a 64 cm tube.  The gas temperature
follows a measured axial profile (the same tabular ``(x, T)`` data used in the
``ethylene_pyrolysis`` validation case in ``omnisoot-cases``).  The profile is
loaded into NumPy arrays and passed to the network as a linearly interpolated
``T_gas(x)`` callable.  The script plots the imposed temperature together with
the resulting C2H4 / H2 / C2H2 mole-fraction profiles.

Validation cases typically keep the profile as a ``(x_m, T_K)`` array pair —
either defined inline or read from a CSV like ``data/ethylene_temperature_profile.csv``.

Run from the repository root::

    conda run -n bloc python examples/imposed_gas_temperature_profile.py
"""

from __future__ import annotations

from collections.abc import Callable, Sequence
from pathlib import Path

import cantera as ct
import matplotlib.pyplot as plt
import numpy as np

from bloc.reactor_models import PFRGasTemperatureProfileNet
from bloc.reactors import PFRGasTemperatureProfile
from bloc.utils import get_mechanism_path

# ---------------------------------------------------------------------------
# Geometry, feed and the measured gas-temperature profile
# ---------------------------------------------------------------------------
_DATA_DIR = Path(__file__).parent / "data"

DIAMETER = 0.05  # [m]
TOTAL_LENGTH = 0.64  # [m]  reactor length in the ethylene_pyrolysis case
MASS_FLOW_RATE = 0.39915  # [kg/s]
COMPOSITION = "C2H4:0.01, N2:0.99"
P_INLET = 101_350.0  # [Pa]


def load_temperature_profile_csv(path: Path) -> tuple[np.ndarray, np.ndarray]:
    """Load a measured ``(x, T)`` profile from a two-column CSV.

    The ethylene-pyrolysis reference file labels the position column ``z[cm]``
    but the values are axial positions in **metres** (0–0.64 m), matching
    ``reactor_length`` in the omnisoot validation case.
    """
    data = np.loadtxt(path, delimiter=",", skiprows=1)
    x_m = data[:, 0]
    T_K = data[:, 1]
    return x_m, T_K


def make_T_gas_fn(
    x_m: Sequence[float], T_K: Sequence[float]
) -> Callable[[float], float]:
    """Build a piecewise-linear ``T_gas(x)`` from tabular breakpoints [m], [K]."""
    x_arr = np.asarray(x_m, dtype=float)
    T_arr = np.asarray(T_K, dtype=float)
    if x_arr.ndim != 1 or T_arr.ndim != 1 or x_arr.size != T_arr.size:
        raise ValueError("x_m and T_K must be 1-D arrays of equal length.")
    if x_arr.size < 2:
        raise ValueError("At least two profile points are required.")

    def T_gas(x: float) -> float:
        return float(np.interp(x, x_arr, T_arr, left=T_arr[0], right=T_arr[-1]))

    return T_gas


def build_network(
    x_m: np.ndarray, T_K: np.ndarray, *, total_length: float = TOTAL_LENGTH
) -> PFRGasTemperatureProfileNet:
    """Build the PFRGasTemperatureProfile engine and its imposed-T network."""
    T_gas_fn = make_T_gas_fn(x_m, T_K)
    gas = ct.Solution(get_mechanism_path("gri30.yaml"))
    gas.TPX = T_gas_fn(0.0), P_INLET, COMPOSITION
    reactor = PFRGasTemperatureProfile(
        gas,
        clone=False,
        diameter=DIAMETER,
        mass_flow_rate=MASS_FLOW_RATE,
    )
    reactor.volume = np.pi / 4.0 * DIAMETER**2 * total_length
    return PFRGasTemperatureProfileNet(
        [reactor],
        T_gas_fn=T_gas_fn,
        total_length=total_length,
    )


def main() -> None:
    """Run the march and plot imposed T_gas(x) + species profiles."""
    # Tabular profile: validation cases load measured data into (x_m, T_K) arrays.
    x_m, T_K = load_temperature_profile_csv(
        _DATA_DIR / "ethylene_temperature_profile.csv"
    )
    # Alternatively, define the breakpoints directly:
    # x_m = np.array([0.0, 0.32, 0.64])
    # T_K = np.array([463.4, 1685.0, 682.1])

    net = build_network(x_m, T_K)
    net.advance_to_steady_state()

    print(f"Residence time     : {net.time:.3f} s")
    print(f"Outlet temperature : {net.reactor.phase.T:.1f} K (imposed)")
    i_c2h4 = net.species_names.index("C2H4")
    i_h2 = net.species_names.index("H2")
    print(
        f"C2H4 conversion     : {1 - net.X_arr[-1, i_c2h4] / net.X_arr[0, i_c2h4]:.1%}"
    )
    print(f"Outlet X_H2        : {net.X_arr[-1, i_h2]:.3f}")

    fig, (ax_T, ax_X) = plt.subplots(2, 1, figsize=(8, 7), sharex=True)
    x_cm = net.x_arr * 100.0
    ax_T.plot(x_cm, net.T_gas_arr - 273.15, color="firebrick", label="imposed")
    ax_T.scatter(
        x_m * 100.0, T_K - 273.15, s=12, color="black", alpha=0.5, label="data"
    )
    ax_T.set_ylabel("Imposed gas T [\u00b0C]")
    ax_T.set_title("Measured gas temperature profile")
    ax_T.legend(fontsize=8)

    for sp in ("C2H4", "H2", "C2H2", "C6H6"):
        if sp in net.species_names:
            i = net.species_names.index(sp)
            ax_X.plot(x_cm, net.X_arr[:, i] * 100.0, label=sp, linewidth=1.5)
    ax_X.set_xlabel("Position [cm]")
    ax_X.set_ylabel("Mole fraction [%]")
    ax_X.set_title("Predicted species under the imposed temperature")
    ax_X.legend(ncol=2, fontsize=8)
    ax_X.set_xlim(0, TOTAL_LENGTH * 100)
    fig.tight_layout()
    plt.show()


if __name__ == "__main__":
    main()
