Skip to content

utils

This module contains miscellaneous utilities useful for various steps involved in the isotropic error generation.

isotropic.utils

bisection

This module contains functions for the bisection algorithm to calculate $F^{-1}$.

get_theta(F, a, b, x, eps)

Find the value of theta such that $F(\theta) = x$ using the bisection method.

Parameters:

Name Type Description Default
F Callable

Function for which to compute the inverse.

required
a float

Lower bound of the interval.

required
b float

Upper bound of the interval.

required
x float | ArrayLike

Value for which to find the inverse.

required
eps float

Tolerance for convergence.

required

Returns:

Type Description
float

The value of $theta$ such that $F(\theta) = x$.

Notes

This function assumes that $F$ is an increasing function in the interval $[a, b]$ and that $F(a) \leq x \leq F(b)$. The bisection method is a root-finding method that repeatedly bisects an interval and then selects a subinterval in which a root exists.

Source code in src/isotropic/utils/bisection.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def get_theta(
    F: Callable, a: float, b: float, x: float | ArrayLike, eps: float
) -> float:
    """
    Find the value of theta such that $F(\\theta) = x$ using the bisection method.

    Parameters
    ----------
    F : Callable
        Function for which to compute the inverse.
    a : float
        Lower bound of the interval.
    b : float
        Upper bound of the interval.
    x : float | ArrayLike
        Value for which to find the inverse.
    eps : float
        Tolerance for convergence.

    Returns
    -------
    float
        The value of $theta$ such that $F(\\theta) = x$.

    Notes
    -----
    This function assumes that $F$ is an increasing function in the interval $[a, b]$
    and that $F(a) \\leq x \\leq F(b)$. The bisection method is a root-finding method
    that repeatedly bisects an interval and then selects a subinterval in which a root exists.
    """

    def cond_fn(state):  # numpydoc ignore=GL08
        a, b = state
        return (b - a) > eps

    def body_fn(state):  # numpydoc ignore=GL08
        a, b = state
        c = (a + b) / 2.0
        Fc = F(c)
        a_new = jnp.where(Fc <= x, c, a)
        b_new = jnp.where(Fc <= x, b, c)
        return (a_new, b_new)

    a, b = jax.lax.while_loop(cond_fn, body_fn, (a, b))
    return (a + b) / 2.0

data_generation

This module generates data for Grover's algorithm with isotropic error.

cli()

Command-line interface for data generation.

Source code in src/isotropic/utils/data_generation.py
415
416
417
418
419
420
421
422
def cli():
    """
    Command-line interface for data generation.
    """
    if len(sys.argv) == 1:
        # No arguments provided, show help and exit
        sys.argv.append("--help")
    app()

generate_data(min_qubits, max_qubits, min_iterations, max_iterations, min_sigma=None, max_sigma=None, num_sigma_points=2, sigma_values=None, data_dir='data', random_key=42, n_samples=1)

Generate data for Grover's algorithm with isotropic error and save to xarray files.

Parameters:

Name Type Description Default
min_qubits int

Minimum number of qubits.

required
max_qubits int

Maximum number of qubits.

required
min_iterations int

Minimum number of Grover iterations to simulate.

required
max_iterations int

Maximum number of Grover iterations to simulate.

required
min_sigma float

Minimum sigma value for isotropic error. Required if sigma_values is not provided.

None
max_sigma float

Maximum sigma value for isotropic error. Required if sigma_values is not provided.

None
num_sigma_points int

Number of sigma points to evaluate between min_sigma and max_sigma. Default is 2.

2
sigma_values list[float]

Explicit list of sigma values. If provided, min_sigma/max_sigma/num_sigma_points are ignored.

None
data_dir str

Directory to save the generated data files. Default is "data".

'data'
random_key int

Integer seed for the JAX PRNG root key. Default is 42.

42
n_samples int

