Skip to content

utils

This module contains miscellaneous utilities useful for various steps involved in the isotropic error generation.

isotropic.utils

bisection

This module contains functions for the bisection algorithm to calculate $F^{-1}$

get_theta(F, a, b, x, eps)

Finds the value of theta such that $F(\theta) = x$ using the bisection method. This function assumes that $F$ is an increasing function in the interval $[a, b]$ and that $F(a) \leq x \leq F(b)$.

The bisection method is a root-finding method that repeatedly bisects an interval and then selects a subinterval in which a root exists.

Parameters:

Name Type Description Default
F Callable

Function for which to compute the inverse.

required
a float

Lower bound of the interval.

required
b float

Upper bound of the interval.

required
x float | ArrayLike

Value for which to find the inverse.

required
eps float

Tolerance for convergence.

required

Returns:

Type Description
float

The value of $theta$ such that $F(\theta) = x$.

Source code in src/isotropic/utils/bisection.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def get_theta(
    F: Callable, a: float, b: float, x: float | ArrayLike, eps: float
) -> float:
    """
    Finds the value of theta such that $F(\\theta) = x$ using the bisection method.
    This function assumes that $F$ is an increasing function in the interval $[a, b]$
    and that $F(a) \\leq x \\leq F(b)$.

    The bisection method is a root-finding method that repeatedly bisects an interval
    and then selects a subinterval in which a root exists.

    Parameters
    ----------
    F : Callable
        Function for which to compute the inverse.
    a : float
        Lower bound of the interval.
    b : float
        Upper bound of the interval.
    x : float | ArrayLike
        Value for which to find the inverse.
    eps : float
        Tolerance for convergence.

    Returns
    -------
    float
        The value of $theta$ such that $F(\\theta) = x$.
    """
    while b - a > eps:
        c = (a + b) / 2.0
        Fc = F(c)
        if Fc <= x:
            a = c
        else:
            b = c
    return (a + b) / 2.0

distribution

This module contains functions for relevant probability distributions

double_factorial_jax(n)

Helper function to compute double factorial:

n!! = n * (n-2) * (n-4) * ... * 1 (if n is odd) or 2 (if n is even).

Parameters:

Name Type Description Default
n int

The integer for which to compute the double factorial.

required

Returns:

Type Description
Array

The value of the double factorial n!!

Source code in src/isotropic/utils/distribution.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def double_factorial_jax(n: int) -> Array:
    """
    Helper function to compute double factorial:

        n!! = n * (n-2) * (n-4) * ... * 1 (if n is odd) or 2 (if n is even).

    Parameters
    ----------
    n : int
        The integer for which to compute the double factorial.

    Returns
    -------
    Array
        The value of the double factorial n!!
    """
    # works for numbers as large as 9**6
    return jnp.where(n <= 0, 1, jnp.prod(jnp.arange(n, 0, -2, dtype=jnp.uint64)))

double_factorial_ratio_jax(num, den)

Computes the ratio of double factorials:

num!! / den!!

Parameters:

Name Type Description Default
num int

The numerator for the double factorial.

required
den int

The denominator for the double factorial.

required

Returns:

Type Description
Array

The value of the ratio num!! / den!!

Notes

For very large numbers, this is numerically stable only when |num - den| ~5.

