# Copyright 2026 Philippe Billet assisted by LLMs in free mode: chatGPT, Qwen, Deepseek, Gemini, Claude, le chat Mistral.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
caustics.py — Catastrophe classification, ray caustic detection, and uniform asymptotic corrections
===========================================================================================================
Overview
--------
The `caustics` module provides a comprehensive toolkit for working with caustics in semiclassical analysis and catastrophe theory. It consolidates previously scattered functionality into a single, well‑tested module with a clean API. The main components are:
1. **Arnold classification** (algebraic) – Symbolic classification of critical points of a function `H(ξ)` (1D) or `H(ξ,η)` (2D) into elementary catastrophes: Morse, fold (A2), cusp (A3), swallowtail (A4), butterfly (A5), and umbilics (D4±). Uses successive derivatives and Hessian analysis.
2. **Catastrophe detection** (adaptive numerical solver) – Robustly finds all critical points of `H` by combining symbolic solving, coarse grid scanning, Newton refinement, and clustering. Then classifies them using the algebraic routines.
3. **Ray‑based caustic detection** – Operates on a bundle of rays (output of `wkb.py`) that contain the integrated **stability matrix** `J = ∂x(t)/∂q`. Caustics are detected when `det(J)` crosses zero. This corrects the common mistake of using velocity zeros or momentum minima as caustic indicators. Maslov indices and phase shifts are computed automatically.
4. **Uniform special functions** – Implements the Airy function (fold), the Pearcey integral (cusp), and a swallowtail integral (A4), together with the correct prefactors for uniform WKB approximations near caustics.
5. **Visualisation helpers** – Functions to plot catastrophe points on 1D curves or 2D surfaces, and to overlay caustic events on ray bundles.
The module is designed to be **independent** of the other packages (it does not import `wkb.py`, `psiop.py`, etc.), but it is used by `wkb.py` for caustic detection and correction.
Mathematical background
-----------------------
**Caustics** are envelopes of families of rays (bicharacteristics) where the ray density becomes infinite. In the semiclassical approximation, the amplitude diverges at a caustic, and the standard WKB ansatz breaks down. Uniform approximations replace the oscillatory exponential by special functions.
For a Hamiltonian system with `n` degrees of freedom, a family of rays is parameterised by an initial‑condition parameter `q ∈ ℝⁿ`. The **stability matrix**
J(t) = ∂x(t)/∂q
satisfies the variational equation `dJ/dt = H_{px}·J` along each ray. A caustic occurs when `det(J(t)) = 0`. The **Maslov index** `μ` is the signed number of such zero crossings; each crossing contributes a phase shift of `π/2` to the wave function.
**Arnold’s classification of catastrophes** (singularities of gradient maps) provides a taxonomy of caustics:
* **A₂ (fold)**: normal form `ξ³` – the simplest caustic, described by the Airy function.
* **A₃ (cusp)**: normal form `ξ⁴` – described by the Pearcey integral.
* **A₄ (swallowtail)**: normal form `ξ⁵` – a three‑parameter integral.
* **D₄± (umbilics)**: normal forms `ξ³ + η³` (hyperbolic) and `ξ³ – 3ξη²` (elliptic) – degenerate cases with Hessian rank 0.
The module classifies a critical point (where `∇H = 0`) by examining its Hessian rank and higher derivatives, using algebraic invariants.
References
----------
.. [1] Arnold, V. I. *Catastrophe Theory*, Springer‑Verlag, 1986.
.. [2] Duistermaat, J. J. “Oscillatory integrals, Lagrange immersions and unfolding of singularities”, *Comm. Pure Appl. Math.* **27**, 207–281, 1974.
.. [3] Maslov, V. P. & Fedoriuk, M. V. *Semi‑Classical Approximation in Quantum Mechanics*, Reidel, 1981.
.. [4] Kravtsov, Yu. A. & Orlov, Yu. I. *Caustics, Catastrophes and Wave Fields*, Springer, 1999.
.. [5] Berry, M. V. & Howls, C. J. “High orders of the Weyl expansion for quantum billiards”, *Phys. Rev. E* **50**(5), 3577–3595, 1994.
.. [6] Connor, J. N. L. “Practical methods for the uniform asymptotic evaluation of oscillatory integrals”, *Mol. Phys.* **31**(1), 33–55, 1976.
"""
from __future__ import annotations
import itertools
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import sympy as sp
try:
from scipy.integrate import solve_ivp
from scipy.optimize import fsolve
from scipy.special import airy
_HAS_SCIPY = True
except ImportError:
_HAS_SCIPY = False
warnings.warn("scipy not found — numeric fallbacks disabled.", ImportWarning)
try:
import matplotlib.pyplot as _plt
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
_HAS_MPL = True
except ImportError:
_HAS_MPL = False
# ══════════════════════════════════════════════════════════════════════════════
# SECTION 1 — Arnold classification (algebraic)
# ══════════════════════════════════════════════════════════════════════════════
[docs]
def classify_arnold_1d(f: sp.Expr,
xi: sp.Symbol,
point: float,
max_order: int = 8,
tol: float = 1e-8) -> dict:
"""
Classify a 1D singularity of f(xi) at xi = point.
Algorithm: evaluate successive derivatives f^(k)(point) until the first
non-vanishing one determines the A_k type.
Parameters
----------
f : sympy.Expr
Function of xi.
xi : sympy.Symbol
point : float
Critical point (should satisfy f'(point) ≈ 0).
max_order : int
Maximum derivative order to check (default 8 covers up to A7).
tol : float
Threshold for "non-zero" derivative.
Returns
-------
dict with keys:
- 'type' : str e.g. "A2 (Fold)", "A3 (Cusp)", ...
- 'order' : int k such that f^(k)(point) ≠ 0
- 'derivatives' : dict {k: value} for k = 1..max_order
- 'point' : float
"""
subs = {xi: point}
derivs = {}
for k in range(1, max_order + 1):
val = float(sp.N(sp.diff(f, (xi, k)).subs(subs)))
derivs[k] = val
# k=1 non-zero → regular point, not a caustic
if abs(derivs.get(1, 0.0)) > tol:
return {"type": "regular (not critical)", "order": 1,
"derivatives": derivs, "point": point}
_ARNOLD_1D = {
2: "A1 (Morse/non-degenerate)",
3: "A2 (Fold)",
4: "A3 (Cusp)",
5: "A4 (Swallowtail)",
6: "A5 (Butterfly)",
7: "A6",
8: "A7",
}
for k in range(2, max_order + 1):
if abs(derivs.get(k, 0.0)) > tol:
label = _ARNOLD_1D.get(k, f"A{k-1} (higher)")
return {"type": label, "order": k,
"derivatives": derivs, "point": point}
return {"type": f"flat (all derivatives vanish up to order {max_order})",
"order": None, "derivatives": derivs, "point": point}
[docs]
def classify_arnold_2d(H: sp.Expr,
xi: sp.Symbol,
eta: sp.Symbol,
point: dict,
tol: float = 1e-8) -> dict:
"""
Classify a 2D catastrophe at a critical point of H(xi, eta).
Follows Arnold's classification:
- Morse rank 2 → non-degenerate
- A_k rank 1 → fold (A2), cusp (A3), swallowtail (A4), butterfly (A5), A6+
- D4± rank 0 → hyperbolic umbilic (I>0) / elliptic umbilic (I<0)
- higher rank 0, I=0 → E6 or beyond
Parameters
----------
H : sympy.Expr
Function of (xi, eta).
xi, eta : sympy.Symbol
point : dict
Numeric values {'xi': float, 'eta': float}.
Extra keys (space coordinates) are ignored.
tol : float
Returns
-------
dict with keys:
- 'type' : str
- 'hessian' : 2×2 list
- 'third_order_tensor' : dict of H_xxx, H_xxy, H_xyy, H_yyy
- 'directional_derivatives': dict of D3…D6 (rank-1 case)
- 'cubic_invariant_I' : float (rank-0 case)
"""
# --- Determine xi and eta values (accept both symbol and string keys) ---
if xi in point:
xi_val = float(point[xi])
eta_val = float(point[eta])
elif "xi" in point and "eta" in point:
xi_val = float(point["xi"])
eta_val = float(point["eta"])
else:
raise KeyError(f"Point dict must contain keys {xi} or 'xi'/'eta'")
subs = {xi: xi_val, eta: eta_val}
# subs = {xi: float(point["xi"]), eta: float(point["eta"])}
# ── Hessian ──────────────────────────────────────────────────
H_xx = float(sp.N(sp.diff(H, (xi, 2)).subs(subs)))
H_yy = float(sp.N(sp.diff(H, (eta, 2)).subs(subs)))
H_xy = float(sp.N(sp.diff(H, xi, eta).subs(subs)))
Hess = np.array([[H_xx, H_xy], [H_xy, H_yy]], dtype=float)
rank = int(np.linalg.matrix_rank(Hess, tol=tol))
eigvals, eigvecs = np.linalg.eigh(Hess) # eigh: real symmetric
# ── Third-order tensor ────────────────────────────────────────
H_xxx = float(sp.N(sp.diff(H, (xi, 3)).subs(subs)))
H_xxy = float(sp.N(sp.diff(H, xi, xi, eta).subs(subs)))
H_xyy = float(sp.N(sp.diff(H, xi, eta, eta).subs(subs)))
H_yyy = float(sp.N(sp.diff(H, (eta, 3)).subs(subs)))
third = {"H_xxx": H_xxx, "H_xxy": H_xxy, "H_xyy": H_xyy, "H_yyy": H_yyy}
base = {"hessian": Hess.tolist(), "third_order_tensor": third}
# ── Case 1 : Morse (rank 2) ───────────────────────────────────
if rank == 2:
return {**base, "type": "Morse (non-degenerate)"}
# ── Case 2 : rank 1 → A_k series ─────────────────────────────
if rank == 1:
# Null direction = eigenvector of the zero eigenvalue
null_idx = int(np.argmin(np.abs(eigvals)))
null_dir = eigvecs[:, null_idx]
vx, vy = float(null_dir[0]), float(null_dir[1])
# Directional derivative operator D
def _D(expr):
return vx * sp.diff(expr, xi) + vy * sp.diff(expr, eta)
# Build D^k H symbolically up to order 6
DH = _D(H)
D2H = _D(DH)
D3H = _D(D2H)
D4H = _D(D3H)
D5H = _D(D4H)
D6H = _D(D5H)
D3 = float(sp.N(D3H.subs(subs)))
D4 = float(sp.N(D4H.subs(subs)))
D5 = float(sp.N(D5H.subs(subs)))
D6 = float(sp.N(D6H.subs(subs)))
directional = {"D3": D3, "D4": D4, "D5": D5, "D6": D6,
"null_direction": null_dir.tolist()}
_AK = {3: "A2 (Fold)", 4: "A3 (Cusp)",
5: "A4 (Swallowtail)", 6: "A5 (Butterfly)"}
for order, (val, lbl) in zip(
[3, 4, 5, 6], [(D3, _AK[3]), (D4, _AK[4]),
(D5, _AK[5]), (D6, _AK[6])]):
if abs(val) > tol:
return {**base, "type": lbl,
"directional_derivatives": directional}
return {**base, "type": "A6+ or higher degeneracy",
"directional_derivatives": directional}
# ── Case 3 : rank 0 → D4 umbilics ────────────────────────────
# Discriminant of the cubic form C(v,w) = a v³ + 3b v²w + 3c vw² + d w³
# with a=H_xxx, b=H_xxy, c=H_xyy, d=H_yyy
# Δ > 0 → three distinct real roots → D4- (elliptic umbilic)
# Δ < 0 → one real root → D4+ (hyperbolic umbilic)
# Δ = 0 → repeated root → higher degeneracy (E6…)
a, b, c, d = H_xxx, H_xxy, H_xyy, H_yyy
I = 18*a*b*c*d - 4*b**3*d + b**2*c**2 - 4*a*c**3 - 27*a**2*d**2
if abs(I) < tol:
return {**base, "type": "D4 degenerate (I≈0) — possible E6 or higher",
"cubic_invariant_I": float(I)}
elif I < 0:
# Δ < 0 → one real root of cubic → D4+ (hyperbolic umbilic)
# Normal form: x³ + y³ (three-cusped shape)
return {**base, "type": "D4+ (Hyperbolic umbilic)",
"cubic_invariant_I": float(I)}
else:
# Δ > 0 → three distinct real roots → D4- (elliptic umbilic)
# Normal form: x³ - 3xy² (three-petalled shape)
return {**base, "type": "D4- (Elliptic umbilic)",
"cubic_invariant_I": float(I)}
# ══════════════════════════════════════════════════════════════════════════════
# SECTION 2 — Catastrophe detection with adaptive solver
# ══════════════════════════════════════════════════════════════════════════════
[docs]
def find_critical_points_numerical(
grad_func,
initial_guesses: List[np.ndarray],
tolerance: float = 1e-6,
domain: Optional[List[Tuple[float, float]]] = None,
cluster_tol: float = 1e-4,
) -> List[np.ndarray]:
"""
Shared numerical kernel for locating critical points where ∇φ = 0.
Minimises |∇φ(x)|² from each initial guess using L-BFGS-B, then
deduplicates nearby solutions with a greedy DBSCAN-style clustering.
This function is the single source of truth used by:
- ``asymptotic.Analyzer.find_critical_points``
- ``fio_bridge._BoundAnalyzer.find_critical_points``
It can also be called directly for lightweight use cases.
Parameters
----------
grad_func : callable
Function ``grad_func(*x) -> array-like`` returning the gradient ∇φ
evaluated at the point ``x`` (unpacked coordinates).
initial_guesses : list of np.ndarray
Starting points for the optimisation.
tolerance : float
Convergence threshold: a point is accepted when |∇φ|² < tolerance.
domain : list of (lo, hi) tuples, optional
If provided, only points satisfying ``lo <= x_i <= hi`` are kept.
cluster_tol : float
Distance threshold for deduplication; solutions closer than this
are merged into a single representative point (cluster mean).
Returns
-------
list of np.ndarray
Unique critical points found within the given tolerance and domain.
"""
if not _HAS_SCIPY:
raise RuntimeError("scipy is required for find_critical_points_numerical.")
from scipy.optimize import minimize as _minimize
def objective(x):
g = np.asarray(grad_func(*x), dtype=float)
return float(np.dot(g, g))
raw = []
for guess in initial_guesses:
try:
res = _minimize(objective, guess, tol=tolerance, method='L-BFGS-B')
if res.success and res.fun < tolerance:
xc = res.x
if domain is not None:
if not all(lo <= xi <= hi for xi, (lo, hi) in zip(xc, domain)):
continue
raw.append(xc)
except Exception:
pass
# Greedy DBSCAN-style deduplication
clusters: List[List[np.ndarray]] = []
for pt in raw:
for cluster in clusters:
if np.linalg.norm(pt - cluster[0]) < cluster_tol:
cluster.append(pt)
break
else:
clusters.append([pt])
return [np.mean(c, axis=0) for c in clusters]
[docs]
class AdaptiveCriticalPointSolver:
"""
Find critical points of H(xi_vars) robustly.
Strategy
--------
1. Symbolic solve (sp.solve) — exact when polynomial.
2. Coarse grid scan — evaluate |∇H| on a grid, keep near-zero cells
as seeds for Newton refinement.
3. Newton / fsolve refinement — high-accuracy convergence from seeds.
4. DBSCAN-style deduplication — merge solutions closer than `cluster_tol`.
This replaces the fragile uniform-grid + sp.nsolve approach that missed
solutions and was slow.
Parameters
----------
H_expr : sp.Expr
xi_vars : sequence of sp.Symbol
Frequency variables (the unknowns for ∇H = 0).
coords : sequence of sp.Symbol, optional
Additional parameter symbols (treated as unknowns too).
bounds : dict {symbol: (lo, hi)}, optional
Search bounds per symbol. Defaults to (-5, 5) for all.
coarse_n : int
Points per axis for the coarse scan (default 30).
cluster_tol : float
Distance threshold for deduplication (default 1e-4).
grad_tol : float
|∇H| threshold to accept a point as critical (default 1e-8).
"""
def __init__(self,
H_expr: sp.Expr,
xi_vars: Sequence[sp.Symbol],
coords: Optional[Sequence[sp.Symbol]] = None,
bounds: Optional[Dict] = None,
coarse_n: int = 30,
cluster_tol: float = 1e-4,
grad_tol: float = 1e-8):
self.H_expr = H_expr
self.xi_vars = list(xi_vars)
self.coords = list(coords) if coords else []
self.unknowns = self.xi_vars + self.coords
self.dim = len(self.unknowns)
self.coarse_n = coarse_n
self.cluster_tol = cluster_tol
self.grad_tol = grad_tol
# Default bounds
_default = (-5.0, 5.0)
self.bounds = {v: _default for v in self.unknowns}
if bounds:
self.bounds.update(bounds)
# Pre-compile gradient as numpy function
self._grad_exprs = [sp.diff(H_expr, v) for v in self.unknowns]
self._grad_func = sp.lambdify(self.unknowns, self._grad_exprs, "numpy")
self._H_func = sp.lambdify(self.unknowns, H_expr, "numpy")
# ── Private helpers ──────────────────────────────────────────
def _eval_grad(self, pt: np.ndarray) -> np.ndarray:
"""Evaluate ∇H at a point as a numpy array."""
try:
g = self._grad_func(*pt)
return np.array(g, dtype=float).ravel()
except Exception:
return np.full(self.dim, np.inf)
def _coarse_seeds(self) -> List[np.ndarray]:
"""Coarse grid scan: return points where |∇H| is locally small."""
axes = [np.linspace(self.bounds[v][0], self.bounds[v][1], self.coarse_n)
for v in self.unknowns]
grid_pts = list(itertools.product(*axes))
# Evaluate |∇H|² on all grid points
grad_norms = []
for pt in grid_pts:
g = self._eval_grad(np.array(pt))
grad_norms.append(np.dot(g, g))
grad_norms = np.array(grad_norms)
# Keep seeds where |∇H|² is in the bottom 10% OR below an absolute threshold
finite_norms = grad_norms[np.isfinite(grad_norms)]
if len(finite_norms) == 0:
return []
thresh = min(np.percentile(finite_norms, 10),
0.1 * float(np.median(finite_norms)) + 1e-6)
seeds = [np.array(grid_pts[i])
for i, v in enumerate(grad_norms)
if v <= thresh and np.isfinite(v)]
return seeds
def _newton_refine(self, seed: np.ndarray) -> Optional[np.ndarray]:
"""Newton refinement from a seed point.
Uses scipy.fsolve when available, else a simple numpy Newton loop."""
if _HAS_SCIPY:
try:
from scipy.optimize import fsolve
sol, info, ier, _ = fsolve(
self._eval_grad, seed,
full_output=True, xtol=1e-12, ftol=1e-12
)
if ier == 1:
residual = np.max(np.abs(self._eval_grad(sol)))
if residual < self.grad_tol:
return sol
except Exception:
pass
return None
# Pure-numpy Newton fallback (finite-difference Jacobian)
pt = seed.copy().astype(float)
eps_fd = 1e-6
for _ in range(50):
g = self._eval_grad(pt)
if np.max(np.abs(g)) < self.grad_tol:
return pt
# Build finite-difference Jacobian of ∇H
n = len(pt)
J = np.zeros((n, n))
for j in range(n):
dv = np.zeros(n); dv[j] = eps_fd
J[:, j] = (self._eval_grad(pt + dv) - g) / eps_fd
try:
delta = np.linalg.solve(J, -g)
pt = pt + delta
except np.linalg.LinAlgError:
break
g_final = self._eval_grad(pt)
if np.max(np.abs(g_final)) < self.grad_tol:
return pt
return None
def _deduplicate(self, solutions: List[np.ndarray]) -> List[np.ndarray]:
"""
Merge solutions closer than cluster_tol (greedy DBSCAN-style).
Returns a list of representative points (cluster means).
"""
if not solutions:
return []
clusters: List[List[np.ndarray]] = []
for sol in solutions:
placed = False
for cluster in clusters:
if np.linalg.norm(sol - cluster[0]) < self.cluster_tol:
cluster.append(sol)
placed = True
break
if not placed:
clusters.append([sol])
return [np.mean(c, axis=0) for c in clusters]
def _symbolic_solve(self) -> Optional[List[np.ndarray]]:
"""Attempt symbolic solution with sp.solve."""
eqs = [sp.Eq(g, 0) for g in self._grad_exprs]
try:
sol_list = sp.solve(eqs, self.unknowns, dict=True)
if not sol_list:
return None
result = []
for sol in sol_list:
try:
pt = np.array([float(sp.N(sol.get(v, v)))
for v in self.unknowns], dtype=float)
if np.all(np.isfinite(pt)):
result.append(pt)
except Exception:
continue
return result if result else None
except Exception:
return None
# ── Public API ───────────────────────────────────────────────
[docs]
def solve(self, method: str = "auto") -> List[np.ndarray]:
"""
Find all critical points.
Parameters
----------
method : "symbolic" | "adaptive" | "auto"
- "symbolic" : sp.solve only (fast for polynomials)
- "adaptive" : coarse grid + Newton (always numerical)
- "auto" : try symbolic first, fall back to adaptive
Returns
-------
list of np.ndarray, each of shape (dim,)
Each array gives the values of (xi_vars + coords) at the
critical point.
"""
solutions = []
if method in ("symbolic", "auto"):
sym_sols = self._symbolic_solve()
if sym_sols:
solutions.extend(sym_sols)
if method == "symbolic":
return self._deduplicate(solutions)
if method in ("adaptive", "auto") and not solutions:
seeds = self._coarse_seeds()
refined = []
for seed in seeds:
pt = self._newton_refine(seed)
if pt is not None:
refined.append(pt)
solutions.extend(refined)
return self._deduplicate(solutions)
[docs]
def to_dict_list(self, solutions: Optional[List[np.ndarray]] = None
) -> List[Dict]:
"""
Convert solutions to list of dicts {symbol: value}.
If solutions is None, calls self.solve() first.
"""
if solutions is None:
solutions = self.solve()
return [{v: float(pt[i]) for i, v in enumerate(self.unknowns)}
for pt in solutions]
[docs]
def detect_catastrophes(H_expr: sp.Expr,
xi_vars: Sequence[sp.Symbol],
coords: Optional[Sequence[sp.Symbol]] = None,
method: str = "auto",
bounds: Optional[Dict] = None,
coarse_n: int = 30,
max_order: int = 6,
tol: float = 1e-8) -> List[Dict]:
"""
Detect and classify catastrophes of H_expr(xi_vars).
Combines AdaptiveCriticalPointSolver (robust search) with
classify_arnold_1d / classify_arnold_2d (typing).
Parameters
----------
H_expr : sp.Expr
xi_vars : tuple of sp.Symbol — (xi,) in 1D, (xi, eta) in 2D
coords : tuple of sp.Symbol, optional
Additional parameter symbols (e.g. spatial coordinates x, y).
When provided, catastrophes are tracked as families.
method : "symbolic" | "adaptive" | "auto"
bounds : dict {symbol: (lo, hi)}, optional
coarse_n : int — coarse grid resolution per axis
max_order : int — max derivative order for 1D classification
tol : float
Returns
-------
list of dict, each with:
- 'point' : dict {symbol: value}
- 'type' : str Arnold type
- 'details' : dict hessian / derivatives / invariants
"""
dim = len(xi_vars)
if dim not in (1, 2):
raise NotImplementedError("detect_catastrophes supports dim 1 or 2.")
solver = AdaptiveCriticalPointSolver(
H_expr, xi_vars, coords=coords,
bounds=bounds, coarse_n=coarse_n, grad_tol=tol
)
raw_points = solver.to_dict_list()
results = []
for pt_dict in raw_points:
if dim == 1:
xi = xi_vars[0]
xi_val = pt_dict.get(xi, None)
if xi_val is None:
continue
classification = classify_arnold_1d(
H_expr, xi, xi_val, max_order=max_order, tol=tol
)
results.append({
"point" : pt_dict,
"type" : classification["type"],
"details": classification,
})
else:
xi, eta = xi_vars[0], xi_vars[1]
if xi not in pt_dict or eta not in pt_dict:
continue
classification = classify_arnold_2d(
H_expr, xi, eta, pt_dict, tol=tol
)
results.append({
"point" : pt_dict,
"type" : classification["type"],
"details": classification,
})
return results
# ══════════════════════════════════════════════════════════════════════════════
# SECTION 3 — Ray-based caustic detection (geometric, corrected)
# ══════════════════════════════════════════════════════════════════════════════
[docs]
@dataclass
class CausticEvent:
"""
A caustic crossing event along a ray.
Attributes
----------
ray_idx : int Index of the ray in the bundle.
time : float Integration time t* where det(J) = 0.
time_idx: int Index in the time array closest to t*.
position: np.ndarray Spatial position x(t*) (shape (dim,)).
momentum: np.ndarray Momentum ξ(t*) (shape (dim,)).
det_J : float Value of det(J) at t* (close to 0).
sign_change: int +1 or -1 (sign of d(det J)/dt at crossing).
maslov_contribution: int +1 (= π/2 phase shift).
arnold_type: str Classification from classify_arnold_* (if available).
"""
ray_idx : int
time : float
time_idx : int
position : np.ndarray
momentum : np.ndarray
det_J : float
sign_change: int = 0
maslov_contribution: int = 1
arnold_type: str = "unknown"
[docs]
class RayCausticDetector:
"""
Detect caustics in a ray bundle by tracking the stability matrix J.
This class implements the correct geometric condition for caustics:
a zero of the determinant of the stability matrix J = ∂x/∂q, where q
are the initial ray parameters. It replaces older heuristics such as
velocity zeros or momentum minima, which are incorrect for caustic
detection.
The detector optionally classifies each caustic event into an Arnold
type (fold, cusp, etc.) if a symbolic Hamiltonian is supplied. It
also computes the Maslov index (signed number of caustic crossings)
and the associated phase shift exp(i μ π/2).
Mathematical background
-----------------------
For a Hamiltonian H(x,ξ) and a family of rays parameterised by an
initial‑condition vector q ∈ ℝⁿ, the stability matrix
J(t) = ∂x(t)/∂q
satisfies the variational equation
dJ/dt = H_{px}(x(t), ξ(t)) · J , J(0) = Iₙ ,
where H_{px} = ∂²H/∂ξ∂x is the mixed Hessian (n×n matrix in n
dimensions). A caustic occurs at time t* where det(J(t*)) = 0.
The Maslov index μ is the signed count of such zero crossings; each
crossing contributes a phase shift of π/2 to the semiclassical
wavefunction.
Important
---------
The ray bundle passed to the constructor **must** contain the
integrated stability matrix. In 1D, the ray dict must have the key
`'J'` (a 1‑D array of length equal to `'t'`). In 2D, the keys
`'J11'`, `'J12'`, `'J21'`, `'J22'` are required. The helper methods
`stability_matrix_ode_1d` and `stability_matrix_ode_2d` can be used
to incorporate the J ODE into the ray integrator.
Parameters
----------
ray_bundle : list of dict
List of integrated rays. Each dict contains at least the fields:
- 't' : ndarray, shape (n_steps,) – integration times.
- 'x', 'xi' : ndarray (1D) or 'x','y','xi','eta' (2D) – position
and momentum.
- Stability matrix entries:
* 1D: 'J' : ndarray, shape (n_steps,)
* 2D: 'J11','J12','J21','J22' : ndarray, shape (n_steps,)
If any of these keys is missing, the detector will raise a KeyError.
dimension : int
1 or 2.
det_threshold : float, default 0.05
A zero crossing of det(J) is registered when |det(J)| falls below
`det_threshold * |det(J(0))|`. Increase this value if caustics
are missed; decrease it to avoid spurious detections.
H_expr : sympy.Expr, optional
Symbolic Hamiltonian. If provided, the Arnold type of each
caustic is determined by `classify_arnold_1d` (1D) or
`classify_arnold_2d` (2D) using the momentum variables.
xi_syms : tuple of sympy.Symbol, optional
Momentum symbols (e.g., (xi,) in 1D, (xi, eta) in 2D). Required
if `H_expr` is given.
x_syms : tuple of sympy.Symbol, optional
Spatial symbols (e.g., (x,) in 1D, (x, y) in 2D). Currently
unused, but kept for future extensions.
Attributes
----------
events : list of CausticEvent
Detected caustic events after calling `detect()`. Each event
contains the time, position, momentum, sign change, Maslov
contribution, and (if possible) the Arnold type.
"""
def __init__(self,
ray_bundle: List[Dict],
dimension: int,
det_threshold: float = 0.05,
H_expr: Optional[sp.Expr] = None,
xi_syms: Optional[Tuple] = None,
x_syms : Optional[Tuple] = None):
self.rays = ray_bundle
self.dimension = dimension
self.det_threshold = det_threshold
self.H_expr = H_expr
self.xi_syms = xi_syms
self.x_syms = x_syms
self._events: List[CausticEvent] = []
# ── Stability matrix ─────────────────────────────────────────
def _det_J_1d(self, ray: Dict) -> np.ndarray:
"""
Extract or estimate det(J) = J for a 1D ray.
If 'J' key is present: use directly.
Else: estimate by finite difference with a neighbouring ray
if the bundle contains it, otherwise raise.
"""
if 'J' in ray:
return np.asarray(ray['J'], dtype=float)
raise KeyError(
"Ray dict is missing key 'J' (stability matrix). "
"Integrate J alongside the ray ODE: dJ/dt = (∂²H/∂ξ∂x)·J, J(0)=1."
)
def _det_J_2d(self, ray: Dict) -> np.ndarray:
"""
Extract or compute det(J) for a 2D ray.
Expects keys 'J11','J12','J21','J22' in the ray dict.
"""
if all(k in ray for k in ('J11', 'J12', 'J21', 'J22')):
J11 = np.asarray(ray['J11'], dtype=float)
J12 = np.asarray(ray['J12'], dtype=float)
J21 = np.asarray(ray['J21'], dtype=float)
J22 = np.asarray(ray['J22'], dtype=float)
return J11 * J22 - J12 * J21
raise KeyError(
"Ray dict is missing stability matrix keys 'J11'..'J22'. "
"Integrate the 2×2 matrix ODE dJ/dt = H_px · J alongside the ray."
)
[docs]
@staticmethod
def stability_matrix_ode_1d(H_px_func):
"""
Return the RHS for the 1D stability ODE to be integrated with the ray.
Usage in your ray integrator
----------------------------
Add a state component J (scalar, initially 1) to the ray ODE.
The extra equation is:
dJ/dt = H_px(x(t), ξ(t)) · J
where H_px = ∂²H/∂ξ∂x (scalar in 1D).
Parameters
----------
H_px_func : callable (x, xi) → float
Lambdified ∂²H/∂ξ∂x.
Returns
-------
callable (t, J_val, x_val, xi_val) → dJ/dt
"""
def rhs(t, J_val, x_val, xi_val):
return H_px_func(x_val, xi_val) * J_val
return rhs
[docs]
@staticmethod
def stability_matrix_ode_2d(H_px_func):
"""
Return the RHS for the 2D stability ODE.
The 2×2 matrix J satisfies dJ/dt = H_{px} · J
where H_{px} = [[∂²H/∂ξ∂x, ∂²H/∂ξ∂y],[∂²H/∂η∂x, ∂²H/∂η∂y]].
Pack J as a flat vector [J11, J12, J21, J22] in the state.
Parameters
----------
H_px_func : callable (x, y, xi, eta) → 2×2 array
Lambdified H_{px} matrix.
Returns
-------
callable (t, J_flat, x, y, xi, eta) → dJ_flat/dt (4-vector)
"""
def rhs(t, J_flat, x, y, xi, eta):
J = J_flat.reshape(2, 2)
Hpx = np.asarray(H_px_func(x, y, xi, eta), dtype=float).reshape(2, 2)
dJ = Hpx @ J
return dJ.ravel()
return rhs
# ── Detection ────────────────────────────────────────────────
def _find_sign_changes(self, arr: np.ndarray,
t_arr: np.ndarray,
threshold: float) -> List[Tuple[int, float]]:
"""
Find indices where |arr| crosses below threshold
and changes sign, indicating det(J) = 0.
Returns list of (index, interpolated_time).
"""
events = []
norm0 = abs(arr[0]) if abs(arr[0]) > 1e-12 else 1.0
scaled = arr / norm0
for i in range(len(scaled) - 1):
a, b = scaled[i], scaled[i + 1]
if abs(a) < threshold or abs(b) < threshold:
continue
if a * b < 0: # sign change
# Linear interpolation for zero crossing time
t_cross = t_arr[i] + (t_arr[i+1] - t_arr[i]) * abs(a) / (abs(a) + abs(b))
events.append((i, t_cross))
return events
[docs]
def detect(self) -> List[CausticEvent]:
"""
Detect all caustic events in the ray bundle.
Returns
-------
list of CausticEvent
"""
self._events = []
for ray_idx, ray in enumerate(self.rays):
t = np.asarray(ray['t'], dtype=float)
try:
if self.dimension == 1:
det_J = self._det_J_1d(ray)
x_arr = np.asarray(ray['x'], dtype=float)
xi_arr = np.asarray(ray['xi'], dtype=float)
crossings = self._find_sign_changes(det_J, t, self.det_threshold)
for idx, t_cross in crossings:
sign = int(np.sign(det_J[idx + 1] - det_J[idx]))
ev = CausticEvent(
ray_idx = ray_idx,
time = t_cross,
time_idx = idx,
position = np.array([x_arr[idx]]),
momentum = np.array([xi_arr[idx]]),
det_J = float(det_J[idx]),
sign_change = sign,
maslov_contribution = 1,
)
if self.H_expr is not None and self.xi_syms:
ev.arnold_type = self._classify_at_event_1d(ev)
self._events.append(ev)
else: # dim == 2
det_J = self._det_J_2d(ray)
x_arr = np.asarray(ray['x'], dtype=float)
y_arr = np.asarray(ray['y'], dtype=float)
xi_arr = np.asarray(ray['xi'], dtype=float)
eta_arr= np.asarray(ray['eta'], dtype=float)
crossings = self._find_sign_changes(det_J, t, self.det_threshold)
for idx, t_cross in crossings:
sign = int(np.sign(det_J[idx + 1] - det_J[idx]))
ev = CausticEvent(
ray_idx = ray_idx,
time = t_cross,
time_idx = idx,
position = np.array([x_arr[idx], y_arr[idx]]),
momentum = np.array([xi_arr[idx], eta_arr[idx]]),
det_J = float(det_J[idx]),
sign_change = sign,
maslov_contribution = 1,
)
if self.H_expr is not None and self.xi_syms:
ev.arnold_type = self._classify_at_event_2d(ev)
self._events.append(ev)
except KeyError as e:
warnings.warn(f"Ray {ray_idx}: {e}", UserWarning)
continue
return self._events
def _classify_at_event_1d(self, ev: CausticEvent) -> str:
"""Call classify_arnold_1d at the caustic position."""
if self.H_expr is None or not self.xi_syms:
return "unknown"
xi_sym = self.xi_syms[0]
xi_val = float(ev.momentum[0])
try:
res = classify_arnold_1d(self.H_expr, xi_sym, xi_val)
return res["type"]
except Exception:
return "classification failed"
def _classify_at_event_2d(self, ev: CausticEvent) -> str:
"""Call classify_arnold_2d at the caustic position."""
if self.H_expr is None or not self.xi_syms or len(self.xi_syms) < 2:
return "unknown"
xi_sym, eta_sym = self.xi_syms[0], self.xi_syms[1]
pt = {"xi": float(ev.momentum[0]), "eta": float(ev.momentum[1])}
try:
res = classify_arnold_2d(self.H_expr, xi_sym, eta_sym, pt)
return res["type"]
except Exception:
return "classification failed"
# ── Maslov index ─────────────────────────────────────────────
[docs]
def maslov_index(self, ray_idx: int) -> int:
"""
Return the Maslov index for ray ray_idx.
The Maslov index is the signed count of caustic crossings:
μ = Σ_k sign(d det(J)/dt at t_k*)
Each crossing contributes a phase shift of π/2 to the WKB amplitude.
Parameters
----------
ray_idx : int
Returns
-------
int (total signed count; use abs() for the number of crossings)
"""
if not self._events:
self.detect()
return sum(ev.maslov_contribution * ev.sign_change
for ev in self._events if ev.ray_idx == ray_idx)
[docs]
def maslov_phase(self, ray_idx: int) -> float:
"""
Total Maslov phase = μ · π/2 for ray ray_idx.
"""
return self.maslov_index(ray_idx) * np.pi / 2
[docs]
def summary(self) -> Dict:
"""Return a summary dict of detected caustic events."""
if not self._events:
self.detect()
by_type = {}
for ev in self._events:
by_type.setdefault(ev.arnold_type, []).append(ev)
return {
"n_events" : len(self._events),
"by_ray" : {i: [e for e in self._events if e.ray_idx == i]
for i in set(e.ray_idx for e in self._events)},
"by_type" : by_type,
"maslov_indices": {
i: self.maslov_index(i)
for i in set(e.ray_idx for e in self._events)
},
}
# ══════════════════════════════════════════════════════════════════════════════
# SECTION 4 — Special functions for uniform corrections
# ══════════════════════════════════════════════════════════════════════════════
[docs]
class CausticFunctions:
"""
Special functions for uniform asymptotic corrections near caustics.
All functions include the correct normalisation prefactors for the
uniform WKB approximation (Maslov-Fedoriuk / Duistermaat).
.. list-table:: Special Functions for Uniform Asymptotic Corrections
:widths: 25 75
:header-rows: 1
* - Caustic Type
- Function Description
* - Fold (A2)
- Airy function Ai
* - Cusp (A3)
- Pearcey integral P(x,y)
* - Swallowtail (A4)
- SW integral (3-parameter oscillatory)
"""
# ── Fold : Airy ───────────────────────────────────────────────
[docs]
@staticmethod
def airy_Ai(z: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
"""
Airy function Ai(z).
Fold caustic uniform approximation:
u(x) ≈ 2√π · ε^{1/6} · |∂_s J|^{-1/2} · Ai(-ε^{-2/3} ζ(x)) · e^{iS_c/ε}
where ζ(x) is the local coordinate measuring distance to the caustic.
"""
if not _HAS_SCIPY:
raise RuntimeError("scipy required for Airy function.")
ai, _, _, _ = airy(z)
return ai
[docs]
@staticmethod
def airy_Ai_prime(z: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
"""Derivative Ai'(z) of the Airy function."""
if not _HAS_SCIPY:
raise RuntimeError("scipy required.")
_, aip, _, _ = airy(z)
return aip
[docs]
@staticmethod
def airy_Bi(z: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
"""Second Airy function Bi(z) (exponentially growing branch)."""
if not _HAS_SCIPY:
raise RuntimeError("scipy required.")
_, _, bi, _ = airy(z)
return bi
# ── Cusp : Pearcey ───────────────────────────────────────────
[docs]
@staticmethod
def pearcey(x: float, y: float,
t_range: float = 6.0,
n_pts: int = 500) -> complex:
"""
Pearcey integral for cusp caustic (A3):
P(x, y) = ∫_{-∞}^{∞} exp(i(t⁴ + x t² + y t)) dt
Evaluated by adaptive quadrature on [-t_range, t_range].
Parameters
----------
x, y : control parameters
t_range : integration half-width (default 6.0; increase for large |x|,|y|)
n_pts : number of quadrature points
Returns
-------
complex
"""
t = np.linspace(-t_range, t_range, n_pts)
dt = t[1] - t[0]
phase = t**4 + x * t**2 + y * t
return complex(np.trapezoid(np.exp(1j * phase), dx=dt))
[docs]
@staticmethod
def pearcey_grid(X: np.ndarray, Y: np.ndarray,
t_range: float = 6.0,
n_pts: int = 500) -> np.ndarray:
"""
Vectorised Pearcey integral over 2D grids X, Y.
Returns complex array of same shape as X.
"""
shape = X.shape
result = np.empty(shape, dtype=complex)
for idx in np.ndindex(shape):
result[idx] = CausticFunctions.pearcey(
float(X[idx]), float(Y[idx]), t_range=t_range, n_pts=n_pts
)
return result
# ── Swallowtail (A4) ─────────────────────────────────────────
[docs]
@staticmethod
def swallowtail(x: float, y: float, z: float,
t_range: float = 5.0,
n_pts: int = 400) -> complex:
"""
Swallowtail integral for A4 caustic:
SW(x,y,z) = ∫_{-∞}^{∞} exp(i(t⁵ + x t³ + y t² + z t)) dt
Parameters
----------
x, y, z : control parameters
t_range, n_pts : quadrature parameters
Returns
-------
complex
"""
t = np.linspace(-t_range, t_range, n_pts)
dt = t[1] - t[0]
phase = t**5 + x * t**3 + y * t**2 + z * t
return complex(np.trapezoid(np.exp(1j * phase), dx=dt))
# ── Phase shift from Maslov index ─────────────────────────────
[docs]
@staticmethod
def maslov_phase_shift(mu: int) -> complex:
"""
WKB phase correction factor from Maslov index μ.
Each caustic crossing (each zero of det J) contributes a phase
shift of π/2. The total correction factor is exp(i μ π/2).
Parameters
----------
mu : int Maslov index (number of signed caustic crossings).
Returns
-------
complex exp(i μ π/2)
"""
return np.exp(1j * mu * np.pi / 2)
# ══════════════════════════════════════════════════════════════════════════════
# SECTION 5 — Visualisation
# ══════════════════════════════════════════════════════════════════════════════
[docs]
def plot_catastrophe(H: sp.Expr,
xi_vars: Sequence[sp.Symbol],
points: List[Dict],
xi_bounds: Tuple[float, float] = (-3.0, 3.0),
eta_bounds: Tuple[float, float] = (-3.0, 3.0),
n: int = 300,
title: Optional[str] = None) -> None:
"""
Plot H and mark classified catastrophe points.
Parameters
----------
H : sympy.Expr
xi_vars : (xi,) or (xi, eta)
points : list of dicts as returned by detect_catastrophes()
each must have keys 'point' and 'type'
xi_bounds, eta_bounds : axis ranges
n : grid resolution
title : optional plot title
"""
if not _HAS_MPL:
raise RuntimeError("matplotlib not available.")
dim = len(xi_vars)
_TYPE_COLORS = {
"A2 (Fold)" : "red",
"A3 (Cusp)" : "orange",
"A4 (Swallowtail)" : "purple",
"A5 (Butterfly)" : "blue",
"D4+ (Hyperbolic umbilic)": "green",
"D4- (Elliptic umbilic)" : "cyan",
}
if dim == 1:
xi = xi_vars[0]
X = np.linspace(xi_bounds[0], xi_bounds[1], n)
Hf = sp.lambdify(xi, H, "numpy")
Y = np.asarray(Hf(X), dtype=float)
fig, ax = _plt.subplots(figsize=(9, 5))
ax.plot(X, Y, "k-", lw=1.5, label="H(ξ)")
for p in points:
pt = p["point"]
typ = p["type"]
xv = float(pt.get(xi, 0.0))
yv = float(Hf(xv))
col = _TYPE_COLORS.get(typ, "gray")
ax.scatter([xv], [yv], color=col, s=80, zorder=5)
ax.annotate(typ, (xv, yv), textcoords="offset points",
xytext=(5, 5), fontsize=8, color=col)
ax.set_xlabel("ξ")
ax.set_ylabel("H(ξ)")
ax.set_title(title or "Catastrophe plot (1D)")
ax.grid(True, alpha=0.3)
ax.legend()
_plt.tight_layout()
_plt.show()
elif dim == 2:
xi, eta = xi_vars
Xv = np.linspace(xi_bounds[0], xi_bounds[1], n // 2)
Yv = np.linspace(eta_bounds[0], eta_bounds[1], n // 2)
XX, YY = np.meshgrid(Xv, Yv)
Hf = sp.lambdify((xi, eta), H, "numpy")
ZZ = np.asarray(Hf(XX, YY), dtype=float)
fig = _plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection="3d")
ax.plot_surface(XX, YY, ZZ, alpha=0.55, rstride=3, cstride=3,
cmap="viridis")
for p in points:
pt = p["point"]
typ = p["type"]
xv = float(pt.get(xi, 0.0))
yv = float(pt.get(eta, 0.0))
zv = float(Hf(xv, yv))
col = _TYPE_COLORS.get(typ, "gray")
ax.scatter([xv], [yv], [zv], color=col, s=80, zorder=5)
ax.text(xv, yv, zv + 0.1, typ, fontsize=7, color=col)
ax.set_xlabel("ξ")
ax.set_ylabel("η")
ax.set_zlabel("H")
ax.set_title(title or "Catastrophe surface (2D)")
_plt.tight_layout()
_plt.show()
else:
raise NotImplementedError("plot_catastrophe supports only dim 1 or 2.")
[docs]
def plot_caustic_events(ray_bundle: List[Dict],
events: List[CausticEvent],
dimension: int,
n_rays_plot: int = 20,
title: Optional[str] = None) -> None:
"""
Overlay caustic events on the ray bundle plot.
Parameters
----------
ray_bundle : list of ray dicts (with 'x', 't' etc.)
events : list of CausticEvent from RayCausticDetector.detect()
dimension : 1 or 2
n_rays_plot: max number of rays to draw
"""
if not _HAS_MPL:
raise RuntimeError("matplotlib not available.")
fig, ax = _plt.subplots(figsize=(10, 6))
n_total = len(ray_bundle)
step = max(1, n_total // n_rays_plot)
indices = range(0, n_total, step)
_TYPE_COLORS = {
"A2 (Fold)" : "red",
"A3 (Cusp)" : "orange",
"A4 (Swallowtail)" : "purple",
"A5 (Butterfly)" : "blue",
}
if dimension == 1:
for i in indices:
ray = ray_bundle[i]
ax.plot(ray['t'], ray['x'], 'b-', alpha=0.3, lw=0.8)
for ev in events:
col = _TYPE_COLORS.get(ev.arnold_type, "red")
ax.scatter([ev.time], [ev.position[0]],
color=col, s=60, zorder=5,
label=ev.arnold_type)
ax.set_xlabel("t")
ax.set_ylabel("x(t)")
else: # 2D
for i in indices:
ray = ray_bundle[i]
ax.plot(ray['x'], ray['y'], 'b-', alpha=0.3, lw=0.8)
for ev in events:
col = _TYPE_COLORS.get(ev.arnold_type, "red")
ax.scatter([ev.position[0]], [ev.position[1]],
color=col, s=60, zorder=5,
label=ev.arnold_type)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_aspect("equal")
# Deduplicate legend
handles, labels = ax.get_legend_handles_labels()
seen = {}
for h, l in zip(handles, labels):
seen.setdefault(l, h)
ax.legend(seen.values(), seen.keys(), fontsize=8)
ax.set_title(title or f"Ray bundle with caustic events (dim={dimension})")
ax.grid(True, alpha=0.3)
_plt.tight_layout()
_plt.show()