Number of independent error samples to average per (sigma, iter) point. Default is 1 (single draw, high variance). Use larger values (e.g. 50) for smooth, reliable estimates of the expected success probability.

1

Returns:

Type Description
None

Saves the generated data to xarray files.

Notes

We implement n_samples averaging via vmap to eliminate single-sample Monte Carlo noise which introduces unexpected variance in the results. This is crucial for generating smooth, monotonic results based on scaling of sigma.

Source code in src/isotropic/utils/data_generation.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def generate_data(
    min_qubits: int,
    max_qubits: int,
    min_iterations: int,
    max_iterations: int,
    min_sigma: Optional[float] = None,
    max_sigma: Optional[float] = None,
    num_sigma_points: int = 2,
    sigma_values: Optional[list[float]] = None,
    data_dir: str = "data",
    random_key: int = 42,
    n_samples: int = 1,
) -> None:
    """
    Generate data for Grover's algorithm with isotropic error and save to xarray files.

    Parameters
    ----------
    min_qubits : int
        Minimum number of qubits.
    max_qubits : int
        Maximum number of qubits.
    min_iterations : int
        Minimum number of Grover iterations to simulate.
    max_iterations : int
        Maximum number of Grover iterations to simulate.
    min_sigma : float, optional
        Minimum sigma value for isotropic error. Required if sigma_values is not provided.
    max_sigma : float, optional
        Maximum sigma value for isotropic error. Required if sigma_values is not provided.
    num_sigma_points : int, optional
        Number of sigma points to evaluate between min_sigma and max_sigma. Default is 2.
    sigma_values : list[float], optional
        Explicit list of sigma values. If provided, min_sigma/max_sigma/num_sigma_points
        are ignored.
    data_dir : str, optional
        Directory to save the generated data files. Default is "data".
    random_key : int, optional
        Integer seed for the JAX PRNG root key. Default is 42.
    n_samples : int, optional
        Number of independent error samples to average per ``(sigma, iter)`` point.
        Default is 1 (single draw, high variance). Use larger values (e.g. 50)
        for smooth, reliable estimates of the expected success probability.

    Returns
    -------
    None
        Saves the generated data to xarray files.

    Notes
    -----
    We implement ``n_samples`` averaging via vmap to eliminate single-sample Monte
    Carlo noise which introduces unexpected variance in the results. This is crucial
    for generating smooth, monotonic results based on scaling of ``sigma``.
    """
    if sigma_values is not None:
        sigmas = jnp.array(sigma_values)
    elif min_sigma is not None and max_sigma is not None:
        sigmas = jnp.linspace(min_sigma, max_sigma, num_sigma_points)
    else:
        raise ValueError("Provide either sigma_values or both min_sigma and max_sigma.")
    if jnp.any(sigmas <= 0) or jnp.any(sigmas >= 1):
        raise ValueError("Sigma values must be in the range (0, 1).")

    os.makedirs(data_dir, exist_ok=True)

    # Loop over qubit counts (cannot vmap: each num_qubits yields different
    # array shapes, e.g. statevector length 2^n).
    for num_qubits in range(min_qubits, max_qubits + 1):
        # TODO: change hardcoded grover oracle
        oracle = jnp.eye(2**num_qubits).tolist()
        oracle[3][3] = -1
        U_w = Operator(oracle)
        marked_item = "0" * (num_qubits - 2) + "11"

        # Pre-compute all statevectors via Qiskit (not JAX-traceable).
        iterations_range = list(range(min_iterations, max_iterations + 1))
        statevectors = []
        # model: per-gate error; effective sigma after k gates is sigma^k
        total_gate_counts = []  # required for scaling sigma
        error_free_probs = []
        for iterations in iterations_range:
            circuit = get_grover_circuit(num_qubits, U_w, iterations)
            total_gate_count = sum(
                v for op, v in circuit.count_ops().items() if op != "barrier"
            )
            total_gate_counts.append(total_gate_count)
            sv = Statevector(circuit)
            statevectors.append(jnp.array(sv.data))
            error_free_probs.append(sv.probabilities_dict()[marked_item])

        Phi_batch = jnp.stack(statevectors)  # (num_iters, 2^n) complex
        error_free_batch = jnp.array(error_free_probs)
        gate_counts_batch = jnp.array(total_gate_counts)

        # Batch JAX computation: vmap over iterations and sigmas
        results = run_experiment_batch(
            Phi_batch=Phi_batch,
            marked_item=marked_item,
            sigmas=sigmas,
            gate_counts=gate_counts_batch,
            random_key=random_key,
            n_samples=n_samples,
        )
        # results shape: (num_iterations, num_sigma_points)

        # Save per-iteration xarray files
        for i, iterations in enumerate(iterations_range):
            error_success = jnp.append(results[i], error_free_batch[i])
            data = xr.Dataset(
                {
                    "success_probability": (["sigma"], error_success),
                    "iterations": iterations,
                    "gate_count": total_gate_counts[i],
                },
                coords={
                    "sigma": jnp.append(sigmas, jnp.array([1.0])),
                },
                attrs={
                    "num_qubits": num_qubits,
                    "marked_item": marked_item,
                },
            )
            data.to_netcdf(
                f"{data_dir}/grover_{num_qubits}_qubits_{iterations}_iterations.nc"
            )

