E: State stabilization with SNAP gates and displacement gates#
The use of feedback GRAPE applied to the Jaynes- Cummings scenario allows us to discover strategies extending the lifetime of a range of quantum states. How- ever, for more complex quantum states such as kitten states, the infidelity becomes significant after just a few dissipative evolution steps in spite of the feedback [cf. Fig. 6(c)]. This raises the question of whether the limited quality of the stabilization is to be attributed to a failure of our feedback-GRAPE learning algorithm to properly explore the control-parameter landscape or, rather, to the limited expressivity of the controls. With the goal of addressing this question, we test our method on the state- stabilization task using a more expressive control scheme.
# ruff: noqa
from feedback_grape.fgrape import optimize_pulse
from feedback_grape.utils.operators import cosm, sinm, identity
from feedback_grape.utils.states import coherent
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
Initialize states#
from feedback_grape.utils.fidelity import ket2dm
N_cav = 30 # number of cavity modes
N_snap = 15
alpha = 2
psi_target = coherent(N_cav, alpha) + coherent(N_cav, -alpha)
# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)
rho_target = ket2dm(psi_target)
# Parity Operator
from feedback_grape.utils.operators import create, destroy
def parity_operator(N_cav):
return jax.scipy.linalg.expm(1j * jnp.pi * (create(N_cav) @ destroy(N_cav)))
# Confirm that the kitten2 state has an even parity
parity_op = parity_operator(N_cav)
parity_check = jnp.isclose(
jnp.trace((parity_op @ rho_target) @ rho_target), 1.0
)
print("Parity check for the kitten2 state:", parity_check)
print("parity_check trace :", jnp.real(jnp.trace(parity_op @ rho_target)))
Parity check for the kitten2 state: True
parity_check trace : 1.0000000000000013
Initialize the parameterized Gates#
def displacement_gate(alphas):
"""Displacement operator for a coherent state."""
alpha_re, alpha_im = alphas
alpha = alpha_re + 1j * alpha_im
gate = jax.scipy.linalg.expm(
alpha * create(N_cav) - alpha.conj() * destroy(N_cav)
)
return gate
def displacement_gate_dag(alphas):
"""Displacement operator for a coherent state."""
alpha_re, alpha_im = alphas
alpha = alpha_re + 1j * alpha_im
gate = (
jax.scipy.linalg.expm(
alpha * create(N_cav) - alpha.conj() * destroy(N_cav)
)
.conj()
.T
)
return gate
def snap_gate(phase_list):
diags = jnp.ones(shape=(N_cav - len(phase_list)))
exponentiated = jnp.exp(1j * jnp.array(phase_list))
diags = jnp.concatenate((exponentiated, diags))
return jnp.diag(diags)
povm_measure_operator (callable):
#
- It should take a measurement outcome and list of params as input
- The measurement outcome options are either 1 or -1
from feedback_grape.utils.operators import create, destroy
def povm_measure_operator(measurement_outcome, params):
"""
POVM for the measurement of the cavity state.
returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
"""
gamma, delta = params
cav_operator = gamma * create(N_cav) @ destroy(N_cav) + delta / 2 * identity(N_cav)
angle = cav_operator
meas_op = jnp.where(
measurement_outcome == 1,
cosm(angle),
sinm(angle),
)
return meas_op
Initialize RNN of choice#
import flax.linen as nn
# You can do whatever you want inside so long as you maintaing the hidden_size and output size shapes
class RNN(nn.Module):
hidden_size: int # number of features in the hidden state
output_size: int # number of features in the output (inferred from the number of parameters) just provide those attributes to the class
@nn.compact
def __call__(self, measurement, hidden_state):
if measurement.ndim == 1:
measurement = measurement.reshape(1, -1)
###############
### Free to change whatever you want below as long as hidden layers have size self.hidden_size
### and output layer has size self.output_size
###############
gru_cell = nn.GRUCell(
features=self.hidden_size,
gate_fn=nn.sigmoid,
activation_fn=nn.tanh,
)
self.make_rng('dropout')
new_hidden_state, _ = gru_cell(hidden_state, measurement)
new_hidden_state = nn.Dropout(rate=0.1, deterministic=False)(
new_hidden_state
)
# this returns the povm_params after linear regression through the hidden state which contains
# the information of the previous time steps and this is optimized to output best povm_params
# new_hidden_state = nn.Dense(features=self.hidden_size)(new_hidden_state)
new_hidden_state = nn.Dense(
features=self.hidden_size,
kernel_init=nn.initializers.glorot_uniform(),
)(new_hidden_state)
new_hidden_state = nn.relu(new_hidden_state)
new_hidden_state = nn.Dense(
features=self.hidden_size,
kernel_init=nn.initializers.glorot_uniform(),
)(new_hidden_state)
new_hidden_state = nn.relu(new_hidden_state)
output = nn.Dense(
features=self.output_size,
kernel_init=nn.initializers.glorot_uniform(),
bias_init=nn.initializers.constant(0.1),
)(new_hidden_state)
output = nn.relu(output)
###############
### Do not change the return statement
###############
return output[0], new_hidden_state
In this notebook, we decreased the convergence threshold and evaluate for num_time_steps = 2#
# Note if tsave = jnp.linspace(0, 1, 1) = [0.0] then the decay is not applied ?
# because the first time step has the original non decayed state
from feedback_grape.fgrape import Decay, Gate
key = jax.random.PRNGKey(42)
# Answer: In documentation, clarify that the initial_params are the params up to the
# point where measurement occurs, compared with other modes where the initial_params
# are the initial params for the entire system for all time steps. --> this is already fixed in
# example a and an explanation of the mechanism may be provided in the docs
measure = Gate(
gate=povm_measure_operator,
initial_params=jax.random.uniform(
key,
shape=(2,), # 2 for gamma and delta
minval=-jnp.pi / 2,
maxval=jnp.pi / 2,
dtype=jnp.float64,
),
measurement_flag=True,
)
displacement = Gate(
gate=displacement_gate,
initial_params=jax.random.uniform(
key,
shape=(2,),
minval=-jnp.pi / 2,
maxval=jnp.pi / 2,
dtype=jnp.float64,
),
measurement_flag=False,
)
snap = Gate(
gate=snap_gate,
initial_params=jax.random.uniform(
key,
shape=(N_snap,),
minval=-jnp.pi / 2,
maxval=jnp.pi / 2,
dtype=jnp.float64,
),
measurement_flag=False,
)
displacement_dag = Gate(
gate=displacement_gate_dag,
initial_params=jax.random.uniform(
key,
shape=(2,),
minval=-jnp.pi / 2,
maxval=jnp.pi / 2,
dtype=jnp.float64,
),
measurement_flag=False,
)
decay = Decay(c_ops=[jnp.sqrt(0.005) * destroy(N_cav)])
system_params = [decay, measure, decay, displacement, snap, displacement_dag]
for reward_weights in [[1.0,1.0],[0.0,1.0]]:
result = optimize_pulse(
U_0=rho_target,
C_target=rho_target,
system_params=system_params,
num_time_steps=2,
reward_weights=reward_weights,
mode="nn",
goal="fidelity",
max_iter=1000,
convergence_threshold=1e-6,
learning_rate=0.01,
evo_type="density",
batch_size=16,
rnn=RNN,
rnn_hidden_size=30,
eval_batch_size=16,
eval_time_steps=5,
)
print(f"reward weights: {reward_weights}\n fidelity@t=2: {result.fidelity_each_timestep[2]}\n fidelity_each_timestep: {jnp.mean(jnp.array(result.fidelity_each_timestep), axis=1)}\n")
reward weights: [1.0, 1.0]
fidelity@t=2: [0.92591808 0.92591749 0.92591797 0.92591771 0.92591646 0.92591733
0.92591811 0.9262644 0.9262644 0.92591815 0.92591816 0.92626412
0.92591816 0.92591692 0.9259181 0.92591816]
fidelity_each_timestep: [1.00000004 0.96135187 0.92598273 0.8934548 0.86350734 0.83590455]
reward weights: [0.0, 1.0]
fidelity@t=2: [0.92582739 0.92583587 0.9258323 0.92579993 0.92583189 0.92583109
0.92583476 0.92457332 0.92457332 0.92583351 0.9258302 0.925819
0.92583449 0.92583398 0.92583212 0.9258347 ]
fidelity_each_timestep: [1.00000004 0.9610333 0.92567237 0.89315248 0.86321307 0.83561829]
result.fidelity_each_timestep
[Array([1.00000004, 1.00000004, 1.00000004, 1.00000004, 1.00000004,
1.00000004, 1.00000004, 1.00000004, 1.00000004, 1.00000004,
1.00000004, 1.00000004, 1.00000004, 1.00000004, 1.00000004,
1.00000004], dtype=float64),
Array([0.96133345, 0.96134252, 0.96133877, 0.96130492, 0.96133823,
0.96133749, 0.96134138, 0.95891675, 0.95891675, 0.96133999,
0.96133644, 0.96132489, 0.96134099, 0.96134045, 0.96133853,
0.96134127], dtype=float64),
Array([0.92582739, 0.92583587, 0.9258323 , 0.92579993, 0.92583189,
0.92583109, 0.92583476, 0.92457332, 0.92457332, 0.92583351,
0.9258302 , 0.925819 , 0.92583449, 0.92583398, 0.92583212,
0.9258347 ], dtype=float64),
Array([0.89319695, 0.8932051 , 0.89320165, 0.89317064, 0.8932011 ,
0.89319945, 0.89320405, 0.89282879, 0.89282879, 0.8932026 ,
0.89319986, 0.89318807, 0.8932038 , 0.89320339, 0.89320146,
0.89320399], dtype=float64),
Array([0.86317697, 0.86318483, 0.86318148, 0.86315166, 0.86318077,
0.86317838, 0.86318383, 0.86345479, 0.86345479, 0.86318216,
0.86317992, 0.86316766, 0.86318359, 0.86318326, 0.86318123,
0.86318373], dtype=float64),
Array([0.83552755, 0.83553515, 0.83553192, 0.8355031 , 0.83553106,
0.83552793, 0.83553416, 0.83624387, 0.83624387, 0.83553233,
0.83553056, 0.83551776, 0.83553401, 0.83553371, 0.83553155,
0.83553408], dtype=float64)]
result.final_state.shape
(16, 30, 30)
from feedback_grape.utils.fidelity import ket2dm
N_cav = 30 # number of cavity modes
N_snap = 15
alpha = 2
psi_target = coherent(N_cav, alpha) + coherent(N_cav, -alpha)
# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)
rho_target = ket2dm(psi_target)
from feedback_grape.utils.fidelity import fidelity
print(
"initial fidelity:",
fidelity(C_target=rho_target, U_final=rho_target, evo_type="density"),
)
for i, state in enumerate(result.final_state):
print(
f"fidelity of state {i}:",
fidelity(C_target=rho_target, U_final=state, evo_type="density"),
)
initial fidelity: 1.0000000353712848
fidelity of state 0: 0.8355275465204998
fidelity of state 1: 0.8355351498978077
fidelity of state 2: 0.8355319238279622
fidelity of state 3: 0.8355030964952107
fidelity of state 4: 0.8355310579595714
fidelity of state 5: 0.8355279280397181
fidelity of state 6: 0.835534161963775
fidelity of state 7: 0.8362438681035523
fidelity of state 8: 0.8362438681035523
fidelity of state 9: 0.8355323320419621
fidelity of state 10: 0.8355305634776103
fidelity of state 11: 0.8355177584837092
fidelity of state 12: 0.8355340086298296
fidelity of state 13: 0.8355337137576336
fidelity of state 14: 0.8355315498567129
fidelity of state 15: 0.8355340755941594