E: State stabilization with SNAP gates and displacement gates#

The use of feedback GRAPE applied to the Jaynes- Cummings scenario allows us to discover strategies extending the lifetime of a range of quantum states. How- ever, for more complex quantum states such as kitten states, the infidelity becomes significant after just a few dissipative evolution steps in spite of the feedback [cf. Fig. 6(c)]. This raises the question of whether the limited quality of the stabilization is to be attributed to a failure of our feedback-GRAPE learning algorithm to properly explore the control-parameter landscape or, rather, to the limited expressivity of the controls. With the goal of addressing this question, we test our method on the state- stabilization task using a more expressive control scheme.

# ruff: noqa
from feedback_grape.fgrape import optimize_pulse
from feedback_grape.utils.operators import cosm, sinm, identity
from feedback_grape.utils.states import coherent
import jax.numpy as jnp
import jax

jax.config.update("jax_enable_x64", True)

Initialize states#

from feedback_grape.utils.fidelity import ket2dm

N_cav = 30  # number of cavity modes
N_snap = 15

alpha = 2
psi_target = coherent(N_cav, alpha) + coherent(N_cav, -alpha)

# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = ket2dm(psi_target)
# Parity Operator
from feedback_grape.utils.operators import create, destroy


def parity_operator(N_cav):
    return jax.scipy.linalg.expm(1j * jnp.pi * (create(N_cav) @ destroy(N_cav)))
# Confirm that the kitten2 state has an even parity
parity_op = parity_operator(N_cav)
parity_check = jnp.isclose(
    jnp.trace((parity_op @ rho_target) @ rho_target), 1.0
)
print("Parity check for the kitten2 state:", parity_check)
print("parity_check trace :", jnp.real(jnp.trace(parity_op @ rho_target)))
Parity check for the kitten2 state: True
parity_check trace : 1.0000000000000013

Initialize the parameterized Gates#

def displacement_gate(alphas):
    """Displacement operator for a coherent state."""
    alpha_re, alpha_im = alphas
    alpha = alpha_re + 1j * alpha_im
    gate = jax.scipy.linalg.expm(
        alpha * create(N_cav) - alpha.conj() * destroy(N_cav)
    )
    return gate


def displacement_gate_dag(alphas):
    """Displacement operator for a coherent state."""
    alpha_re, alpha_im = alphas
    alpha = alpha_re + 1j * alpha_im
    gate = (
        jax.scipy.linalg.expm(
            alpha * create(N_cav) - alpha.conj() * destroy(N_cav)
        )
        .conj()
        .T
    )
    return gate
def snap_gate(phase_list):
    diags = jnp.ones(shape=(N_cav - len(phase_list)))
    exponentiated = jnp.exp(1j * jnp.array(phase_list))
    diags = jnp.concatenate((exponentiated, diags))
    return jnp.diag(diags)

povm_measure_operator (callable):
#

- It should take a measurement outcome and list of params as input
- The measurement outcome options are either 1 or -1
from feedback_grape.utils.operators import create, destroy


def povm_measure_operator(measurement_outcome, params):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    gamma, delta = params
    cav_operator = gamma * create(N_cav) @ destroy(N_cav) + delta / 2 * identity(N_cav)
    angle = cav_operator
    meas_op = jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
    return meas_op

Initialize RNN of choice#

import flax.linen as nn