run_experiment_batch(Phi_batch, marked_item, sigmas, gate_counts, random_key=42, n_samples=1)

Run batched experiment: vmap over iterations and sigmas.

For a fixed num_qubits all statevectors share the same shape, so the pure-JAX computation is vmapped over the iteration dimension (Phi_batch) and the sigma dimension in a single JIT-compiled XLA program.

Parameters:

Name Type Description Default
Phi_batch Array

Stack of complex statevectors, shape (num_iterations, 2**n).

required
marked_item str

The marked item to search for in binary string format.

required
sigmas Array

Sigma values to evaluate, shape (num_sigma_points,).

required
gate_counts Array

Total gate counts excluding barriers for each iteration, shape (num_iterations,).

required
random_key int

Integer seed for the JAX PRNG root key. Default is 42.

42
n_samples int

Number of independent error samples to average per (sigma, iter) point. Default is 1 (single draw, high variance). Use larger values (e.g. 50) for smooth, reliable estimates of the expected success probability.

1

Returns:

Type Description
Array

Success probabilities, shape (num_iterations, num_sigma_points).

Notes

We use an error model of per-gate error, implying the effective sigma after k gates is sigma^k. The gate_counts array is used to scale the sigma values for each iteration accordingly.

We implement n_samples averaging via vmap to eliminate single-sample Monte Carlo noise which introduces unexpected variance in the results. This is crucial for generating smooth, monotonic results based on scaling of sigma.

