Stabilization of 5 qubits using a RNN control model#
In this example, we train a RNN model to stabilize 5 qubits against decay and decoherence using 4 generalized measurements plus a unitary per error correction cycle and compare its performance to the uncontrolled case. Mathematically, the model is powerful enough to express Laflamme code. To demonstrate GPU acceleration, we chose a larger batch size and fewer training iterations. The code takes 16 min on a NVIDIA Quadro RTX 6000 node.
# ruff: noqa
import os
os.sys.path.append("../../../..")
# ruff: noqa
from feedback_grape.fgrape import optimize_pulse
from feedback_grape.fgrape import Gate, Decay # type: ignore
from feedback_grape.utils.states import basis # type: ignore
from feedback_grape.utils.fidelity import ket2dm # type: ignore
from feedback_grape.utils.operators import sigmaz, sigmam # type: ignore
from feedback_grape.utils.modeling import embed # type: ignore
import jax
import jax.numpy as jnp
# Experimental parameters
N_qubits = 5
N_meas = 4
gamma_z = 0.1 # dephasing rate
gamma_m = 0.1 # decay rate
psi_target = (basis(2**N_qubits, 0) + basis(2**N_qubits, 2**N_qubits - 1)) / jnp.sqrt(2) # (|11...1> + |00...0>)/sqrt(2)
rho_target = ket2dm(psi_target)
# Model parameters
rnn_hidden_size = 16
# Training parameters
num_time_steps = 3
reward_weights = [1.0]*num_time_steps
convergence_threshold = None # no early stopping
N_training_iterations = 250 # relatively small number of iterations for demo purposes
batch_size = 128
learning_rate = 3e-3
# Evaluation parameters
eval_time_steps = 10
eval_batch_size = 128
# All the operators we need
def generate_hermitian(params, dim):
assert len(params) == dim**2, "Number of real parameters must be dim^2 for an NxN Hermitian matrix."
X = params.reshape(dim, dim)
# Real part: take the symmetric part of X
Re = 0.5 * (X + X.T)
# Imag part: take the antisymmetric part of X
Im = 0.5 * (X - X.T)
# Build Hermitian matrix: Herm = Re + 1j * Im
H = Re + 1j * Im
return H
def generate_unitary(params, dim):
assert len(params) == dim**2, "Number of real parameters must be dim^2 for an NxN unitary matrix."
H = generate_hermitian(params, dim)
return jax.scipy.linalg.expm(-1j * H)
def generate_povm(measurement_outcome, params, dim):
"""
Generate a 2-outcome POVM elements M_0 and M_1 for a system with Hilbert space dimension dim.
This function should parametrize all such POVMs up to unitary equivalence, i.e., M_i -> U M_i for some unitary U.
I.e it parametrizes all pairs (M_0, M_1) such that M_0 M_0† + M_1 M_1† = I.
measurement_outcome: 0 or 1, indicating which POVM element to generate.
params: list of dim^(dim+1) real parameters.
when measurement_outcome == 1:
M_1 = S D S†
when measurement_outcome == -1:
M_0 = S (I - D) S†
where S is a unitary parametrized by dim^2 parameters, and D is a diagonal matrix with eigenvalues parametrized by dim parameters.
"""
assert len(params) == dim * (dim + 1), "Number of real parameters must be N * (N + 1) for an NxN POVM element."
S = generate_unitary(params[0:dim*dim], dim=dim) # All parameters for unitary
d_vec = jnp.astype(jnp.square(jnp.sin( params[dim*dim:dim*(dim+1)] )), jnp.complex128) # Last #dim parameters for eigenvalues
d_vec = 1e-8 + (1 - 2e-8) * d_vec # Avoid exactly 0 or 1 eigenvalues for numerical stability
# jnp.multiply is fast way to "matrix @ diagonal matrix" multiplication, especially for large matrices
return jnp.where(measurement_outcome == 1,
jnp.multiply(S, d_vec) @ S.conj().T,
jnp.multiply(S, jnp.sqrt(1 - jnp.square(d_vec))) @ S.conj().T
)
def jump_ops(N_qubits, gamma_z, gamma_m):
return [
gamma_z**0.5 * embed(sigmaz(), 1, (2**j, 2, 2**(N_qubits - j - 1)))
for j in range(N_qubits)
] + [
gamma_m**0.5 * embed(sigmam(), 1, (2**j, 2, 2**(N_qubits - j - 1)))
for j in range(N_qubits)
]
# All the gates we need
key = jax.random.PRNGKey(2)
dim = 2**N_qubits # Hilbert space dimension
decay_gate = Decay(
c_ops=jump_ops(N_qubits, gamma_z, gamma_m),
)
N_unitary_params = dim**2
U_gate = Gate(
gate=lambda params: generate_unitary(params, dim=dim),
initial_params = jax.random.uniform(key, (N_unitary_params,), minval=0.0, maxval=1.0),
measurement_flag = False
)
N_msmt_params = dim * (dim + 1)
msmt_gate = Gate(
gate=lambda msmt, params: generate_povm(msmt, params, dim=dim),
initial_params = jax.random.uniform(key, (N_msmt_params,), minval=0.0, maxval=2*jnp.pi),
measurement_flag = True
)
system_params = [decay_gate] + [msmt_gate]*N_meas + [U_gate]
# Train RNN
result = optimize_pulse(
U_0=rho_target,
C_target=rho_target,
system_params=system_params,
num_time_steps=num_time_steps,
reward_weights=reward_weights,
mode="nn",
goal="fidelity",
max_iter=N_training_iterations,
convergence_threshold=convergence_threshold,
learning_rate=learning_rate,
evo_type="density",
batch_size=batch_size,
eval_batch_size=eval_batch_size,
eval_time_steps=eval_time_steps,
progress=True,
rnn_hidden_size=rnn_hidden_size,
)
print(f"Iterations: {result.iterations}")
print(f"Average fidelity over 5 timesteps: {jnp.mean(jnp.mean(jnp.array(result.fidelity_each_timestep), axis=1)[1:6]):.2f}")
print(f"Fidelity across {eval_time_steps} timesteps: \n{jnp.mean(jnp.array(result.fidelity_each_timestep), axis=1)}")
# Expected output:
# Iterations: 250
# Average fidelity over 5 timesteps: 0.73
# Fidelity across 10 timesteps:
# [1. 0.75307206 0.78865803 0.74860818 0.69591505 0.6625737 0.63545622 0.61818136 0.58743217 0.54756492 0.51893844]
Iteration 10, Loss: -0.414569, T=33s, eta=744s
Iteration 20, Loss: -0.417313, T=68s, eta=749s
Iteration 30, Loss: -0.443282, T=102s, eta=730s
Iteration 40, Loss: -0.386187, T=136s, eta=702s
Iteration 50, Loss: -0.016632, T=170s, eta=672s
Iteration 60, Loss: 0.313386, T=204s, eta=641s
Iteration 70, Loss: 0.723101, T=238s, eta=609s
Iteration 80, Loss: 1.055035, T=272s, eta=576s
Iteration 90, Loss: 1.007991, T=306s, eta=542s
Iteration 100, Loss: 1.147771, T=340s, eta=509s
Iteration 110, Loss: 1.695535, T=375s, eta=476s
Iteration 120, Loss: 1.546634, T=409s, eta=443s
Iteration 130, Loss: 1.500587, T=443s, eta=409s
Iteration 140, Loss: 1.857865, T=477s, eta=375s
Iteration 150, Loss: 1.472231, T=511s, eta=342s
Iteration 160, Loss: 2.462461, T=545s, eta=308s
Iteration 170, Loss: 2.093630, T=579s, eta=274s
Iteration 180, Loss: 1.545639, T=613s, eta=240s
Iteration 190, Loss: 2.547595, T=647s, eta=206s
Iteration 200, Loss: 2.031090, T=682s, eta=173s
Iteration 210, Loss: 2.501371, T=716s, eta=139s
Iteration 220, Loss: 2.742482, T=750s, eta=105s
Iteration 230, Loss: 2.828192, T=784s, eta=71s
Iteration 240, Loss: 2.745591, T=817s, eta=37s
Iterations: 250
Average fidelity over 5 timesteps: 0.73
Fidelity across 10 timesteps: [1. 0.75307206 0.78865803 0.74860818 0.69591505 0.6625737
0.63545622 0.61818136 0.58743217 0.54756492 0.51893844]
# Simulation of the uncontrolled dynamics
import dynamiqs as dq
from feedback_grape.utils.fidelity import fidelity
dq_result = dq.mesolve(
H=jnp.zeros((dim, dim)), # No Hamiltonian
jump_ops=jump_ops(N_qubits, gamma_z, gamma_m),
rho0=rho_target,
tsave=jnp.linspace(0, eval_time_steps, eval_time_steps + 1),
)
rho_t = dq_result.states.to_jax()
fidelities = jnp.array([fidelity(C_target=rho_target, U_final=rho_ti, evo_type="density") for rho_ti in rho_t])
print(f"Fidelity of uncontrolled dynamics across {eval_time_steps} timesteps: \n{fidelities}")
print(f"Avergage fidelity over first 5 timesteps: {jnp.mean(fidelities[1:6]):.2f}")
# Expected output: Fidelity of uncontrolled dynamics across 10 timesteps:
# [1. 0.54488266 0.38306071 0.31783434 0.28817987 0.27384715 0.26740115 0.26570994 0.26725884 0.27118467 0.2769193 ]
# Avergage fidelity over first 5 timesteps: 0.36
Fidelity of uncontrolled dynamics across 10 timesteps: [1. 0.54488266 0.38306071 0.31783434 0.28817987 0.27384715
0.26740115 0.26570994 0.26725884 0.27118467 0.2769193 ]
Avergage fidelity over first 5 timesteps: 0.36