# You can do whatever you want inside so long as you maintaing the hidden_size and output size shapes
class RNN(nn.Module):
    hidden_size: int  # number of features in the hidden state
    output_size: int  # number of features in the output (inferred from the number of parameters) just provide those attributes to the class

    @nn.compact
    def __call__(self, measurement, hidden_state):

        if measurement.ndim == 1:
            measurement = measurement.reshape(1, -1)

        ###############
        ### Free to change whatever you want below as long as hidden layers have size self.hidden_size
        ### and output layer has size self.output_size
        ###############

        gru_cell = nn.GRUCell(
            features=self.hidden_size,
            gate_fn=nn.sigmoid,
            activation_fn=nn.tanh,
        )
        self.make_rng('dropout')

        new_hidden_state, _ = gru_cell(hidden_state, measurement)
        new_hidden_state = nn.Dropout(rate=0.1, deterministic=False)(
            new_hidden_state
        )
        # this returns the povm_params after linear regression through the hidden state which contains
        # the information of the previous time steps and this is optimized to output best povm_params
        # new_hidden_state = nn.Dense(features=self.hidden_size)(new_hidden_state)
        new_hidden_state = nn.Dense(
            features=self.hidden_size,
            kernel_init=nn.initializers.glorot_uniform(),
        )(new_hidden_state)
        new_hidden_state = nn.relu(new_hidden_state)
        new_hidden_state = nn.Dense(
            features=self.hidden_size,
            kernel_init=nn.initializers.glorot_uniform(),
        )(new_hidden_state)
        new_hidden_state = nn.relu(new_hidden_state)
        output = nn.Dense(
            features=self.output_size,
            kernel_init=nn.initializers.glorot_uniform(),
            bias_init=nn.initializers.constant(0.1),
        )(new_hidden_state)
        output = nn.relu(output)

        ###############
        ### Do not change the return statement
        ###############

        return output[0], new_hidden_state

In this notebook, we decreased the convergence threshold and evaluate for num_time_steps = 2#

# Note if tsave = jnp.linspace(0, 1, 1) = [0.0] then the decay is not applied ?
# because the first time step has the original non decayed state
from feedback_grape.fgrape import Decay, Gate

key = jax.random.PRNGKey(42)

# Answer: In documentation, clarify that the initial_params are the params up to the
# point where measurement occurs, compared with other modes where the initial_params
# are the initial params for the entire system for all time steps. --> this is already fixed in
# example a and an explanation of the mechanism may be provided in the docs
measure = Gate(
    gate=povm_measure_operator,
    initial_params=jax.random.uniform(
        key,
        shape=(2,),  # 2 for gamma and delta
        minval=-jnp.pi / 2,
        maxval=jnp.pi / 2,
        dtype=jnp.float64,
    ),
    measurement_flag=True,
)

displacement = Gate(
    gate=displacement_gate,
    initial_params=jax.random.uniform(
        key,
        shape=(2,),
        minval=-jnp.pi / 2,
        maxval=jnp.pi / 2,
        dtype=jnp.float64,
    ),
    measurement_flag=False,
)

snap = Gate(
    gate=snap_gate,
    initial_params=jax.random.uniform(
        key,
        shape=(N_snap,),
        minval=-jnp.pi / 2,
        maxval=jnp.pi / 2,
        dtype=jnp.float64,
    ),
    measurement_flag=False,
)

displacement_dag = Gate(
    gate=displacement_gate_dag,
    initial_params=jax.random.uniform(
        key,
        shape=(2,),
        minval=-jnp.pi / 2,
        maxval=jnp.pi / 2,
        dtype=jnp.float64,
    ),
    measurement_flag=False,
)

decay = Decay(c_ops=[jnp.sqrt(0.005) * destroy(N_cav)])

system_params = [decay, measure, decay, displacement, snap, displacement_dag]


for reward_weights in [[1.0,1.0],[0.0,1.0]]:
    result = optimize_pulse(
        U_0=rho_target,
        C_target=rho_target,
        system_params=system_params,
        num_time_steps=2,
        reward_weights=reward_weights,
        mode="nn",
        goal="fidelity",
        max_iter=1000,
        convergence_threshold=1e-6,
        learning_rate=0.01,
        evo_type="density",
        batch_size=16,
        rnn=RNN,
        rnn_hidden_size=30,
        eval_batch_size=16,
        eval_time_steps=5,
    )

    print(f"reward weights: {reward_weights}\n fidelity@t=2: {result.fidelity_each_timestep[2]}\n fidelity_each_timestep: {jnp.mean(jnp.array(result.fidelity_each_timestep), axis=1)}\n")
reward weights: [1.0, 1.0]
 fidelity@t=2: [0.92591808 0.92591749 0.92591797 0.92591771 0.92591646 0.92591733
 0.92591811 0.9262644  0.9262644  0.92591815 0.92591816 0.92626412
 0.92591816 0.92591692 0.9259181  0.92591816]
 fidelity_each_timestep: [1.00000004 0.96135187 0.92598273 0.8934548  0.86350734 0.83590455]