Source code in src/isotropic/utils/data_generation.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
def run_experiment_batch(
    Phi_batch: Array,
    marked_item: str,
    sigmas: Array,
    gate_counts: Array,
    random_key: int = 42,
    n_samples: int = 1,
) -> Array:
    """
    Run batched experiment: vmap over iterations and sigmas.

    For a fixed num_qubits all statevectors share the same shape, so the
    pure-JAX computation is vmapped over the iteration dimension (Phi_batch)
    and the sigma dimension in a single JIT-compiled XLA program.

    Parameters
    ----------
    Phi_batch : Array
        Stack of complex statevectors, shape ``(num_iterations, 2**n)``.
    marked_item : str
        The marked item to search for in binary string format.
    sigmas : Array
        Sigma values to evaluate, shape ``(num_sigma_points,)``.
    gate_counts : Array
        Total gate counts excluding barriers for each iteration, shape ``(num_iterations,)``.
    random_key : int, optional
        Integer seed for the JAX PRNG root key. Default is 42.
    n_samples : int, optional
        Number of independent error samples to average per ``(sigma, iter)`` point.
        Default is 1 (single draw, high variance). Use larger values (e.g. 50)
        for smooth, reliable estimates of the expected success probability.

    Returns
    -------
    Array
        Success probabilities, shape ``(num_iterations, num_sigma_points)``.

    Notes
    -----
    We use an error model of per-gate error, implying the effective sigma after k gates is sigma^k.
    The gate_counts array is used to scale the sigma values for each iteration accordingly.

    We implement ``n_samples`` averaging via vmap to eliminate single-sample Monte
    Carlo noise which introduces unexpected variance in the results. This is crucial
    for generating smooth, monotonic results based on scaling of ``sigma``.
    """
    # Convert all statevectors to hypersphere (vmap over iterations)
    Phi_sp_batch = jax.vmap(statevector_to_hypersphere)(Phi_batch)

    # Orthonormal basis for each iteration (vmap)
    basis_batch = jax.vmap(get_orthonormal_basis)(Phi_sp_batch)

    d_phi = Phi_sp_batch.shape[1]
    log_factorial_ratio = jnp.log(double_factorial_ratio(d_phi - 2, d_phi - 3))
    marked_index = int(marked_item, 2)

    key = random.PRNGKey(random_key)  # root PRNG key; all randomness derives from this
    sigma_keys = random.split(key, num=sigmas.shape[0])  # one key per sigma

    # Delegate to the stable module-level compiled function so the JIT cache
    # is reused across all calls with the same (marked_index, n_samples, array shapes).
    results = _run_batch_compiled(
        Phi_sp_batch,
        basis_batch,
        sigmas,
        gate_counts,
        sigma_keys,
        log_factorial_ratio,
        marked_index,
        n_samples,
    )
    # Shape: (num_sigmas, num_iterations)

    return results.T  # (num_iterations, num_sigmas)

distribution

This module contains functions for relevant probability distributions.

double_factorial_jax(n)

Helper function to compute double factorial.

Parameters:

Name Type Description Default
n int

The integer for which to compute the double factorial.

required

Returns:

Type Description
Array

The value of the double factorial n!! as a JAX array.

Notes

The double factorial is defined as:

n!! = n * (n-2) * (n-4) * ... * 1 (if n is odd) or 2 (if n is even).
Source code in src/isotropic/utils/distribution.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def double_factorial_jax(n: int) -> Array:
    """
    Helper function to compute double factorial.

    Parameters
    ----------
    n : int
        The integer for which to compute the double factorial.

    Returns
    -------
    Array
        The value of the double factorial n!! as a JAX array.

    Notes
    -----
    The double factorial is defined as:

        n!! = n * (n-2) * (n-4) * ... * 1 (if n is odd) or 2 (if n is even).
    """
    # works for numbers as large as 9**6
    return jnp.where(n <= 0, 1, jnp.prod(jnp.arange(n, 0, -2, dtype=jnp.uint64)))

double_factorial_ratio(num, den)

Compute the ratio of double factorials num!! / den!! .

Parameters:

Name Type Description Default
num int

The numerator double factorial.

required
den int

The denominator double factorial.

required

Returns:

Type Description
float

The ratio num!! / den!! .

Source code in src/isotropic/utils/distribution.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def double_factorial_ratio(num: int, den: int) -> float:
    """
    Compute the ratio of double factorials num!! / den!! .

    Parameters
    ----------
    num : int
        The numerator double factorial.
    den : int
        The denominator double factorial.

    Returns
    -------
    float
        The ratio num!! / den!! .
    """
    num_list = list(range(num, 0, -2))
    den_list = list(range(den, 0, -2))
    # make sure both lists are the same length by padding the shorter one with 1s
    max_len = max(len(num_list), len(den_list))
    num_list += [1] * (max_len - len(num_list))
    den_list += [1] * (max_len - len(den_list))
    num_array = np.array(num_list)
    den_array = np.array(den_list)

    def ratio(a, b):  # numpydoc ignore=GL08
        return a / b

    result_array = np.vectorize(ratio)(num_array, den_array)
    return np.prod(result_array)

