Source code for wkb

# Copyright 2026 Philippe Billet assisted by LLMs in free mode: chatGPT, Qwen, Deepseek, Gemini, Claude, le chat Mistral.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
wkb.py — Multidimensional WKB approximation with caustic corrections
==========================================================================

Overview
--------
The `wkb` module provides a comprehensive implementation of the Wentzel–Kramers–Brillouin (WKB) method for constructing asymptotic solutions to linear partial differential equations of the form

    P(x, –iε∇) u(x) = 0,  ε → 0,

where `P(x,ξ)` is the (pseudo‑differential) symbol of the operator.  The solution is sought as a sum over rays (bicharacteristics):

    u(x) ≈ ∑_j A_j(x) e^{i S_j(x)/ε},

with the phase `S` satisfying the eikonal equation `P(x,∇S)=0` and the amplitudes `A_j` determined by transport equations along the rays.

Key features:

* **Automatic dimension detection** – works for both 1D and 2D problems without changing the calling interface.
* **Ray tracing** – integrates Hamilton’s equations for the bicharacteristics, including the variational equations for the stability matrix `J` (used to detect caustics).
* **Multi‑order amplitude transport** – computes amplitudes up to arbitrary order (0,1,2,3,…) by solving ODEs along each ray.
* **Caustic detection and correction** – monitors `det(J)` to locate caustics (folds, cusps) and applies:
    * Maslov phase shifts (π/2 per caustic),
    * Uniform Airy (fold) or Pearcey (cusp) corrections near caustics.
* **Interpolation onto regular grids** – uses `scipy.interpolate` (linear, `griddata`) to map the ray‑based solution to a uniform spatial grid.
* **Rich visualisation suite** – phase portraits, amplitude decomposition, caustic analysis, ray plots, comparison of WKB orders.
* **Utilities for generating initial data** – line segments, circles, point sources.

Mathematical background
-----------------------
The WKB ansatz `u = e^{iS/ε} (a₀ + ε a₁ + ε² a₂ + …)` is inserted into the equation `P(x,–iε∇)u = 0`.  Expanding in powers of ε yields:

* **Order ε⁰ (eikonal equation):**  `P(x, ∇S) = 0`.  This is a Hamilton–Jacobi equation solved by the method of characteristics (rays).  The rays satisfy

        dx/dt = ∂P/∂ξ,  dξ/dt = –∂P/∂x,  dS/dt = ξ·∂P/∂ξ – P.

* **Order ε¹ (transport equation for a₀):**  `(∂P/∂ξ)·∇a₀ + ½ (∇_ξ·∇_x P) a₀ = 0`.  Along a ray this becomes an ODE for the amplitude.

* **Higher orders:**  Similar ODEs for `a₁, a₂, …` involving derivatives of `P` up to order three.

The **stability matrix** `J = ∂x(t)/∂q` (where `q` parametrises the initial data) measures the focusing of nearby rays.  Its determinant vanishes at **caustics**.  Near a fold caustic the standard WKB amplitude blows up; the correct uniform approximation involves the Airy function.  Near a cusp, the Pearcey integral is required.  The module implements both corrections using the companion `caustics` module.