reward weights: [0.0, 1.0]
 fidelity@t=2: [0.92582739 0.92583587 0.9258323  0.92579993 0.92583189 0.92583109
 0.92583476 0.92457332 0.92457332 0.92583351 0.9258302  0.925819
 0.92583449 0.92583398 0.92583212 0.9258347 ]
 fidelity_each_timestep: [1.00000004 0.9610333  0.92567237 0.89315248 0.86321307 0.83561829]
result.fidelity_each_timestep
[Array([1.00000004, 1.00000004, 1.00000004, 1.00000004, 1.00000004,
        1.00000004, 1.00000004, 1.00000004, 1.00000004, 1.00000004,
        1.00000004, 1.00000004, 1.00000004, 1.00000004, 1.00000004,
        1.00000004], dtype=float64),
 Array([0.96133345, 0.96134252, 0.96133877, 0.96130492, 0.96133823,
        0.96133749, 0.96134138, 0.95891675, 0.95891675, 0.96133999,
        0.96133644, 0.96132489, 0.96134099, 0.96134045, 0.96133853,
        0.96134127], dtype=float64),
 Array([0.92582739, 0.92583587, 0.9258323 , 0.92579993, 0.92583189,
        0.92583109, 0.92583476, 0.92457332, 0.92457332, 0.92583351,
        0.9258302 , 0.925819  , 0.92583449, 0.92583398, 0.92583212,
        0.9258347 ], dtype=float64),
 Array([0.89319695, 0.8932051 , 0.89320165, 0.89317064, 0.8932011 ,
        0.89319945, 0.89320405, 0.89282879, 0.89282879, 0.8932026 ,
        0.89319986, 0.89318807, 0.8932038 , 0.89320339, 0.89320146,
        0.89320399], dtype=float64),
 Array([0.86317697, 0.86318483, 0.86318148, 0.86315166, 0.86318077,
        0.86317838, 0.86318383, 0.86345479, 0.86345479, 0.86318216,
        0.86317992, 0.86316766, 0.86318359, 0.86318326, 0.86318123,
        0.86318373], dtype=float64),
 Array([0.83552755, 0.83553515, 0.83553192, 0.8355031 , 0.83553106,
        0.83552793, 0.83553416, 0.83624387, 0.83624387, 0.83553233,
        0.83553056, 0.83551776, 0.83553401, 0.83553371, 0.83553155,
        0.83553408], dtype=float64)]
result.final_state.shape
(16, 30, 30)
from feedback_grape.utils.fidelity import ket2dm

N_cav = 30  # number of cavity modes
N_snap = 15

alpha = 2
psi_target = coherent(N_cav, alpha) + coherent(N_cav, -alpha)

# Normalize psi_target before constructing rho_target
psi_target = psi_target / jnp.linalg.norm(psi_target)

rho_target = ket2dm(psi_target)
from feedback_grape.utils.fidelity import fidelity

print(
    "initial fidelity:",
    fidelity(C_target=rho_target, U_final=rho_target, evo_type="density"),
)
for i, state in enumerate(result.final_state):
    print(
        f"fidelity of state {i}:",
        fidelity(C_target=rho_target, U_final=state, evo_type="density"),
    )
initial fidelity: 1.0000000353712848
fidelity of state 0: 0.8355275465204998
fidelity of state 1: 0.8355351498978077
fidelity of state 2: 0.8355319238279622
fidelity of state 3: 0.8355030964952107
fidelity of state 4: 0.8355310579595714
fidelity of state 5: 0.8355279280397181
fidelity of state 6: 0.835534161963775
fidelity of state 7: 0.8362438681035523
fidelity of state 8: 0.8362438681035523
fidelity of state 9: 0.8355323320419621
fidelity of state 10: 0.8355305634776103
fidelity of state 11: 0.8355177584837092
fidelity of state 12: 0.8355340086298296
fidelity of state 13: 0.8355337137576336
fidelity of state 14: 0.8355315498567129
fidelity of state 15: 0.8355340755941594