normal_integrand(theta, d, sigma, log_factorial_ratio=None)

Compute the function g(θ).

Parameters:

Name Type Description Default
theta float

Angle parameter(s).

required
d int

Dimension parameter.

required
sigma float

Sigma parameter (should be in valid range).

required
log_factorial_ratio float

Precomputed value of log((d-1)!! / (d-2)!!). When None (the default) it is computed on every call. Passing it in avoids redundant work when the integrand is evaluated many times for the same d (e.g. during numerical integration).

None

Returns:

Type Description
Array

Value(s) of the function evaluated at theta.

Notes

g(θ) is integrated to calculate F(θ) which is the distribution function for the angle θ in a normal distribution:

$$g(\theta) = \frac{(d-1)!! \times (1-\sigma^2) \times \sin^{d-1}(\theta)}{\pi \times (d-2)!! \times (1+\sigma^2-2\sigma\cos(\theta))^{(d+1)/2}}$$.

Source code in src/isotropic/utils/distribution.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def normal_integrand(
    theta: float, d: int, sigma: float, log_factorial_ratio: float | None = None
) -> Array:
    """
    Compute the function g(θ).

    Parameters
    ----------
    theta : float
        Angle parameter(s).
    d : int
        Dimension parameter.
    sigma : float
        Sigma parameter (should be in valid range).
    log_factorial_ratio : float, optional
        Precomputed value of ``log((d-1)!! / (d-2)!!)``. When ``None``
        (the default) it is computed on every call. Passing it in avoids
        redundant work when the integrand is evaluated many times for the
        same ``d`` (e.g. during numerical integration).

    Returns
    -------
    Array
        Value(s) of the function evaluated at `theta`.

    Notes
    -----
    g(θ) is integrated to calculate F(θ) which is the
    distribution function for the angle θ in a normal distribution:

    $$g(\\theta) = \\frac{(d-1)!! \\times (1-\\sigma^2) \\times \\sin^{d-1}(\\theta)}{\\pi \\times (d-2)!! \\times (1+\\sigma^2-2\\sigma\\cos(\\theta))^{(d+1)/2}}$$.
    """

    # Compute in log-space to avoid numerical underflow for large d,
    # where sin^(d-1)(theta) and denominator^((d+1)/2) both underflow to 0.
    if log_factorial_ratio is None:
        log_factorial_ratio = jnp.log(double_factorial_ratio(d - 1, d - 2))

    denominator_base = 1.0 + sigma**2 - 2.0 * sigma * jnp.cos(theta)

    log_result = (
        log_factorial_ratio
        + jnp.log(1.0 - sigma**2)
        + (d - 1) * jnp.log(jnp.sin(theta))
        - jnp.log(jnp.pi)
        - ((d + 1) / 2.0) * jnp.log(denominator_base)
    )

    return jnp.exp(log_result)

linalg

Linear algebra utilities for isotropic error analysis, implemented using JAX.

jax_null_space(A)

Compute the null space of a matrix $A$ using JAX.

Parameters:

Name Type Description Default
A ArrayLike

The input matrix for which to compute the null space.

required

Returns:

Type Description
Array

The basis vectors of the null space of A.

Notes

See also:

Source code in src/isotropic/utils/linalg.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def jax_null_space(A: ArrayLike) -> Array:
    """
    Compute the null space of a matrix $A$ using JAX.

    Parameters
    ----------
    A : ArrayLike
        The input matrix for which to compute the null space.

    Returns
    -------
    Array
        The basis vectors of the null space of A.

    Notes
    -----
    See also:

    - `scipy.linalg.null_space` for the reference implementation in SciPy.
    - [https://github.com/jax-ml/jax/pull/14486](https://github.com/jax-ml/jax/pull/14486) for an old JAX implementation.
    """
    u, s, vh = svd(A, full_matrices=True)
    M, N = u.shape[0], vh.shape[1]
    rcond = jnp.finfo(s.dtype).eps * max(M, N)
    tol = jnp.amax(s, initial=0.0) * rcond
    num = jnp.sum(s > tol, dtype=int)
    Q = vh[num:, :].T.conj()
    return Q