References
----------
.. [1] Maslov, V. P. & Fedoriuk, M. V.  *Semi‑Classical Approximation in Quantum Mechanics*, Reidel, 1981.
.. [2] Duistermaat, J. J.  “Oscillatory integrals, Lagrange immersions and unfolding of singularities”, *Comm. Pure Appl. Math.* 27, 207–281, 1974.
.. [3] Berry, M. V. & Howls, C. J.  “High orders of the Weyl expansion for quantum billiards”, *Phys. Rev. E* 50(5), 3577–3595, 1994.
.. [4] Ludwig, D.  “Uniform asymptotic expansions at a caustic”, *Comm. Pure Appl. Math.* 19, 215–250, 1966.
.. [5] Kravtsov, Yu. A. & Orlov, Yu. I.  *Caustics, Catastrophes and Wave Fields*, Springer, 1999.
"""

from imports import *
from caustics import * 
# ==================================================================
# ENHANCED WKB WITH CAUSTIC CORRECTIONS
# ==================================================================

[docs] def wkb_approximation(symbol, initial_phase, order=1, domain=None, resolution=50, epsilon=0.1, dimension=None, caustic_correction='auto', caustic_threshold=1e-3): """ Compute multidimensional WKB approximation (1D or 2D). u(x) ≈ exp(iS/ε) · [a₀ + ε·a₁ + ε²·a₂ + ...] Automatically detects dimension from initial_phase or uses dimension parameter. Parameters ---------- symbol : sympy expression Principal symbol p(x, ξ) for 1D or p(x, y, ξ, η) for 2D. initial_phase : dict Initial data on a curve/point: 1D: Keys 'x', 'S', 'p_x', optionally 'a' (dict or array) 2D: Keys 'x', 'y', 'S', 'p_x', 'p_y', optionally 'a' (dict or array) order : int WKB order (0, 1, 2, or 3). domain : tuple or None 1D: (x_min, x_max) 2D: ((x_min, x_max), (y_min, y_max)) If None, inferred from initial data. resolution : int or tuple Grid resolution (single int or (nx, ny) for 2D). epsilon : float Small parameter for asymptotic expansion. dimension : int or None Force dimension (1 or 2). If None, auto-detect. Returns ------- dict WKB solution with keys adapted to dimension. Examples -------- >>> from sympy import symbols, sqrt >>> >>> # 1D harmonic oscillator >>> x, xi = symbols('x xi', real=True) >>> symbol = xi**2 + x**2 # p(x,ξ) = ξ² + x² >>> >>> # Initial conditions: Gaussian wave packet >>> n_rays = 20 >>> x0 = np.linspace(-2, 2, n_rays) >>> initial = { ... 'x': x0, ... 'p_x': np.ones(n_rays), # momentum = 1 ... 'S': 0.5 * x0**2, # phase ... 'a': np.exp(-x0**2) # Gaussian amplitude ... } >>> >>> result = wkb_approximation( ... symbol, initial, order=2, epsilon=0.05, resolution=100 ... ) >>> >>> # Extract solution >>> x_grid = result['x'] >>> u = result['u'] >>> plt.plot(x_grid, np.abs(u)) >>> # 2D wave equation >>> x, y, xi, eta = symbols('x y xi eta', real=True) >>> c = 1.0 # wave speed >>> symbol = xi**2 + eta**2 - c**2 >>> >>> # Point source initial conditions >>> theta = np.linspace(0, 2*np.pi, 30, endpoint=False) >>> r0 = 0.5 >>> initial = { ... 'x': r0 * np.cos(theta), ... 'y': r0 * np.sin(theta), ... 'p_x': np.cos(theta), ... 'p_y': np.sin(theta), ... 'S': np.zeros(30), ... 'a': np.ones(30) ... } >>> >>> result = wkb_approximation( ... symbol, initial, order=1, epsilon=0.1, resolution=(80, 80) ... ) >>> >>> # Visualize 2D solution >>> X, Y = result['x'], result['y'] >>> U = result['u'] >>> plt.pcolormesh(X, Y, np.abs(U)) """ base_solution = _compute_base_wkb(symbol, initial_phase, order, domain, resolution, epsilon, dimension) # Detect caustics detector = RayCausticDetector(base_solution['rays'], base_solution['dimension'], det_threshold=caustic_threshold) caustics = detector.detect() if len(caustics) == 0: print("No caustics detected - using standard WKB") base_solution['caustic_correction'] = 'none' base_solution['caustics'] = [] return base_solution print(f"\nApplying caustic corrections (mode: {caustic_correction})...") # Apply corrections based on caustic type if base_solution['dimension'] == 1: corrected_solution = _apply_1d_caustic_corrections( base_solution, caustics, epsilon, caustic_correction ) else: corrected_solution = _apply_2d_caustic_corrections( base_solution, caustics, epsilon, caustic_correction ) corrected_solution['caustics'] = caustics corrected_solution['caustic_correction'] = caustic_correction print("Caustic corrections applied successfully") return corrected_solution
def _compute_base_wkb(symbol, initial_phase, order=1, domain=None, resolution=50, epsilon=0.1, dimension=None): """ Compute multidimensional WKB (Wentzel-Kramers-Brillouin) approximation for wave propagation. Constructs an asymptotic solution to a pseudo-differential equation using the WKB method, which represents the solution as an oscillatory phase multiplied by a slowly varying amplitude: u(x, ε) = exp(iS(x)/ε) · Σₖ₌₀ⁿ εᵏ aₖ(x) where S(x) is the eikonal (phase function), aₖ(x) are transport amplitudes, and ε is a small parameter representing the inverse wavelength. The method traces bicharacteristic rays through phase space using Hamilton's equations and solves transport equations along these rays to determine amplitude evolution. This implementation supports both 1D and 2D spatial domains and can compute multi-order corrections (order 0 through 3+) to improve accuracy beyond the standard semiclassical limit. The stability matrix J(t) is co-integrated alongside each ray via: dJ/dt = H_px · J, J(0) = I where H_px[i,j] = ∂²p/∂ξᵢ∂xⱼ is the mixed Hessian of the Hamiltonian. det(J) → 0 signals a caustic; the ray dict exposes J11 (..J22 in 2D) for downstream use by RayCausticDetector. Mathematical Framework ---------------------- Given a symbol p(x, ξ) defining the pseudo-differential operator, the WKB method solves: 1. **Eikonal equation** (determines phase): p(x, ∇S(x)) = 0 2. **Transport equations** (determine amplitudes): ∇ₓp·∇a₀ + (1/2)a₀∇ξ·∇ₓp = 0 (order 0) [Higher-order corrections for k ≥ 1] The rays are computed via Hamilton's equations: dx/dt = ∂p/∂ξ, dξ/dt = -∂p/∂x dS/dt = ξ·∂p/∂ξ - p Parameters ---------- symbol : sympy.Expr Symbolic expression for the principal symbol p(x, ξ) in 1D or p(x, y, ξ, η) in 2D. Must be written in terms of SymPy symbols and represent the dispersion relation of the wave equation. Common examples: - Schrödinger: ξ² + V(x) - Wave equation: ξ² + η² - ω² - Helmholtz: ξ² - k²(x) initial_phase : dict Initial conditions for ray tracing at t=0. Must contain: **Required keys (1D)**: - 'x' : array_like, shape (n_rays,) Initial spatial positions - 'p_x' : array_like, shape (n_rays,) Initial momenta (ξ values) - 'S' : array_like, shape (n_rays,) Initial phase values **Required keys (2D)**: - 'x', 'y' : array_like, shape (n_rays,) Initial spatial positions - 'p_x', 'p_y' : array_like, shape (n_rays,) Initial momenta (ξ, η values) - 'S' : array_like, shape (n_rays,) Initial phase values **Optional key**: - 'a' : array_like or dict Initial amplitude values. Can be: - Single array (n_rays,) → interpreted as a₀ - Dict {0: a₀, 1: a₁, ...} → multi-order amplitudes If omitted, a₀ defaults to ones, higher orders to zeros. order : int, default=1 Maximum order of the asymptotic expansion. Higher orders include progressively smaller corrections proportional to εᵏ: - 0: Leading-order WKB (geometric optics) - 1: First correction (includes caustic pre-factors) - 2: Second correction (third-derivative terms) - 3+: Higher corrections (simplified expressions) domain : tuple or None, default=None Spatial domain for the output grid: - 1D: (x_min, x_max) - 2D: ((x_min, x_max), (y_min, y_max)) If None, automatically determined from ray extent with 10% margin. resolution : int or tuple, default=50 Grid resolution for interpolated output: - int: Same resolution in all dimensions - tuple: (nx,) for 1D or (nx, ny) for 2D epsilon : float, default=0.1 Small parameter ε representing the inverse wavelength or semiclassical parameter. The solution is accurate when ε → 0. Typical values: 0.01 to 0.5. dimension : int or None, default=None Spatial dimension (1 or 2). If None, automatically detected from initial_phase: - Presence of 'y' and 'p_y' keys → dimension = 2 - Otherwise → dimension = 1 Returns ------- result : dict Dictionary containing the computed WKB solution and diagnostic information: **Solution fields**: - 'u' : ndarray, shape (nx,) or (nx, ny), complex Complete WKB approximation: u = exp(iS/ε) · Σₖ εᵏaₖ - 'S' : ndarray, shape (nx,) or (nx, ny), float Interpolated phase function (eikonal) - 'a' : dict of ndarrays Individual amplitude orders: {0: a₀, 1: a₁, ..., order: aₙ} - 'a_total' : ndarray, complex Total amplitude: Σₖ εᵏaₖ (without phase factor) **Grid information**: - 'x' : ndarray Spatial grid in x direction (1D: shape (nx,), 2D: shape (nx, ny)) - 'y' : ndarray (2D only) Spatial grid in y direction, shape (nx, ny) **Ray tracing data**: - 'rays' : list of dict Each dict contains traced ray data with keys: 't', 'x', ['y'], 'xi', ['eta'], 'S', 'a0', 'a1', ..., 'J11' (all dims), 'J12', 'J21', 'J22' (2D only). - 'n_rays' : int Number of successfully traced rays **Metadata**: - 'dimension' : int Spatial dimension used (1 or 2) - 'order' : int Order of asymptotic expansion - 'epsilon' : float Small parameter value - 'domain' : tuple Spatial domain used for output grid Raises ------ ValueError - If dimension is not 1 or 2 - If required keys are missing from initial_phase - If array lengths are inconsistent (e.g., len(x) ≠ len(y) in 2D) RuntimeError - If all rays fail to integrate (numerical instability) Notes ----- **Validity and Limitations**: 1. **Small ε assumption**: The WKB method is asymptotic in ε → 0. Results become inaccurate for ε > 1 or near caustics where rays focus. 2. **Caustics**: At caustics (where rays cross), det(J) → 0 and the amplitude formally diverges. The stability matrix J is tracked per-ray for caustic detection by RayCausticDetector. 3. **Turning points**: Classical turning points (where p = 0) require connection formulas not implemented here. The method works best away from such points. 4. **Interpolation artifacts**: The solution is interpolated from discrete rays onto a regular grid. Ensure sufficient ray density (n_rays) and resolution to avoid aliasing or gaps. **Numerical Implementation**: - Ray integration uses scipy.integrate.solve_ivp with RK45 (4th/5th order Runge-Kutta) - Default tolerances: rtol=1e-6, atol=1e-9 - Integration time: t ∈ [0, 5] with 100 evaluation points per ray - 1D interpolation: linear (interp1d) - 2D interpolation: linear scattered data (griddata) **Computational Complexity**: - Ray tracing: O(n_rays × n_steps × n_derivatives) - Grid interpolation: O(n_rays × n_steps × resolution^dimension) - Memory: O(resolution^dimension) for output arrays References ---------- .. [1] Maslov, V.P. and Fedoriuk, M.V. (1981). "Semi-Classical Approximation in Quantum Mechanics". Springer. .. [2] Ralston, J. (1982). "Gaussian beams and the propagation of singularities". Studies in PDE, MAA Studies in Mathematics, 23, 206-248. .. [3] Hörmander, L. (1985). "The Analysis of Linear Partial Differential Operators III". Springer-Verlag. See Also -------- scipy.integrate.solve_ivp : ODE solver used for ray tracing scipy.interpolate.griddata : 2D interpolation method numpy.fft : For comparison with spectral methods """ from scipy.integrate import solve_ivp from scipy.interpolate import griddata, interp1d # ================================================================== # DETECT DIMENSION # ================================================================== if dimension is None: # Auto-detect from initial_phase has_y = 'y' in initial_phase and 'p_y' in initial_phase dimension = 2 if has_y else 1 if dimension not in [1, 2]: raise ValueError(f"Dimension must be 1 or 2, got {dimension}") print(f"WKB approximation in {dimension}D (order {order})") # ================================================================== # SETUP SYMBOLIC VARIABLES # ================================================================== if dimension == 1: x = symbols('x', real=True) xi = symbols('xi', real=True) spatial_vars = [x] momentum_vars = [xi] spatial_symbols = (x,) momentum_symbols = (xi,) all_vars = (x, xi) else: # dimension == 2 x, y = symbols('x y', real=True) xi, eta = symbols('xi eta', real=True) spatial_vars = [x, y] momentum_vars = [xi, eta] spatial_symbols = (x, y) momentum_symbols = (xi, eta) all_vars = (x, y, xi, eta) # ================================================================== # VALIDATE AND EXTRACT INITIAL DATA # ================================================================== required_keys_1d = ['x', 'S', 'p_x'] required_keys_2d = ['x', 'y', 'S', 'p_x', 'p_y'] required_keys = required_keys_2d if dimension == 2 else required_keys_1d if not all(k in initial_phase for k in required_keys): raise ValueError(f"initial_phase must contain: {required_keys}") # Extract spatial coordinates x_init = np.asarray(initial_phase['x']) n_rays = len(x_init) if dimension == 2: y_init = np.asarray(initial_phase['y']) if len(y_init) != n_rays: raise ValueError("x and y must have same length") # Extract phase and momentum S_init = np.asarray(initial_phase['S']) px_init = np.asarray(initial_phase['p_x']) if dimension == 2: py_init = np.asarray(initial_phase['p_y']) # Extract amplitudes for each order a_init = {} if 'a' in initial_phase: if isinstance(initial_phase['a'], dict): for k, v in initial_phase['a'].items(): a_init[k] = np.asarray(v) else: a_init[0] = np.asarray(initial_phase['a']) else: a_init[0] = np.ones(n_rays) # Initialize missing orders to zero for k in range(order + 1): if k not in a_init: a_init[k] = np.zeros(n_rays) # ================================================================== # COMPUTE SYMBOLIC DERIVATIVES # ================================================================== print("Computing symbolic derivatives...") derivatives = {} # First derivatives (Hamilton equations) for i, mom_var in enumerate(momentum_vars): derivatives[f'dp_d{mom_var.name}'] = diff(symbol, mom_var) for i, space_var in enumerate(spatial_vars): derivatives[f'dp_d{space_var.name}'] = diff(symbol, space_var) # Second derivatives (transport equations) for mom_var in momentum_vars: derivatives[f'd2p_d{mom_var.name}2'] = diff(symbol, mom_var, 2) for space_var in spatial_vars: for mom_var in momentum_vars: derivatives[f'd2p_d{mom_var.name}d{space_var.name}'] = \ diff(diff(symbol, mom_var), space_var) if len(momentum_vars) == 2: derivatives['d2p_dxideta'] = diff(diff(symbol, momentum_vars[0]), momentum_vars[1]) # ---------------------------------------------------------------- # STABILITY MATRIX: variational equations of the Hamiltonian flow # # J = dx(t)/dq, where q is the parameter along the initial curve. # Its ODE comes from differentiating Hamilton's equations w.r.t. q: # # 1D: # d/dt(dx/dq) = (d2p/dxi2)*dxi0/dq + (d2p/dxi/dx)*dx0/dq # d/dt(dxi/dq) = -(d2p/dx2) *dx0/dq - (d2p/dx/dxi)*dxi0/dq # # We track the POSITION component J = dx/dq only, initialised from # the geometry of the initial data curve (see ray-loop below). # The coupled variable K = dxi/dq is also integrated so that J # can evolve correctly through the (d2p/dxi2)*K term. # # Required second derivatives: # d2p/dxi2 (already present as d2p_dxi2) # d2p/dxi/dx (already present as d2p_dxidx) # d2p/dx2 NEW # For 2D, the full 2x2 block equivalents are needed. # ---------------------------------------------------------------- # Spatial second derivatives (variational equations) for space_var in spatial_vars: derivatives[f'd2p_d{space_var.name}2'] = diff(symbol, space_var, 2) if dimension == 2: derivatives['d2p_dxidy'] = diff(diff(symbol, xi), y) # d2p/dxi/dy derivatives['d2p_detadx'] = diff(diff(symbol, eta), x) # d2p/deta/dx derivatives['d2p_dxy'] = diff(diff(symbol, x), y) # d2p/dx/dy # d2p_dxidx -> d2p/dxi/dx (already present) # d2p_detady -> d2p/deta/dy (already present) # Third derivatives (higher-order corrections) if order >= 2: for mom_var in momentum_vars: derivatives[f'd3p_d{mom_var.name}3'] = diff(symbol, mom_var, 3) if dimension == 2: derivatives['d3p_dxi2deta'] = diff(diff(symbol, xi, 2), eta) derivatives['d3p_dxideta2'] = diff(diff(symbol, xi), eta, 2) derivatives['d3p_dxi2dx'] = diff(diff(symbol, xi, 2), x) derivatives['d3p_deta2dy'] = diff(diff(symbol, eta, 2), y) # Lambdify all derivatives print(f"Lambdifying {len(derivatives)} derivatives...") funcs = {} for name, expr in derivatives.items(): funcs[name] = lambdify(all_vars, expr, 'numpy') # Principal symbol funcs['p'] = lambdify(all_vars, symbol, 'numpy') # ================================================================== # HELPER FUNCTIONS FOR DERIVATIVES EVALUATION # ================================================================== def eval_func(name, *args): """Safely evaluate a function, handling dimension differences.""" if name in funcs: return funcs[name](*args) return 0.0 def compute_geometric_spreading(*args): """ Compute divergence of momentum gradient. 1D: d²p/dξ² 2D: d²p/dξ² + d²p/dη² """ if dimension == 1: return eval_func('d2p_dxi2', *args) else: return (eval_func('d2p_dxi2', *args) + eval_func('d2p_deta2', *args)) def compute_spatial_momentum_coupling(*args): """ Compute ∇_x · ∇_ξ p 1D: ∂²p/∂x∂ξ 2D: ∂²p/∂x∂ξ + ∂²p/∂y∂η """ if dimension == 1: return eval_func('d2p_dxidx', *args) else: return (eval_func('d2p_dxidx', *args) + eval_func('d2p_detady', *args)) # ================================================================== # STATE VECTOR LAYOUT # ================================================================== # # 1D: [x, xi, S, J11, K11, a0, a1, ...] # 0 1 2 3 4 5 6 # J11 = dx/dq, K11 = dxi/dq # # 2D: [x, y, xi, eta, S, J11,J12,J21,J22, K11,K12,K21,K22, a0, a1, ...] # 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 # Jij = dxi(t)/dqj, Kij = dxii(t)/dqj (i=spatial, xi=momentum) # ================================================================== if dimension == 1: idx_x, idx_xi, idx_S = 0, 1, 2 idx_J11 = 3 idx_K11 = 4 idx_a_start = 5 else: idx_x, idx_y, idx_xi, idx_eta, idx_S = 0, 1, 2, 3, 4 idx_J11, idx_J12, idx_J21, idx_J22 = 5, 6, 7, 8 idx_K11, idx_K12, idx_K21, idx_K22 = 9, 10, 11, 12 idx_a_start = 13 idx_a = {k: idx_a_start + k for k in range(order + 1)} # ================================================================== # ================================================================== # RAY TRACING — vectorised batch integration # ================================================================== # Instead of n_rays separate solve_ivp calls, we pack ALL rays into a # single flat state vector and call solve_ivp ONCE. The lambdified # functions already accept numpy arrays, so every derivative evaluation # operates on all rays simultaneously (numpy broadcasting). # # State: Z shape (n_rays, ns) where ns = components per ray. # Flat: z_flat = Z.ravel(), reconstructed as Z = z_flat.reshape(n_rays, ns). # On exit we slice Z back into per-ray dicts for the rest of the pipeline. # ================================================================== print(f"Ray tracing {n_rays} rays (vectorised batch)...") tmax = 5.0 n_steps_per_ray = 100 # ------------------------------------------------------------------ # Build the initial batch state matrix Z0 (n_rays × ns) # ------------------------------------------------------------------ if dimension == 1: ns = idx_a_start + order + 1 Z0 = np.zeros((n_rays, ns)) Z0[:, idx_x] = x_init Z0[:, idx_xi] = px_init Z0[:, idx_S] = S_init Z0[:, idx_J11] = 1.0 if n_rays > 1: dpx = np.empty(n_rays); dx_ = np.empty(n_rays) dpx[1:-1] = px_init[2:] - px_init[:-2] dx_[1:-1] = x_init[2:] - x_init[:-2] dpx[0] = px_init[1] - px_init[0]; dx_[0] = x_init[1] - x_init[0] dpx[-1] = px_init[-1] - px_init[-2]; dx_[-1] = x_init[-1] - x_init[-2] with np.errstate(invalid='ignore', divide='ignore'): Z0[:, idx_K11] = np.where(dx_ != 0, dpx / dx_, 0.0) for k in range(order + 1): Z0[:, idx_a[k]] = a_init[k] else: ns = idx_a_start + order + 1 Z0 = np.zeros((n_rays, ns)) Z0[:, idx_x] = x_init; Z0[:, idx_y] = y_init Z0[:, idx_xi] = px_init; Z0[:, idx_eta] = py_init Z0[:, idx_S] = S_init Z0[:, idx_J11] = 1.0; Z0[:, idx_J22] = 1.0 if n_rays > 1: dpx0 = np.empty(n_rays); dpy0 = np.empty(n_rays) dx0 = np.empty(n_rays); dy0 = np.empty(n_rays) for arr_in, arr_out in [(px_init, dpx0), (py_init, dpy0), (x_init, dx0), (y_init, dy0)]: arr_out[1:-1] = arr_in[2:] - arr_in[:-2] arr_out[0] = arr_in[1] - arr_in[0] arr_out[-1] = arr_in[-1] - arr_in[-2] arc = np.hypot(dx0, dy0) m = arc > 1e-14 Z0[m, idx_K11] = dpx0[m] / arc[m] Z0[m, idx_K21] = dpy0[m] / arc[m] for k in range(order + 1): Z0[:, idx_a[k]] = a_init[k] # ------------------------------------------------------------------ # Vectorised ODE right-hand side # ------------------------------------------------------------------ def ray_ode_batch(t, z_flat): Z = z_flat.reshape(n_rays, ns) if dimension == 1: xv = Z[:, idx_x]; xiv = Z[:, idx_xi] args_ = (xv, xiv) dxdt = eval_func('dp_dxi', *args_) dxidt = -eval_func('dp_dx', *args_) dSdt = xiv * dxdt - eval_func('p', *args_) pxixi = eval_func('d2p_dxi2', *args_) pxxi = eval_func('d2p_dxidx', *args_) pxx = eval_func('d2p_dx2', *args_) Jv = Z[:, idx_J11]; Kv = Z[:, idx_K11] dJdt = pxixi*Kv + pxxi*Jv dKdt = -pxx*Jv - pxxi*Kv geom = pxixi; coupling = pxxi dZ = np.empty_like(Z) dZ[:, idx_x] = dxdt; dZ[:, idx_xi] = dxidt dZ[:, idx_S] = dSdt dZ[:, idx_J11] = dJdt; dZ[:, idx_K11] = dKdt a0v = Z[:, idx_a[0]] dZ[:, idx_a[0]] = -0.5*a0v*geom if order >= 1: a1v = Z[:, idx_a[1]] dZ[:, idx_a[1]] = -0.5*a1v*geom - 0.5*a0v*coupling if order >= 2: a2v = Z[:, idx_a[2]] d3 = eval_func('d3p_dxi3', *args_) dZ[:, idx_a[2]] = (-0.5*a2v*geom - 0.125*a0v*d3*dxidt - 0.25*Z[:, idx_a[1]]*coupling) if order >= 3: a3v = Z[:, idx_a[3]] d3 = eval_func('d3p_dxi3', *args_) dZ[:, idx_a[3]] = -0.5*a3v*geom - 0.1*Z[:, idx_a[1]]*d3*dxidt else: # dimension == 2 xv = Z[:, idx_x]; yv = Z[:, idx_y] xiv = Z[:, idx_xi]; etav = Z[:, idx_eta] args_ = (xv, yv, xiv, etav) dxdt = eval_func('dp_dxi', *args_) dydt = eval_func('dp_deta', *args_) dxidt = -eval_func('dp_dx', *args_) detadt = -eval_func('dp_dy', *args_) dSdt = xiv*dxdt + etav*dydt - eval_func('p', *args_) pxx = eval_func('d2p_dx2', *args_) pyy = eval_func('d2p_dy2', *args_) pxy = eval_func('d2p_dxy', *args_) pxix = eval_func('d2p_dxidx', *args_) pxiy = eval_func('d2p_dxidy', *args_) petax = eval_func('d2p_detadx', *args_) petay = eval_func('d2p_detady', *args_) pxixi = eval_func('d2p_dxi2', *args_) petapeta = eval_func('d2p_deta2', *args_) pxipeta = eval_func('d2p_dxideta', *args_) J11=Z[:,idx_J11]; J12=Z[:,idx_J12] J21=Z[:,idx_J21]; J22=Z[:,idx_J22] K11=Z[:,idx_K11]; K12=Z[:,idx_K12] K21=Z[:,idx_K21]; K22=Z[:,idx_K22] dZ = np.empty_like(Z) dZ[:,idx_x]=dxdt; dZ[:,idx_y]=dydt dZ[:,idx_xi]=dxidt; dZ[:,idx_eta]=detadt; dZ[:,idx_S]=dSdt dZ[:,idx_J11] = pxixi*K11+pxipeta*K21+pxix*J11+pxiy*J21 dZ[:,idx_J12] = pxixi*K12+pxipeta*K22+pxix*J12+pxiy*J22 dZ[:,idx_J21] = pxipeta*K11+petapeta*K21+petax*J11+petay*J21 dZ[:,idx_J22] = pxipeta*K12+petapeta*K22+petax*J12+petay*J22 dZ[:,idx_K11] = -(pxx*J11+pxy*J21)-(pxix*K11+petax*K21) dZ[:,idx_K12] = -(pxx*J12+pxy*J22)-(pxix*K12+petax*K22) dZ[:,idx_K21] = -(pxy*J11+pyy*J21)-(pxiy*K11+petay*K21) dZ[:,idx_K22] = -(pxy*J12+pyy*J22)-(pxiy*K12+petay*K22) geom = pxixi + petapeta; coupling = pxix + petay a0v = Z[:, idx_a[0]] dZ[:, idx_a[0]] = -0.5*a0v*geom if order >= 1: a1v = Z[:, idx_a[1]] cross = eval_func('d2p_dxideta', *args_) dZ[:, idx_a[1]] = (-0.5*a1v*geom - 0.5*a0v*coupling - 0.25*a0v*cross*(dxidt+detadt)) if order >= 2: a2v = Z[:, idx_a[2]] d3xi = eval_func('d3p_dxi3', *args_) d3eta = eval_func('d3p_deta3', *args_) d3mix1 = eval_func('d3p_dxi2deta', *args_) d3mix2 = eval_func('d3p_dxideta2', *args_) correction = (d3xi*dxidt + d3eta*detadt + d3mix1*(dxidt+detadt) + d3mix2*(dxidt+detadt)) dZ[:, idx_a[2]] = (-0.5*a2v*geom - 0.125*a0v*correction - 0.25*Z[:,idx_a[1]]*coupling) if order >= 3: a3v = Z[:, idx_a[3]] d3t = eval_func('d3p_dxi3',*args_)+eval_func('d3p_deta3',*args_) dZ[:, idx_a[3]] = (-0.5*a3v*geom - 0.1*Z[:,idx_a[1]]*d3t*(dxidt+detadt)) return dZ.ravel() # ------------------------------------------------------------------ # Single solve_ivp integrating all rays at once # ------------------------------------------------------------------ sol = solve_ivp( ray_ode_batch, (0, tmax), Z0.ravel(), method='RK45', t_eval=np.linspace(0, tmax, n_steps_per_ray), rtol=1e-6, atol=1e-9 ) if not sol.success: print(f"Warning: batch ray integration: {sol.message}") # Unpack: sol.y has shape (n_rays*ns, n_steps) Y = sol.y.reshape(n_rays, ns, -1) # (n_rays, ns, n_steps) rays = [] for i in range(n_rays): rd = {'t': sol.t} if dimension == 1: rd['x'] = Y[i, idx_x, :] rd['xi'] = Y[i, idx_xi, :] rd['S'] = Y[i, idx_S, :] J11 = Y[i, idx_J11, :] rd['J11'] = J11; rd['J'] = J11 else: rd['x'] = Y[i, idx_x, :]; rd['y'] = Y[i, idx_y, :] rd['xi'] = Y[i, idx_xi, :]; rd['eta'] = Y[i, idx_eta, :] rd['S'] = Y[i, idx_S, :] J11=Y[i,idx_J11,:]; J12=Y[i,idx_J12,:] J21=Y[i,idx_J21,:]; J22=Y[i,idx_J22,:] rd['J11']=J11; rd['J12']=J12; rd['J21']=J21; rd['J22']=J22 rd['J'] = J11*J22 - J12*J21 for k in range(order + 1): rd[f'a{k}'] = Y[i, idx_a[k], :] rays.append(rd) print(f"Successfully traced {len(rays)} rays") # ================================================================== # INTERPOLATION ONTO REGULAR GRID # ================================================================== print("Interpolating solution onto grid...") # Determine domain if domain is None: x_all = np.concatenate([ray['x'] for ray in rays]) x_min, x_max = x_all.min(), x_all.max() margin = 0.1 * (x_max - x_min) if dimension == 1: domain = (x_min - margin, x_max + margin) else: y_all = np.concatenate([ray['y'] for ray in rays]) y_min, y_max = y_all.min(), y_all.max() margin_y = 0.1 * (y_max - y_min) domain = ((x_min - margin, x_max + margin), (y_min - margin_y, y_max + margin_y)) # Create grid if dimension == 1: if isinstance(resolution, tuple): resolution = resolution[0] x_grid = np.linspace(domain[0], domain[1], resolution) # Collect ray data x_points = np.concatenate([ray['x'] for ray in rays]) S_points = np.concatenate([ray['S'] for ray in rays]) a_points = {k: np.concatenate([ray[f'a{k}'] for ray in rays]) for k in range(order + 1)} # Sort for interpolation sort_idx = np.argsort(x_points) x_points = x_points[sort_idx] S_points = S_points[sort_idx] for k in range(order + 1): a_points[k] = a_points[k][sort_idx] # Interpolate S_grid = interp1d(x_points, S_points, kind='linear', bounds_error=False, fill_value=0.0)(x_grid) a_grids = {} for k in range(order + 1): a_grids[k] = interp1d(x_points, a_points[k], kind='linear', bounds_error=False, fill_value=0.0)(x_grid) grid_coords = {'x': x_grid} else: # dimension == 2 if isinstance(resolution, int): nx = ny = resolution else: nx, ny = resolution (x_min, x_max), (y_min, y_max) = domain x_grid = np.linspace(x_min, x_max, nx) y_grid = np.linspace(y_min, y_max, ny) X_grid, Y_grid = np.meshgrid(x_grid, y_grid, indexing='ij') # Collect ray data x_points = [] y_points = [] S_points = [] a_points = {k: [] for k in range(order + 1)} for ray in rays: x_points.extend(ray['x']) y_points.extend(ray['y']) S_points.extend(ray['S']) for k in range(order + 1): a_points[k].extend(ray[f'a{k}']) points = np.column_stack([x_points, y_points]) if np.std(y_points) < 1e-12: # degenerate case: 1D interpolation S_grid = np.interp(X_grid[:,0], x_points, S_points) S_grid = np.tile(S_grid[:, None], (1, Y_grid.shape[1])) else: S_grid = griddata(points, S_points, (X_grid, Y_grid), method='linear', fill_value=0.0, rescale=True) # Interpolate S_grid = np.nan_to_num(S_grid, nan=0.0) # Amplitude interpolations are independent — run in parallel with threads. # (griddata spends most time in C code for Delaunay triangulation, so # threads get real concurrency despite the GIL.) from concurrent.futures import ThreadPoolExecutor as _TPool a_grids = {} def _interp_order(k): arr = griddata(points, a_points[k], (X_grid, Y_grid), method='linear', fill_value=0.0) return k, np.nan_to_num(arr, nan=0.0) with _TPool() as _pool: for k, arr in _pool.map(_interp_order, range(order + 1)): a_grids[k] = arr grid_coords = {'x': X_grid, 'y': Y_grid} # ================================================================== # CONSTRUCT WKB SOLUTION # ================================================================== phase_factor = np.exp(1j * S_grid / epsilon) # Sum asymptotic series a_total = np.zeros_like(a_grids[0], dtype=complex) epsilon_power = 1.0 for k in range(order + 1): a_total += epsilon_power * a_grids[k] epsilon_power *= epsilon print(f" Order {k}: max|a_{k}| = {np.max(np.abs(a_grids[k])):.6f}") u_grid = phase_factor * a_total print(f"\nWKB solution computed (order {order}, dim={dimension})") print(f"Max |u| = {np.max(np.abs(u_grid)):.6f}") # ================================================================== # RETURN RESULTS # ================================================================== result = { 'dimension': dimension, 'order': order, 'epsilon': epsilon, 'domain': domain, 'S': S_grid, 'a': a_grids, 'a_total': a_total, 'u': u_grid, 'rays': rays, 'n_rays': len(rays) } result.update(grid_coords) return result def _apply_1d_caustic_corrections(base_solution, caustics, epsilon, mode): """ Apply caustic corrections in 1D using Airy functions and Maslov index. """ x = base_solution['x'] S = base_solution['S'] a = base_solution['a'][0] # Initialize u_corrected with the standard solution u_corrected = np.copy(base_solution['u']) # Compute Maslov index for each point maslov_phases = np.zeros_like(x) for caustic in caustics: x_c = caustic.position[0] # Add π/2 phase shift past each caustic maslov_phases[x > x_c] += np.pi / 2 # Apply corrections based on mode if mode == 'none': # No correction, just return the standard solution print("No caustic correction applied (mode: none)") elif mode == 'maslov' or (mode == 'auto' and len(caustics) > 0): # Apply Maslov phase correction u_corrected = a * np.exp(1j * (S / epsilon + maslov_phases)) print(f"Applied Maslov correction: {len(caustics)} caustics found") if mode == 'airy' or (mode == 'auto' and len(caustics) > 0): # Apply Airy function near caustics # Start from current u_corrected (which may already have Maslov) for caustic in caustics: x_c = caustic.position[0] # Region of Airy correction airy_width = 5 * epsilon**(2/3) mask = np.abs(x - x_c) < airy_width if np.any(mask): # Scaled coordinate z = (x[mask] - x_c) / epsilon**(2/3) # Airy function via scipy from scipy.special import airy as _airy Ai, _, _, _ = _airy(z) # Amplitude at caustic idx_c = np.argmin(np.abs(x - x_c)) a_c = a[idx_c] S_c = S[idx_c] # Replace with uniform approximation u_corrected[mask] = a_c * np.pi * Ai * np.exp(1j * S_c / epsilon) if mode == 'airy': print(f"Applied Airy corrections near {len(caustics)} fold caustics") elif mode == 'auto': print(f"Applied Airy corrections near {len(caustics)} fold caustics") result = base_solution.copy() result['u'] = u_corrected result['u_standard'] = base_solution['u'] # Keep original for comparison result['maslov_phases'] = maslov_phases return result def _apply_2d_caustic_corrections(base_solution, caustics, epsilon, mode): """ Apply caustic corrections in 2D using Airy/Pearcey functions. """ X = base_solution['x'] Y = base_solution['y'] S = base_solution['S'] a = base_solution['a'][0] u_corrected = np.copy(base_solution['u']) if mode == 'none': print("No caustic correction applied (mode: none)") result = base_solution.copy() result['u_standard'] = base_solution['u'] return result # Classify and correct each caustic for caustic in caustics: x_c, y_c = caustic.position[0], caustic.position[1] caustic_type = caustic.arnold_type if caustic.arnold_type == 'A2': # Fold caustic - use Airy correction_width = 5 * epsilon**(2/3) # Distance to caustic dist = np.sqrt((X - x_c)**2 + (Y - y_c)**2) mask = dist < correction_width if np.any(mask): # Find direction normal to caustic (simplified) # In practice, compute from ray geometry # Scaled coordinate perpendicular to caustic z = dist[mask] / epsilon**(2/3) # Airy correction from scipy.special import airy as _airy Ai, _, _, _ = _airy(z) Ai = np.pi * Ai idx_x = np.argmin(np.abs(X[:, 0] - x_c)) idx_y = np.argmin(np.abs(Y[0, :] - y_c)) a_c = a[idx_x, idx_y] S_c = S[idx_x, idx_y] u_corrected[mask] = a_c * Ai * np.exp(1j * S_c / epsilon) elif caustic.arnold_type == 'A3': # Cusp caustic - use Pearcey correction_width = 5 * epsilon**(1/2) dist = np.sqrt((X - x_c)**2 + (Y - y_c)**2) mask = dist < correction_width if np.any(mask): # Scaled coordinates x_scaled = (X[mask] - x_c) / epsilon**(1/2) y_scaled = (Y[mask] - y_c) / epsilon**(1/4) # Pearcey integral (expensive!) P_vals = np.array([CausticFunctions.pearcey_approx(xs, ys) for xs, ys in zip(x_scaled, y_scaled)]) idx_x = np.argmin(np.abs(X[:, 0] - x_c)) idx_y = np.argmin(np.abs(Y[0, :] - y_c)) a_c = a[idx_x, idx_y] S_c = S[idx_x, idx_y] u_corrected[mask] = a_c * P_vals * np.exp(1j * S_c / epsilon) print(f"Applied corrections to {len(caustics)} caustics") print(f" Fold (Airy): {sum(1 for c in caustics if c.arnold_type=='A2')}") print(f" Cusp (Pearcey): {sum(1 for c in caustics if c.arnold_type=='A3')}") result = base_solution.copy() result['u'] = u_corrected result['u_standard'] = base_solution['u'] return result
[docs] def compare_orders(symbol, initial_phase, max_order=3, **kwargs): """ Compare WKB approximations at different orders. Works for both 1D and 2D automatically. """ import matplotlib.pyplot as plt solutions = {} for order in range(max_order + 1): print(f"\n{'='*60}") print(f"Computing order {order}") print(f"{'='*60}") sol = wkb_approximation(symbol, initial_phase, order=order, **kwargs) solutions[order] = sol # Plot comparison dim = solutions[0]['dimension'] n_orders = max_order + 1 if dim == 1: fig, axes = plt.subplots(n_orders, 1, figsize=(12, 3*n_orders)) if n_orders == 1: axes = [axes] for order, ax in enumerate(axes): sol = solutions[order] x = sol['x'] u = sol['u'] ax.plot(x, np.real(u), 'b-', label='Re(u)', linewidth=2) ax.plot(x, np.imag(u), 'r--', label='Im(u)', linewidth=2) ax.plot(x, np.abs(u), 'g:', label='|u|', linewidth=2) ax.set_xlabel('x', fontsize=11) ax.set_ylabel('u', fontsize=11) ax.set_title(f'Order {order} (ε={sol["epsilon"]:.3f})', fontsize=12) ax.grid(True, alpha=0.3) ax.legend(loc='best') # Add text with max amplitude max_amp = np.max(np.abs(u)) ax.text(0.02, 0.98, f'max|u| = {max_amp:.4f}', transform=ax.transAxes, va='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) else: # dim == 2 fig, axes = plt.subplots(2, n_orders, figsize=(5*n_orders, 9)) if n_orders == 1: axes = axes.reshape(2, 1) for order in range(n_orders): sol = solutions[order] X, Y = sol['x'], sol['y'] u = sol['u'] # Top row: |u| im1 = axes[0, order].contourf(X, Y, np.abs(u), levels=30, cmap='viridis') axes[0, order].set_title(f'Order {order}: |u|', fontsize=11) axes[0, order].set_xlabel('x') axes[0, order].set_ylabel('y') axes[0, order].set_aspect('equal') plt.colorbar(im1, ax=axes[0, order]) # Bottom row: Re(u) im2 = axes[1, order].contourf(X, Y, np.real(u), levels=30, cmap='RdBu_r') axes[1, order].set_title(f'Order {order}: Re(u)', fontsize=11) axes[1, order].set_xlabel('x') axes[1, order].set_ylabel('y') axes[1, order].set_aspect('equal') plt.colorbar(im2, ax=axes[1, order]) # Overlay some rays if 'rays' in sol: for ray in sol['rays'][::max(1, len(sol['rays'])//15)]: axes[0, order].plot(ray['x'], ray['y'], 'k-', alpha=0.2, linewidth=0.5) axes[1, order].plot(ray['x'], ray['y'], 'k-', alpha=0.2, linewidth=0.5) plt.tight_layout() # Print convergence info print("\n" + "="*60) print("CONVERGENCE ANALYSIS") print("="*60) if dim == 1: # Sample at center idx_center = len(solutions[0]['x']) // 2 print(f"\nAt x = {solutions[0]['x'][idx_center]:.3f}:") for order in range(n_orders): u_val = solutions[order]['u'][idx_center] print(f" Order {order}: u = {u_val:.6f}, |u| = {np.abs(u_val):.6f}") else: # Sample at center nx, ny = solutions[0]['x'].shape idx_x, idx_y = nx//2, ny//2 print(f"\nAt (x,y) = ({solutions[0]['x'][idx_x, idx_y]:.3f}, " f"{solutions[0]['y'][idx_x, idx_y]:.3f}):") for order in range(n_orders): u_val = solutions[order]['u'][idx_x, idx_y] print(f" Order {order}: u = {u_val:.6f}, |u| = {np.abs(u_val):.6f}") # Compute differences between consecutive orders print("\nRelative differences between orders:") for order in range(1, n_orders): u_prev = solutions[order-1]['u'] u_curr = solutions[order]['u'] # L2 relative difference diff = np.linalg.norm(u_curr - u_prev) / (np.linalg.norm(u_prev) + 1e-10) print(f" ||u_{order} - u_{order-1}|| / ||u_{order-1}|| = {diff:.6e}") return solutions, fig
[docs] def plot_phase_space(solution, time_slice=None): """ Plot phase space (position-momentum) trajectories. Parameters ---------- solution : dict Output from wkb_approximation time_slice : float or None Time at which to sample (None = final time) """ import matplotlib.pyplot as plt dim = solution['dimension'] rays = solution['rays'] if time_slice is None: time_idx = -1 else: time_idx = np.argmin(np.abs(rays[0]['t'] - time_slice)) if dim == 1: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) # Position-momentum plot for ray in rays: x = ray['x'] xi = ray['xi'] ax1.plot(x, xi, 'b-', alpha=0.5, linewidth=1) ax1.plot(x[0], xi[0], 'go', markersize=6) ax1.plot(x[time_idx], xi[time_idx], 'ro', markersize=4) ax1.set_xlabel('x (position)', fontsize=12) ax1.set_ylabel('ξ (momentum)', fontsize=12) ax1.set_title('Phase Space Trajectories', fontsize=13) ax1.grid(True, alpha=0.3) # Phase evolution for i, ray in enumerate(rays[::max(1, len(rays)//10)]): ax2.plot(ray['t'], ray['S'], alpha=0.7, label=f'Ray {i}') ax2.set_xlabel('t (time)', fontsize=12) ax2.set_ylabel('S (phase)', fontsize=12) ax2.set_title('Phase Evolution', fontsize=13) ax2.grid(True, alpha=0.3) ax2.legend(loc='best', fontsize=8) else: # dim == 2 fig = plt.figure(figsize=(16, 5)) # 3D phase space (x, y, |p|) ax1 = fig.add_subplot(131, projection='3d') for ray in rays[::max(1, len(rays)//20)]: x = ray['x'] y = ray['y'] p_mag = np.sqrt(ray['xi']**2 + ray['eta']**2) ax1.plot(x, y, p_mag, alpha=0.6, linewidth=1) ax1.set_xlabel('x') ax1.set_ylabel('y') ax1.set_zlabel('|p|') ax1.set_title('Phase Space (x, y, |p|)') # Momentum plane (ξ, η) ax2 = fig.add_subplot(132) for ray in rays: xi = ray['xi'] eta = ray['eta'] ax2.plot(xi, eta, 'b-', alpha=0.4, linewidth=0.8) ax2.plot(xi[0], eta[0], 'go', markersize=4) ax2.plot(xi[time_idx], eta[time_idx], 'ro', markersize=3) ax2.set_xlabel('ξ', fontsize=12) ax2.set_ylabel('η', fontsize=12) ax2.set_title('Momentum Space', fontsize=13) ax2.set_aspect('equal') ax2.grid(True, alpha=0.3) # Phase evolution ax3 = fig.add_subplot(133) for i, ray in enumerate(rays[::max(1, len(rays)//15)]): ax3.plot(ray['t'], ray['S'], alpha=0.6) ax3.set_xlabel('t (time)', fontsize=12) ax3.set_ylabel('S (phase)', fontsize=12) ax3.set_title('Phase Evolution', fontsize=13) ax3.grid(True, alpha=0.3) plt.tight_layout() return fig
# ================================================================== # ADVANCED: Amplitude decomposition # ==================================================================
[docs] def plot_amplitude_decomposition(solution): """ Plot individual amplitude orders aₖ and their contributions. """ import matplotlib.pyplot as plt dim = solution['dimension'] order = solution['order'] eps = solution['epsilon'] if dim == 1: fig, axes = plt.subplots(order + 2, 1, figsize=(12, 3*(order+2))) if order == 0: axes = [axes] else: axes = axes.flatten() x = solution['x'] # Plot each amplitude order for k in range(order + 1): ak = solution['a'][k] weight = eps**k axes[k].plot(x, ak, 'b-', linewidth=2, label=f'$a_{k}$') axes[k].plot(x, weight * ak, 'r--', linewidth=2, label=f'$\\varepsilon^{k} a_{k}$') axes[k].set_xlabel('x') axes[k].set_ylabel(f'$a_{k}$') axes[k].set_title(f'Amplitude order {k} (weight = ε^{k} = {weight:.4f})') axes[k].grid(True, alpha=0.3) axes[k].legend() # Plot total amplitude axes[order + 1].plot(x, np.real(solution['a_total']), 'b-', linewidth=2, label='Re($a_{total}$)') axes[order + 1].plot(x, np.imag(solution['a_total']), 'r--', linewidth=2, label='Im($a_{total}$)') axes[order + 1].plot(x, np.abs(solution['a_total']), 'g:', linewidth=2, label='$|a_{total}|$') axes[order + 1].set_xlabel('x') axes[order + 1].set_ylabel('Total amplitude') axes[order + 1].set_title('Total Amplitude (sum of all orders)') axes[order + 1].grid(True, alpha=0.3) axes[order + 1].legend() else: # dim == 2 fig, axes = plt.subplots(2, order + 2, figsize=(5*(order+2), 9)) if order == 0: axes = axes.reshape(2, 1) X, Y = solution['x'], solution['y'] # Plot each amplitude order for k in range(order + 1): ak = solution['a'][k] weight = eps**k # Top: ak im1 = axes[0, k].contourf(X, Y, ak, levels=30, cmap='viridis') axes[0, k].set_title(f'$a_{k}$') axes[0, k].set_xlabel('x') axes[0, k].set_ylabel('y') axes[0, k].set_aspect('equal') plt.colorbar(im1, ax=axes[0, k]) # Bottom: weighted im2 = axes[1, k].contourf(X, Y, weight * ak, levels=30, cmap='viridis') axes[1, k].set_title(f'$\\varepsilon^{k} a_{k}$ (ε={eps:.3f})') axes[1, k].set_xlabel('x') axes[1, k].set_ylabel('y') axes[1, k].set_aspect('equal') plt.colorbar(im2, ax=axes[1, k]) # Plot total amplitude a_total_abs = np.abs(solution['a_total']) a_total_real = np.real(solution['a_total']) im3 = axes[0, order+1].contourf(X, Y, a_total_abs, levels=30, cmap='viridis') axes[0, order+1].set_title('$|a_{total}|$') axes[0, order+1].set_xlabel('x') axes[0, order+1].set_ylabel('y') axes[0, order+1].set_aspect('equal') plt.colorbar(im3, ax=axes[0, order+1]) im4 = axes[1, order+1].contourf(X, Y, a_total_real, levels=30, cmap='RdBu_r') axes[1, order+1].set_title('Re($a_{total}$)') axes[1, order+1].set_xlabel('x') axes[1, order+1].set_ylabel('y') axes[1, order+1].set_aspect('equal') plt.colorbar(im4, ax=axes[1, order+1]) plt.tight_layout() return fig
# ================================================================== # VISUALIZATION WITH CAUSTIC HIGHLIGHTING # ==================================================================
[docs] def plot_with_caustics(solution, component='abs', highlight_caustics=True): """ Plot WKB solution with caustics highlighted. Parameters ---------- solution : dict Output of wkb_approximation() component : {'abs','real','imag','phase'} Which component of u to visualize. highlight_caustics : bool Whether to mark caustic locations. """ import numpy as np import matplotlib.pyplot as plt from matplotlib.patches import Patch # ---------------------------------------------------------- # Helper: select component to plot # ---------------------------------------------------------- def _select_component(u, component): if component == 'real': return np.real(u), 'RdBu_r' elif component == 'imag': return np.imag(u), 'RdBu_r' elif component == 'abs': return np.abs(u), 'viridis' elif component == 'phase': return np.angle(u), 'twilight' else: raise ValueError("component must be one of: real, imag, abs, phase") # ---------------------------------------------------------- # Helper: plot 1D caustics (single legend entry) # ---------------------------------------------------------- def _plot_caustics_1d(ax, caustics, x, data): if not highlight_caustics or len(caustics) == 0: return added = False for c in caustics: xc = c.position[0] # CausticEvent uses attribute access, not dict ax.axvline(xc, color='red', linestyle='--', linewidth=2, alpha=0.7, label='Caustic' if not added else None) ax.plot( xc, data[np.argmin(np.abs(x - xc))], 'ro', markersize=8 ) added = True # ---------------------------------------------------------- # Helper: plot 2D caustics (only two legend entries: A2, A3) # ---------------------------------------------------------- def _plot_caustics_2d(ax, caustics): if not highlight_caustics or len(caustics) == 0: return type_seen = set() for c in caustics: x_c, y_c = c.position[0], c.position[1] t = getattr(c, 'arnold_type', 'A2') marker = 'o' if t == 'A2' else 's' color = 'red' if t == 'A2' else 'orange' label = None if t not in type_seen: label = f"{'Fold' if t=='A2' else 'Cusp'} ({t})" type_seen.add(t) ax.plot( x_c, y_c, marker, color=color, markersize=10, markeredgecolor='white', markeredgewidth=1.5, label=label ) # ---------------------------------------------------------- # Retrieve data & select component # ---------------------------------------------------------- dim = solution['dimension'] u = solution['u'] caustics = solution.get('caustics', []) data, cmap = _select_component(u, component) # ---------------------------------------------------------- # 1D plotting # ---------------------------------------------------------- if dim == 1: fig, axes = plt.subplots(2 if 'u_standard' in solution else 1, 1, figsize=(12, 6)) if not isinstance(axes, np.ndarray): axes = [axes] x = solution['x'] # Main panel ax = axes[0] ax.plot(x, data, 'b-', linewidth=2, label=component) ax.set_xlabel('x') ax.set_ylabel(f'{component}(u)') ax.set_title(f'WKB with Caustic Corrections ({solution.get("caustic_correction","none")})') ax.grid(True, alpha=0.3) # Caustics _plot_caustics_1d(ax, caustics, x, data) ax.legend() # Comparison panel if 'u_standard' in solution: ax2 = axes[1] data_std, _ = _select_component(solution['u_standard'], component) ax2.plot(x, data_std, 'r--', linewidth=2, alpha=0.7, label='Standard WKB') ax2.plot(x, data, 'b-', linewidth=2, label='Corrected') ax2.set_xlabel('x') ax2.set_ylabel(f'{component}(u)') ax2.set_title('Comparison: Standard vs Corrected') ax2.grid(True, alpha=0.3) # Caustics on the comparison plot _plot_caustics_1d(ax2, caustics, x, data) ax2.legend() plt.tight_layout() return fig # ---------------------------------------------------------- # 2D plotting # ---------------------------------------------------------- else: fig, axes = plt.subplots(1, 2 if 'u_standard' in solution else 1, figsize=(16, 6)) if not isinstance(axes, np.ndarray): axes = [axes] X = solution['x'] Y = solution['y'] idx = 0 # Standard WKB if 'u_standard' in solution: data_std, _ = _select_component(solution['u_standard'], component) im = axes[0].contourf(X, Y, data_std, 30, cmap=cmap) axes[0].set_title("Standard WKB") axes[0].set_aspect('equal') axes[0].set_xlabel('x') axes[0].set_ylabel('y') fig.colorbar(im, ax=axes[0]) idx = 1 # Corrected WKB im2 = axes[idx].contourf(X, Y, data, 30, cmap=cmap) axes[idx].set_title(f"With Caustic Corrections ({solution.get('caustic_correction','none')})") axes[idx].set_aspect('equal') axes[idx].set_xlabel('x') axes[idx].set_ylabel('y') fig.colorbar(im2, ax=axes[idx]) # Caustics (single legend) _plot_caustics_2d(axes[idx], caustics) # Rays (optional) if 'rays' in solution: rays = solution['rays'] step = max(1, len(rays)//20) for r in rays[::step]: axes[idx].plot(r['x'], r['y'], 'k-', linewidth=0.6, alpha=0.25) axes[idx].legend(loc='upper right') plt.tight_layout() return fig
[docs] def plot_caustic_analysis(solution): """ Detailed analysis plot of caustics. """ import matplotlib.pyplot as plt caustics = solution.get('caustics', []) if len(caustics) == 0: print("No caustics to analyze") return None dim = solution['dimension'] if dim == 1: fig, axes = plt.subplots(3, 1, figsize=(12, 10)) x = solution['x'] # 1. Solution amplitude axes[0].plot(x, np.abs(solution['u']), 'b-', linewidth=2, label='|u| corrected') if 'u_standard' in solution: axes[0].plot(x, np.abs(solution['u_standard']), 'r--', linewidth=2, alpha=0.7, label='|u| standard') for caustic in caustics: x_c = caustic.position[0] axes[0].axvline(x_c, color='red', linestyle=':', alpha=0.5) axes[0].text(x_c, axes[0].get_ylim()[1]*0.9, f"Caustic\n{caustic.arnold_type}", ha='center', fontsize=9, bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.5)) axes[0].set_ylabel('|u|', fontsize=12) axes[0].set_title('Amplitude with Caustic Locations') axes[0].legend() axes[0].grid(True, alpha=0.3) # 2. Phase axes[1].plot(x, solution['S'], 'g-', linewidth=2, label='Phase S') if 'maslov_phases' in solution: axes[1].plot(x, solution['maslov_phases'], 'orange', linewidth=2, linestyle='--', label='Maslov correction') for caustic in caustics: axes[1].axvline(caustic.position[0], color='red', linestyle=':', alpha=0.5) axes[1].set_ylabel('Phase', fontsize=12) axes[1].set_title('Phase and Maslov Index') axes[1].legend() axes[1].grid(True, alpha=0.3) # 3. Error between standard and corrected if 'u_standard' in solution: error = np.abs(solution['u'] - solution['u_standard']) axes[2].plot(x, error, 'purple', linewidth=2) axes[2].set_xlabel('x', fontsize=12) axes[2].set_ylabel('|u_corrected - u_standard|', fontsize=12) axes[2].set_title('Correction Magnitude') axes[2].set_yscale('log') axes[2].grid(True, alpha=0.3) for caustic in caustics: axes[2].axvline(caustic.position[0], color='red', linestyle=':', alpha=0.5) else: # 2D n_caustics = len(caustics) fig = plt.figure(figsize=(16, 10)) gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3) X, Y = solution['x'], solution['y'] # Main plot: solution with caustics ax_main = fig.add_subplot(gs[0:2, 0:2]) im = ax_main.contourf(X, Y, np.abs(solution['u']), levels=30, cmap='viridis') plt.colorbar(im, ax=ax_main, label='|u|') # Plot rays if 'rays' in solution: for ray in solution['rays'][::max(1, len(solution['rays'])//30)]: ax_main.plot(ray['x'], ray['y'], 'k-', alpha=0.15, linewidth=0.5) # Mark caustics fold_caustics = [] cusp_caustics = [] for caustic in caustics: x_c, y_c = caustic.position[0], caustic.position[1] if caustic.arnold_type == 'A2': fold_caustics.append((x_c, y_c)) ax_main.plot(x_c, y_c, 'ro', markersize=12, markeredgecolor='white', markeredgewidth=2) else: cusp_caustics.append((x_c, y_c)) ax_main.plot(x_c, y_c, 'ys', markersize=12, markeredgecolor='white', markeredgewidth=2) ax_main.set_xlabel('x', fontsize=11) ax_main.set_ylabel('y', fontsize=11) ax_main.set_title(f'Solution with {n_caustics} Caustics', fontsize=13) ax_main.set_aspect('equal') # Legend from matplotlib.patches import Patch legend_elements = [ Patch(facecolor='red', label=f'Fold (A2): {len(fold_caustics)}'), Patch(facecolor='yellow', label=f'Cusp (A3): {len(cusp_caustics)}') ] ax_main.legend(handles=legend_elements, loc='upper right') # Phase plot ax_phase = fig.add_subplot(gs[0, 2]) im_phase = ax_phase.contourf(X, Y, solution['S'], levels=30, cmap='twilight') plt.colorbar(im_phase, ax=ax_phase, label='Phase S') ax_phase.set_title('Phase') ax_phase.set_aspect('equal') # Error plot if 'u_standard' in solution: ax_error = fig.add_subplot(gs[1, 2]) error = np.abs(solution['u'] - solution['u_standard']) im_error = ax_error.contourf(X, Y, np.log10(error + 1e-10), levels=30, cmap='hot') plt.colorbar(im_error, ax=ax_error, label='log10(error)') ax_error.set_title('Correction Effect') ax_error.set_aspect('equal') # Caustic statistics ax_stats = fig.add_subplot(gs[2, :]) ax_stats.axis('off') stats_text = f"Caustic Statistics:\n" stats_text += f" Total caustics: {n_caustics}\n" stats_text += f" Fold caustics (A2): {len(fold_caustics)}\n" stats_text += f" Cusp caustics (A3): {len(cusp_caustics)}\n" stats_text += f" Correction method: {solution.get('caustic_correction', 'none')}\n" stats_text += f" Epsilon: {solution['epsilon']:.4f}\n" if 'u_standard' in solution: max_error = np.max(np.abs(solution['u'] - solution['u_standard'])) mean_error = np.mean(np.abs(solution['u'] - solution['u_standard'])) stats_text += f" Max correction: {max_error:.4e}\n" stats_text += f" Mean correction: {mean_error:.4e}\n" ax_stats.text(0.1, 0.5, stats_text, fontsize=11, verticalalignment='center', fontfamily='monospace', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5)) plt.tight_layout() return fig
[docs] def create_initial_data_line(x_range, n_points=20, direction=(1, 0), y_intercept=0.0): """ Create initial data for WKB on a line segment. Parameters ---------- x_range : tuple Range (x_min, x_max) for the line segment. n_points : int Number of points on the line. direction : tuple Direction of rays (ξ₀, η₀). y_intercept : float y-coordinate of the line. Returns ------- dict Initial data for wkb_approximation. Examples -------- >>> # Horizontal line with rays going upward >>> ic = create_initial_data_line((-1, 1), n_points=20, ... direction=(0, 1), y_intercept=0) """ x_init = np.linspace(x_range[0], x_range[1], n_points) y_init = np.full(n_points, y_intercept) S_init = np.zeros(n_points) # Normalize direction dir_norm = np.sqrt(direction[0]**2 + direction[1]**2) px_init = np.full(n_points, direction[0] / dir_norm) py_init = np.full(n_points, direction[1] / dir_norm) return { 'x': x_init, 'y': y_init, 'S': S_init, 'p_x': px_init, 'p_y': py_init }
[docs] def create_initial_data_circle(radius=1.0, n_points=30, outward=True): """ Create initial data for WKB on a circle. Parameters ---------- radius : float Radius of the circle. n_points : int Number of points on the circle. outward : bool If True, rays point outward; if False, inward. Returns ------- dict Initial data for wkb_approximation. Examples -------- >>> # Circle with outward rays >>> ic = create_initial_data_circle(radius=1.0, n_points=30, outward=True) """ theta = np.linspace(0, 2*np.pi, n_points, endpoint=False) x_init = radius * np.cos(theta) y_init = radius * np.sin(theta) S_init = np.zeros(n_points) # Rays perpendicular to circle if outward: px_init = np.cos(theta) py_init = np.sin(theta) else: px_init = -np.cos(theta) py_init = -np.sin(theta) return { 'x': x_init, 'y': y_init, 'S': S_init, 'p_x': px_init, 'p_y': py_init }
[docs] def create_initial_data_point_source(x0=0.0, y0=0.0, n_rays=20): """ Create initial data for WKB from a point source. Parameters ---------- x0, y0 : float Source location. n_rays : int Number of rays emanating from source. Returns ------- dict Initial data for wkb_approximation. Examples -------- >>> # Point source at origin >>> ic = create_initial_data_point_source(0, 0, n_rays=24) """ theta = np.linspace(0, 2*np.pi, n_rays, endpoint=False) x_init = np.full(n_rays, x0) y_init = np.full(n_rays, y0) S_init = np.zeros(n_rays) # Rays in all directions px_init = np.cos(theta) py_init = np.sin(theta) return { 'x': x_init, 'y': y_init, 'S': S_init, 'p_x': px_init, 'p_y': py_init }
[docs] def visualize_wkb_rays(wkb_result, plot_type='phase', n_rays_plot=None): """ Visualize WKB solution with rays. Parameters ---------- wkb_result : dict Output from wkb_approximation. plot_type : str What to visualize: 'phase', 'amplitude', 'real', 'rays'. n_rays_plot : int, optional Number of rays to plot (if None, plot all). Examples -------- >>> wkb = wkb_approximation(...) >>> visualize_wkb_rays(wkb, plot_type='phase') """ fig, ax = plt.subplots(figsize=(10, 8)) X = wkb_result['x'] Y = wkb_result['y'] if plot_type == 'phase': # Plot phase S = wkb_result['S'] im = ax.contourf(X, Y, S, levels=30, cmap='twilight') plt.colorbar(im, ax=ax, label='Phase S(x,y)') ax.set_title('WKB Phase Function') elif plot_type == 'amplitude': # Plot amplitude a = wkb_result['a'] im = ax.contourf(X, Y, a, levels=30, cmap='viridis') plt.colorbar(im, ax=ax, label='Amplitude a(x,y)') ax.set_title('WKB Amplitude') elif plot_type == 'real': # Plot real part u = wkb_result['u'] im = ax.contourf(X, Y, np.real(u), levels=30, cmap='RdBu') plt.colorbar(im, ax=ax, label='Re(u)') ax.set_title('WKB Solution - Real Part') elif plot_type == 'rays': # Plot phase contours with rays S = wkb_result['S'] ax.contour(X, Y, S, levels=20, colors='gray', alpha=0.3) ax.set_title('WKB Rays') # Overlay rays if 'rays' in wkb_result and plot_type in ['phase', 'amplitude', 'rays']: rays = wkb_result['rays'] n_total = len(rays) if n_rays_plot is None: n_rays_plot = min(n_total, 20) # Limit for clarity # Select evenly spaced rays ray_indices = np.linspace(0, n_total-1, n_rays_plot, dtype=int) for idx in ray_indices: ray = rays[idx] ax.plot(ray['x'], ray['y'], 'r-', alpha=0.5, linewidth=1) # Mark start ax.plot(ray['x'][0], ray['y'][0], 'go', markersize=4) ax.set_xlabel('x') ax.set_ylabel('y') ax.set_aspect('equal') ax.grid(True, alpha=0.3) plt.tight_layout() plt.show()
def _run_tests(verbose=True): """ Run a series of tests for the WKB module, generating plots for verification. If verbose, prints progress and shows plots. """ import matplotlib.pyplot as plt import numpy as np from sympy import symbols print("\n" + "="*60) print(" wkb.py — self-test with visualisation") print("="*60) # ------------------------------------------------------------------ # Test 1: 1D harmonic oscillator (no caustics) # ------------------------------------------------------------------ print("\nTest 1: 1D harmonic oscillator (no caustics expected)") x, xi = symbols('x xi', real=True) symbol_1d = xi**2 + x**2 n_rays = 30 x0 = np.linspace(-2, 2, n_rays) initial_1d = { 'x': x0, 'p_x': np.ones(n_rays), 'S': 0.5 * x0**2, 'a': np.exp(-x0**2) } result_1d = wkb_approximation( symbol_1d, initial_1d, order=2, epsilon=0.1, resolution=200, domain=(-3, 3) ) if verbose: print(f" → {result_1d['n_rays']} rays traced") print(f" → max |u| = {np.max(np.abs(result_1d['u'])):.6f}") # Plot 1D result fig1, ax1 = plt.subplots(figsize=(10, 4)) x_plot = result_1d['x'] ax1.plot(x_plot, np.real(result_1d['u']), 'b-', label='Re(u)') ax1.plot(x_plot, np.imag(result_1d['u']), 'r--', label='Im(u)') ax1.plot(x_plot, np.abs(result_1d['u']), 'g:', label='|u|') ax1.set_xlabel('x') ax1.set_title('Test 1: 1D harmonic oscillator (order 2, ε=0.1)') ax1.legend() ax1.grid(True, alpha=0.3) # ------------------------------------------------------------------ # Test 2: 2D point source (circular waves) # ------------------------------------------------------------------ print("\nTest 2: 2D point source (circular waves)") x, y, xi, eta = symbols('x y xi eta', real=True) c = 1.0 symbol_2d = xi**2 + eta**2 - c**2 n_rays = 36 theta = np.linspace(0, 2*np.pi, n_rays, endpoint=False) r0 = 0.5 initial_2d = { 'x': r0 * np.cos(theta), 'y': r0 * np.sin(theta), 'p_x': np.cos(theta), 'p_y': np.sin(theta), 'S': np.zeros(n_rays), 'a': np.ones(n_rays) } result_2d = wkb_approximation( symbol_2d, initial_2d, order=1, epsilon=0.2, resolution=(80, 80), domain=((-2, 2), (-2, 2)) ) if verbose: print(f" → {result_2d['n_rays']} rays traced") print(f" → max |u| = {np.max(np.abs(result_2d['u'])):.6f}") fig2, axes = plt.subplots(1, 2, figsize=(12, 5)) X, Y = result_2d['x'], result_2d['y'] im1 = axes[0].pcolormesh(X, Y, np.abs(result_2d['u']), shading='auto', cmap='viridis') axes[0].set_title('Test 2: |u| (point source)') axes[0].set_aspect('equal') plt.colorbar(im1, ax=axes[0]) for ray in result_2d['rays'][::len(result_2d['rays'])//10]: axes[0].plot(ray['x'], ray['y'], 'w-', alpha=0.3, linewidth=0.7) im2 = axes[1].pcolormesh(X, Y, np.real(result_2d['u']), shading='auto', cmap='RdBu_r') axes[1].set_title('Test 2: Re(u)') axes[1].set_aspect('equal') plt.colorbar(im2, ax=axes[1]) # ------------------------------------------------------------------ # Test 3: 1D with a caustic (fold) – use J integration # ------------------------------------------------------------------ print("\nTest 3: 1D with a caustic (fold) – Airy pattern expected") symbol_caustic = xi**2 - x x0_c = np.linspace(-1, 1, 25) initial_caustic = { 'x': x0_c, 'p_x': np.ones_like(x0_c), 'S': np.zeros_like(x0_c), 'a': np.ones_like(x0_c) } result_no_corr = wkb_approximation( symbol_caustic, initial_caustic, order=1, epsilon=0.1, resolution=300, domain=(-1.5, 1.5), caustic_correction='none' ) result_corr = wkb_approximation( symbol_caustic, initial_caustic, order=1, epsilon=0.1, resolution=300, domain=(-1.5, 1.5), caustic_correction='auto', caustic_threshold=1e-2 ) if verbose: n_c = len(result_corr.get('caustics', [])) print(f" → detected {n_c} caustics (note: J non intégré → 0)") fig3, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True) x_plot = result_no_corr['x'] axes[0].plot(x_plot, np.abs(result_no_corr['u']), 'b-', label='standard WKB') axes[0].plot(x_plot, np.abs(result_corr['u']), 'r--', label='with caustic correction') axes[0].set_ylabel('|u|') axes[0].set_title('Test 3: 1D caustic (fold) – amplitude') axes[0].legend() axes[0].grid(True, alpha=0.3) axes[1].plot(x_plot, np.real(result_no_corr['u']), 'b-', label='standard') axes[1].plot(x_plot, np.real(result_corr['u']), 'r--', label='corrected') axes[1].set_xlabel('x') axes[1].set_ylabel('Re(u)') axes[1].set_title('Real part') axes[1].legend() axes[1].grid(True, alpha=0.3) if result_corr.get('caustics'): x_c = result_corr['caustics'][0].position[0] axes[0].axvline(x_c, color='k', linestyle=':', alpha=0.7) axes[1].axvline(x_c, color='k', linestyle=':', alpha=0.7) # ------------------------------------------------------------------ # Test 4: Comparison of orders for 1D oscillator # ------------------------------------------------------------------ print("\nTest 4: Compare WKB orders (0,1,2) for 1D oscillator") solutions, fig4 = compare_orders( symbol_1d, initial_1d, max_order=2, epsilon=0.1, resolution=200, domain=(-3, 3) ) # ------------------------------------------------------------------ # Test 5: Amplitude decomposition for 2D point source # ------------------------------------------------------------------ print("\nTest 5: Amplitude decomposition for 2D point source") result_2d_high = wkb_approximation( symbol_2d, initial_2d, order=2, epsilon=0.2, resolution=(60, 60), domain=((-2, 2), (-2, 2)) ) fig5 = plot_amplitude_decomposition(result_2d_high) # ------------------------------------------------------------------ # Test 6: Phase space plot # ------------------------------------------------------------------ print("\nTest 6: Phase space plot") fig6 = plot_phase_space(result_1d) # ------------------------------------------------------------------ # Test 7: Caustic analysis plot (if detected) # ------------------------------------------------------------------ if result_corr.get('caustics'): print("\nTest 7: Caustic analysis plot") fig7 = plot_caustic_analysis(result_corr) else: print("\nTest 7: Skipped (no caustic in test 3)") plt.show() print("\n" + "="*60) print(" All tests completed. Close the figure windows to exit.") print("="*60) return { 'test1_1d_harmonic': result_1d, 'test2_2d_point': result_2d, 'test3_caustic_1d': (result_no_corr, result_corr), } if __name__ == "__main__": _run_tests(verbose=True)