Source code for qrisp.algorithms.gqsp.helper_functions

"""
********************************************************************************
* Copyright (c) 2026 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************
"""

import jax
from jax import Array
import jax.numpy as jnp
from typing import Literal, TYPE_CHECKING

if TYPE_CHECKING:
    from jax.typing import ArrayLike


# numpy.polynomial.chebyshev.poly2cheb
# To be deprecated when available in jax.numpy
[docs] @jax.jit def poly2cheb(poly: "ArrayLike") -> Array: """ Convert a polynomial from monomial to Chebyshev basis. JAX version of `numpy.polynomial.chebyshev.poly2cheb <https://numpy.org/doc/2.3/reference/generated/numpy.polynomial.chebyshev.poly2cheb.html>`_. Convert an array representing the coefficients of a polynomial (relative to the monomial basis) ordered from lowest degree to highest, to an array of the coefficients of the equivalent Chebyshev series, ordered from lowest to highest degree. Parameters ---------- poly : ArrayLike 1-D array containing the polynomial coefficients, ordered from lowest order term to highest. Returns ------- cheb : Array 1-D array containing the coefficients of the equivalent Chebyshev series ordered from lowest order term to highest. Examples -------- >>> import jax.numpy as jnp >>> from qrisp.gqsp import poly2cheb >>> poly = jnp.array([-2., -8., 4., 12.]) >>> cheb = poly2cheb(poly) >>> cheb [0., 1., 2., 3.] """ N = len(poly) # Build the transformation matrix C such that P_power = C @ P_cheb # This matrix contains the power-basis coefficients of T_n(x) C = jnp.zeros((N, N), dtype=poly.dtype) C = C.at[0, 0].set(1) # T_0(x) = 1 if N > 1: C = C.at[1, 1].set(1) # T_1(x) = x # Use the recurrence T_n(x) = 2 * x * T_{n-1}(x) - T_{n-2}(x) for n in range(2, N): # 2 * x * T_{n-1}(x): shift coefficients right by 1 and multiply by 2 prev = C[n - 1] prev_shifted = jnp.roll(prev, 1) * 2 # Handle the roll boundary condition manually to match 2 * x * T_{n-1}(x) prev_shifted = prev_shifted.at[0].set(0) C = C.at[n, :].set(prev_shifted - C[n - 2, :]) # Solve the linear system for the Chebyshev coefficients # The matrix C is triangular/well-behaved, making the solve stable cheb = jnp.linalg.solve(C.T, poly) return cheb
# numpy.polynomial.chebyshev.cheb2poly # To be deprecated when available in jax.numpy
[docs] @jax.jit def cheb2poly(cheb: "ArrayLike") -> Array: """ Convert a polynomial from Chebyshev to monomial basis. JAX version of `numpy.polynomial.chebyshev.cheb2poly <https://numpy.org/doc/stable/reference/generated/numpy.polynomial.chebyshev.cheb2poly.html>`_. Convert an array representing the coefficients of a Chebyshev series, ordered from lowest degree to highest, to an array of the coefficients of the equivalent polynomial (relative to the monomial basis) ordered from lowest to highest degree. Parameters ---------- cheb : ArrayLike 1-D array containing the Chebyshev series coefficients, ordered from lowest order term to highest. Returns ------- poly : Array 1-D array containing the coefficients of the equivalent polynomial (relative to the monomial basis), ordered from lowest order term to highest. Examples -------- >>> import jax.numpy as jnp >>> from qrisp.gqsp import cheb2poly >>> poly = jnp.array([0., 1., 2., 3.]) >>> poly = cheb2poly(cheb) >>> poly [-2., -8., 4., 12.] """ N = len(cheb) # Build the transformation matrix C such that P_power = C @ P_cheb # This matrix contains the power-basis coefficients of T_n(x) C = jnp.zeros((N, N), dtype=cheb.dtype) C = C.at[0, 0].set(1) # T_0(x) = 1 if N > 1: C = C.at[1, 1].set(1) # T_1(x) = x # Use the recurrence T_n(x) = 2 * x * T_{n-1}(x) - T_{n-2}(x) for n in range(2, N): # 2 * x * T_{n-1}(x): shift coefficients right by 1 and multiply by 2 prev = C[n - 1] prev_shifted = jnp.roll(prev, 1) * 2 # Handle the roll boundary condition manually to match 2 * x * T_{n-1}(x) prev_shifted = prev_shifted.at[0].set(0) C = C.at[n, :].set(prev_shifted - C[n - 2, :]) # Resulting power coefficients poly = jnp.dot(cheb, C) # or jnp.dot(C.T, coeffs) if coeffs was a column vector return poly
def _rescale_poly( alpha: "ArrayLike", p: "ArrayLike", kind: Literal["Polynomial", "Chebyshev"] = "Polynomial", ) -> "ArrayLike": r""" Returns a new polynomial $\tilde{p}$ such that $\tilde{p}(z) = p(z/\alpha)$. Parameters ---------- alpha : ArrayLike Scalar scaling factor. p : ArrayLike 1-D array containing the polynomial coefficients, ordered from lowest order term to highest. kind : {"Polynomial", "Chebyshev"} The kind of ``p``. The default is ``"Polynomial"``. Returns ------- ArrayLike 1-D array containing the (rescaled) polynomial coefficients, ordered from lowest order term to highest. """ # Rescaling of the polynomial to account for scaling factor alpha of block-encoding scaling_exponents = jnp.arange(len(p)) scaling_factors = jnp.power(alpha, scaling_exponents) # Convert to Polynomial for rescaling if kind == "Chebyshev": p = cheb2poly(p) p = p * scaling_factors if kind == "Chebyshev": p = poly2cheb(p) return p