simpsons

This module contains functions for estimating the integral of a function using Simpson's rule.

simpsons_rule(f, a, b, C, tol)

Estimate the integral of a function using Simpson's rule.

Parameters:

Name Type Description Default
f Callable

Function to integrate.

required
a float

Lower limit of integration.

required
b float

Upper limit of integration.

required
C float

Bound on 4th derivative of f.

required
tol float

Desired tolerance for the integral estimate.

required

Returns:

Type Description
Array

Estimated value of the integral.

Source code in src/isotropic/utils/simpsons.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def simpsons_rule(f: Callable, a: float, b: float, C: float, tol: float) -> Array:
    """
    Estimate the integral of a function using Simpson's rule.

    Parameters
    ----------
    f : Callable
        Function to integrate.
    a : float
        Lower limit of integration.
    b : float
        Upper limit of integration.
    C : float
        Bound on 4th derivative of f.
    tol : float
        Desired tolerance for the integral estimate.

    Returns
    -------
    Array
        Estimated value of the integral.
    """
    # Estimate minimum number of intervals needed for given tolerance
    # n: int = (jnp.ceil(((180 * tol) / (C * (b - a) ** 5)) ** (-0.25))).astype(int)
    # if n % 2 == 1:
    #     n += 1  # Simpson's rule requires even n
    # derived from the worst-case interval [0, π]. Since C=1 and tol=1e-15,
    # this gives n ≈ 36,100
    # Making n fixed allows jax compilation

    n = 36100
    x: Array = jnp.linspace(a, b, n + 1)
    y: Array = f(x)

    S: Array = y[0] + y[-1] + 4 * jnp.sum(y[1:-1:2]) + 2 * jnp.sum(y[2:-2:2])
    integral: Array = (b - a) / (3 * n) * S
    return integral

state_transforms

This module contains functions for transforming the quantum state.

add_isotropic_error(Phi_sp, e2, theta_zero)

Add isotropic error to state $\Phi$ given $e_2$ and $\theta_0$.

Parameters:

Name Type Description Default
Phi_sp ArrayLike

State to which isotropic error is added (in spherical form).

required
e2 ArrayLike

Vector $e_2$ in $S_{d-1}$ with uniform distribution.

required
theta_zero float

Angle $\theta_0$ in $[0,\pi]$ with density function $f(\theta_0)$.

required

Returns:

Type Description
Array

Statevector in spherical form after adding isotropic error.

Source code in src/isotropic/utils/state_transforms.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
@jax.jit
def add_isotropic_error(Phi_sp: Array, e2: Array, theta_zero: float) -> Array:
    """
    Add isotropic error to state $\\Phi$ given $e_2$ and $\\theta_0$.

    Parameters
    ----------
    Phi_sp : ArrayLike
        State to which isotropic error is added (in spherical form).
    e2 : ArrayLike
        Vector $e_2$ in $S_{d-1}$ with uniform distribution.
    theta_zero : float
        Angle $\\theta_0$ in $[0,\\pi]$ with density function $f(\\theta_0)$.

    Returns
    -------
    Array
        Statevector in spherical form after adding isotropic error.
    """
    Psi_sp = (Phi_sp * jnp.cos(theta_zero)) + (
        (jnp.sum(e2, axis=0)) * jnp.sin(theta_zero)
    )
    return Psi_sp

generate_and_add_isotropic_error(Phi, sigma=0.9, key=random.PRNGKey(0))

Generate and add isotropic error to a given statevector.