Source code in src/isotropic/utils/distribution.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def double_factorial_ratio_jax(num: int, den: int) -> Array:
    """
    Computes the ratio of double factorials:

        num!! / den!!

    Parameters
    ----------
    num : int
        The numerator for the double factorial.
    den : int
        The denominator for the double factorial.

    Returns
    -------
    Array
        The value of the ratio num!! / den!!

    Notes
    -----
    For very large numbers, this is numerically stable only when |num - den| ~5.
    """
    warnings.warn(
        "This is an experimental implementation. There are known issues with using this for numbers larger than 2**8",
        UserWarning,
    )
    if abs(num - den) > 4:
        raise ValueError("num and den should be close to each other")
    num_elems = jnp.arange(num, 0, -2, dtype=jnp.uint64)
    den_elems = jnp.arange(den, 0, -2, dtype=jnp.uint64)

    len_diff = den_elems.shape[0] - num_elems.shape[0]

    # Ensure both num_elems and den_elems have the same length
    if len_diff > 0:
        num_elems = jnp.concatenate((num_elems, jnp.ones(len_diff, dtype=jnp.uint64)))
    else:
        den_elems = jnp.concatenate((den_elems, jnp.ones(-len_diff, dtype=jnp.uint64)))

    num_len = num_elems.shape[0]
    den_len = den_elems.shape[0]

    ratio_elems = jnp.zeros(num_len // 2)

    for k in jnp.arange(0, num_len // 2, 1):
        ratio_elems = ratio_elems.at[k].set(
            (num_elems[k] * num_elems[num_len - 1 - k])
            / (den_elems[k] * den_elems[den_len - 1 - k])
        )
    ratio = jnp.prod(ratio_elems)
    return ratio

double_factorial_ratio_scipy(num, den)

Compute the ratio of double factorials num!! / den!!.

Parameters:

Name Type Description Default
num int

The numerator double factorial.

required
den int

The denominator double factorial.

required

Returns:

Type Description
float

The ratio of the double factorials.

Notes

This only works for numbers up to 300.

Raises:

Type Description
ValueError

If num or den is greater than 300.

Source code in src/isotropic/utils/distribution.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def double_factorial_ratio_scipy(num: int, den: int) -> float:
    """
    Compute the ratio of double factorials num!! / den!!.

    Parameters
    ----------
    num : int
        The numerator double factorial.
    den : int
        The denominator double factorial.

    Returns
    -------
    float
        The ratio of the double factorials.

    Notes
    -----
    This only works for numbers up to 300.

    Raises
    ------
    ValueError
        If num or den is greater than 300.
    """
    if num > 300 or den > 300:
        raise ValueError("This only works for numbers up to 300")
    return factorial2(num) / factorial2(den)

normal_integrand(theta, d, sigma)

Computes the function g(θ) that is integrated to calculate F(θ) which is the distribution function for the angle θ in a normal distribution:

$$g(\theta) = \frac{(d-1)!! \times (1-\sigma^2) \times \sin^{d-1}(\theta)}{\pi \times (d-2)!! \times (1+\sigma^2-2\sigma\cos(\theta))^{(d+1)/2}}$$

Parameters:

Name Type Description Default
theta float

Angle parameter(s).

required
d int

Dimension parameter.

required
sigma float

Sigma parameter (should be in valid range).

required

Returns:

Type Description
Array

Value(s) of the function evaluated at theta.

Source code in src/isotropic/utils/distribution.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def normal_integrand(theta: float, d: int, sigma: float) -> Array:
    """
    Computes the function g(θ) that is integrated to calculate F(θ) which is the
    distribution function for the angle θ in a normal distribution:

    $$g(\\theta) = \\frac{(d-1)!! \\times (1-\\sigma^2) \\times \\sin^{d-1}(\\theta)}{\\pi \\times (d-2)!! \\times (1+\\sigma^2-2\\sigma\\cos(\\theta))^{(d+1)/2}}$$

    Parameters
    ----------
    theta : float
        Angle parameter(s).
    d : int
        Dimension parameter.
    sigma : float
        Sigma parameter (should be in valid range).

    Returns
    -------
    Array
        Value(s) of the function evaluated at `theta`.
    """

    # TODO: Convert inputs to JAX arrays once @jit works
    # theta = jnp.asarray(theta)
    # d = jnp.asarray(d, dtype=jnp.int32)
    # sigma = jnp.asarray(sigma)

    # factorial components
    numerator_factorial = factorial2(d - 1)
    denominator_factorial = factorial2(d - 2)

    # Numerator components
    one_minus_sigma_sq = 1.0 - sigma**2
    sin_theta_power = jnp.power(jnp.sin(theta), d - 1)

    # Denominator components
    denominator_base = 1.0 + sigma**2 - 2.0 * sigma * jnp.cos(theta)
    denominator_power = jnp.power(denominator_base, (d + 1) / 2.0)

    # Combine all terms
    numerator = numerator_factorial * one_minus_sigma_sq * sin_theta_power
    denominator = jnp.pi * denominator_factorial * denominator_power

    result = numerator / denominator

    return result

linalg

jax_null_space(A)

Compute the null space of a matrix $A$ using JAX.

Parameters:

Name Type Description Default
A ArrayLike

The input matrix for which to compute the null space.

required

Returns:

Type Description
Array

The basis vectors of the null space of A.

Notes

See also:

Source code in src/isotropic/utils/linalg.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def jax_null_space(A: ArrayLike) -> Array:
    """
    Compute the null space of a matrix $A$ using JAX.

    Parameters
    ----------
    A : ArrayLike
        The input matrix for which to compute the null space.

    Returns
    -------
    Array
        The basis vectors of the null space of A.

    Notes
    ------
    See also:

    - `scipy.linalg.null_space` for the reference implementation in SciPy.
    - [https://github.com/jax-ml/jax/pull/14486](https://github.com/jax-ml/jax/pull/14486) for an old JAX implementation.
    """
    u, s, vh = svd(A, full_matrices=True)
    M, N = u.shape[0], vh.shape[1]
    rcond = jnp.finfo(s.dtype).eps * max(M, N)
    tol = jnp.amax(s, initial=0.0) * rcond
    num = jnp.sum(s > tol, dtype=int)
    Q = vh[num:, :].T.conj()
    return Q

simpsons

This module contains functions for estimating the integral of a function using Simpson's rule.

simpsons_rule(f, a, b, C, tol)

Estimates the integral of a function using Simpson's rule.

Parameters:

Name Type Description Default
f Callable

Function to integrate.

required
a float

Lower limit of integration.

required
b float

Upper limit of integration.

required
C float

Bound on 4th derivative of f.

required
tol float

Desired tolerance for the integral estimate.

required

Returns:

Type Description
Array

Estimated value of the integral.

Source code in src/isotropic/utils/simpsons.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def simpsons_rule(f: Callable, a: float, b: float, C: float, tol: float) -> Array:
    """
    Estimates the integral of a function using Simpson's rule.

    Parameters
    ----------
    f : Callable
        Function to integrate.
    a : float
        Lower limit of integration.
    b : float
        Upper limit of integration.
    C : float
        Bound on 4th derivative of f.
    tol : float
        Desired tolerance for the integral estimate.

    Returns
    -------
    Array
        Estimated value of the integral.
    """
    # Estimate minimum number of intervals needed for given tolerance
    n: int = int(jnp.ceil(((180 * tol) / (C * (b - a) ** 5)) ** (-0.25)))
    if n % 2 == 1:
        n += 1  # Simpson's rule requires even n

    x: Array = jnp.linspace(a, b, n + 1)
    y: Array = f(x)

    S: Array = y[0] + y[-1] + 4 * jnp.sum(y[1:-1:2]) + 2 * jnp.sum(y[2:-2:2])
    integral: Array = (b - a) / (3 * n) * S
    return integral

state_transforms

This module contains functions for transforming the quantum state

add_isotropic_error(Phi_sp, e2, theta_zero)

Add isotropic error to state $\Phi$ given $e_2$ and $\theta_0$

Parameters:

Name Type Description Default
Phi_sp ArrayLike

state to which isotropic error is added (in spherical form)

required
e2 ArrayLike

vector $e_2$ in $S_{d-1}$ with uniform distribution

required
theta_zero float

angle $\theta_0$ in $[0,\pi]$ with density function $f(\theta_0)$

required

Returns:

Type Description
Array

statevector in spherical form after adding isotropic error

Source code in src/isotropic/utils/state_transforms.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def add_isotropic_error(Phi_sp: Array, e2: Array, theta_zero: float) -> Array:
    """
    Add isotropic error to state $\\Phi$ given $e_2$ and $\\theta_0$

    Parameters
    ----------
    Phi_sp : ArrayLike
        state to which isotropic error is added (in spherical form)
    e2 : ArrayLike
        vector $e_2$ in $S_{d-1}$ with uniform distribution
    theta_zero : float
        angle $\\theta_0$ in $[0,\\pi]$ with density function $f(\\theta_0)$

    Returns
    -------
    Array
        statevector in spherical form after adding isotropic error
    """
    Psi_sp = (Phi_sp * jnp.cos(theta_zero)) + (
        (jnp.sum(e2, axis=0)) * jnp.sin(theta_zero)
    )
    return Psi_sp

generate_and_add_isotropic_error(Phi, sigma=0.9, key=random.PRNGKey(0))

Generate and add isotropic error to a given statevector.

Parameters:

Name Type Description Default
Phi ArrayLike

The input statevector as a complex JAX array of dimension $2^n$, for n-qubits.

required
sigma float

The standard deviation for the isotropic error, by default 0.9.

0.9
key ArrayLike

Random key for reproducibility, by default random.PRNGKey(0).

PRNGKey(0)

Returns:

Type Description
Array

The perturbed statevector after adding isotropic error.

Source code in src/isotropic/utils/state_transforms.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def generate_and_add_isotropic_error(
    Phi: ArrayLike, sigma: float = 0.9, key: ArrayLike = random.PRNGKey(0)
) -> Array:
    """
    Generate and add isotropic error to a given statevector.

    Parameters
    ----------
    Phi : ArrayLike
        The input statevector as a complex JAX array of dimension $2^n$, for n-qubits.
    sigma : float, optional
        The standard deviation for the isotropic error, by default 0.9.
    key : ArrayLike, optional
        Random key for reproducibility, by default random.PRNGKey(0).

    Returns
    -------
    Array
        The perturbed statevector after adding isotropic error.
    """
    Phi_spherical = statevector_to_hypersphere(Phi)
    basis = get_orthonormal_basis(
        Phi_spherical
    )  # gives d vectors with d+1 elements each
    _, coeffs = get_e2_coeffs(
        d=basis.shape[0],  # gives d coefficients for the d vectors above
        F_j=F_j,
        key=key,
    )
    e2 = jnp.expand_dims(coeffs, axis=-1) * basis

    def g(theta):
        return normal_integrand(theta, d=Phi_spherical.shape[0], sigma=sigma)

    x = random.uniform(key, shape=(), minval=0, maxval=1)
    theta_zero = get_theta_zero(x=x, g=g)
    Psi_spherical = add_isotropic_error(Phi_spherical, e2=e2, theta_zero=theta_zero)
    Psi = hypersphere_to_statevector(Psi_spherical)
    return Psi

hypersphere_to_statevector(S)

Generate the statevector $\Phi$ from hypersphere $S$

Parameters:

Name Type Description Default
S Array

hypersphere as a real JAX array of dimension $2^{n+1}$ for n qubits

required

Returns:

Type Description
Array

statevector as a complex JAX array of dimension $2^n$

Source code in src/isotropic/utils/state_transforms.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def hypersphere_to_statevector(S: Array) -> Array:
    """
    Generate the statevector $\\Phi$ from hypersphere $S$

    Parameters
    ----------
    S: ArrayLike
        hypersphere as a real JAX array of dimension $2^{n+1}$ for n qubits

    Returns
    -------
    Array
        statevector as a complex JAX array of dimension $2^n$
    """
    Phi = jnp.zeros(int(2 ** (log(S.shape[0], 2) - 1)), dtype=complex)
    for x in range(Phi.shape[0]):
        Phi = Phi.at[x].set(S[2 * x] + 1j * S[2 * x + 1])
    return Phi

statevector_to_hypersphere(Phi)

Generate the hypersphere from statevector $\Phi$

Parameters:

Name Type Description Default
Phi Array

statevector as a complex JAX array of dimension $2^n$, for n-qubits

required

Returns:

Type Description
Array

hypersphere as a real JAX array of dimension $2^{n+1}$

Source code in src/isotropic/utils/state_transforms.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def statevector_to_hypersphere(Phi: Array) -> Array:
    """
    Generate the hypersphere from statevector $\\Phi$

    Parameters
    ----------
    Phi: ArrayLike
        statevector as a complex JAX array of dimension $2^n$, for n-qubits

    Returns
    -------
    Array
        hypersphere as a real JAX array of dimension $2^{n+1}$
    """
    S = jnp.zeros(int(2 ** (log(Phi.shape[0], 2) + 1)), dtype=float)
    for x in range(S.shape[0] // 2):
        S = S.at[2 * x].set(Phi[x].real)
        S = S.at[2 * x + 1].set(Phi[x].imag)
    return S