Parameters:

Name Type Description Default
Phi ArrayLike

The input statevector as a complex JAX array of dimension $2^n$, for n-qubits.

required
sigma float

The standard deviation for the isotropic error, by default 0.9.

0.9
key ArrayLike

Random key for reproducibility, by default random.PRNGKey(0).

PRNGKey(0)

Returns:

Type Description
Array

The perturbed statevector after adding isotropic error.

Source code in src/isotropic/utils/state_transforms.py
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def generate_and_add_isotropic_error(
    Phi: ArrayLike,
    sigma: float = 0.9,
    key: ArrayLike = random.PRNGKey(0),
) -> Array:
    """
    Generate and add isotropic error to a given statevector.

    Parameters
    ----------
    Phi : ArrayLike
        The input statevector as a complex JAX array of dimension $2^n$, for n-qubits.
    sigma : float, optional
        The standard deviation for the isotropic error, by default 0.9.
    key : ArrayLike, optional
        Random key for reproducibility, by default random.PRNGKey(0).

    Returns
    -------
    Array
        The perturbed statevector after adding isotropic error.
    """

    Phi_spherical = statevector_to_hypersphere(Phi)
    basis = get_orthonormal_basis(
        Phi_spherical
    )  # gives d vectors with d+1 elements each
    theta, coeffs = get_e2_coeffs(
        d=basis.shape[0],  # gives d coefficients for the d vectors above
        key=key,
    )
    e2 = jnp.expand_dims(coeffs, axis=-1) * basis

    d = Phi_spherical.shape[0]
    log_factorial_ratio = jnp.log(double_factorial_ratio(d - 1, d - 2))

    def g(theta):  # numpydoc ignore=GL08
        return normal_integrand(
            theta, d=d, sigma=sigma, log_factorial_ratio=log_factorial_ratio
        )

    x = random.uniform(key, shape=(), minval=0, maxval=1)
    theta_zero = get_theta_zero(x=x, g=g)
    Psi_spherical = add_isotropic_error(Phi_spherical, e2=e2, theta_zero=theta_zero)
    Psi = hypersphere_to_statevector(Psi_spherical)
    return Psi

hypersphere_to_statevector(S)

Generate the statevector $\Phi$ from hypersphere $S$.

Parameters:

Name Type Description Default
S ArrayLike

Hypersphere as a real JAX array of dimension $2^{n+1}$ for n qubits.

required

Returns:

Type Description
Array

Statevector as a complex JAX array of dimension $2^n$.

Source code in src/isotropic/utils/state_transforms.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@jax.jit
def hypersphere_to_statevector(S: Array) -> Array:
    """
    Generate the statevector $\\Phi$ from hypersphere $S$.

    Parameters
    ----------
    S : ArrayLike
        Hypersphere as a real JAX array of dimension $2^{n+1}$ for n qubits.

    Returns
    -------
    Array
        Statevector as a complex JAX array of dimension $2^n$.
    """

    Phi = S[0::2] + 1j * S[1::2]
    return Phi

statevector_to_hypersphere(Phi)

Generate the hypersphere from statevector $\Phi$.

Parameters:

Name Type Description Default
Phi ArrayLike

Statevector as a complex JAX array of dimension $2^n$, for n-qubits.

required

Returns:

Type Description
Array

Hypersphere as a real JAX array of dimension $2^{n+1}$.

Source code in src/isotropic/utils/state_transforms.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
@jax.jit
def statevector_to_hypersphere(Phi: Array) -> Array:
    """
    Generate the hypersphere from statevector $\\Phi$.

    Parameters
    ----------
    Phi : ArrayLike
        Statevector as a complex JAX array of dimension $2^n$, for n-qubits.

    Returns
    -------
    Array
        Hypersphere as a real JAX array of dimension $2^{n+1}$.
    """
    S = jnp.column_stack([Phi.real, Phi.imag]).ravel()
    return S