B. State purification with qubit-mediated measurement#

from feedback_grape.fgrape import optimize_pulse
import jax.numpy as jnp

The cavity is initially in a mixed state –> Goal is to purify the state#

We are trying to maximize the property determined by \(tr (\rho_{\text{cav}}^2)\) which is the purity

In the following, we consider an adaptive measurement scheme, demonstrated in a series of experiments on Rydberg atoms interacting with microwave cavities. In this scheme, the cavity is coupled to an ancilla qubit, which can then be read out to update our knowledge of the quantum state of the cavity.

image

# initial state is a thermal state
n_average = 2
N_cavity = 30
# natural logarithm
beta = jnp.log((1 / n_average) + 1)
diags = jnp.exp(-beta * jnp.arange(N_cavity))
normalized_diags = diags / jnp.sum(diags, axis=0)
rho_cav = jnp.diag(normalized_diags)

Now the thing is here, we don’t need a rho_final because the purity or the reward that we want to maximize is \(tr (\rho_{\text{cav}}^2)\).#

Unlike fidelity expressions which wants to find how close to states are

Next Step is to construct our POVM#

from feedback_grape.utils.operators import cosm, sinm

povm_measure_operator (callable):
#

- It should take a measurement outcome and list of params as input
- The measurement outcome options are either 1 or -1
from feedback_grape.utils.operators import create, destroy, identity
import jax


def povm_measure_operator(measurement_outcome, params):
    """
    POVM for the measurement of the cavity state.
    returns Mm ( NOT the POVM element Em = Mm_dag @ Mm ), given measurement_outcome m, gamma and delta
    """
    gamma = params[0]
    delta = params[1]
    number_operator = create(N_cavity) @ destroy(N_cavity)
    angle = (gamma * number_operator) + delta / 2 * identity(N_cavity)
    return jnp.where(
        measurement_outcome == 1,
        cosm(angle),
        sinm(angle),
    )
# Answer: why does RNN outputs two the same, then two the same, then two the same and so on? -> calculate during forward proagation then again during back_propagation
from feedback_grape.fgrape import Gate

measure = Gate(
    gate=povm_measure_operator,
    initial_params=jax.random.uniform(
        key=jax.random.PRNGKey(42),
        shape=(2,),
        minval=0.0,
        maxval=jnp.pi,
        dtype=jnp.float64,
    ),
    measurement_flag=True,
    # param_constraints=[
    #     [0, jnp.pi],
    #     [-2 * jnp.pi, 2 * jnp.pi],
    # ],
)

system_params = [measure]

for reward_weights in [(1,1,1,1,1),(0,0,0,0,1)]:
    result = optimize_pulse(
        U_0=rho_cav,
        C_target=None,
        system_params=system_params,
        num_time_steps=5,
        reward_weights=reward_weights,
        mode="lookup",
        goal="purity",
        max_iter=1000,
        convergence_threshold=1e-20,
        learning_rate=0.01,
        evo_type="density",
        batch_size=10,
        eval_time_steps=20,
    )

    print(f"reward weights: {reward_weights}\n purity@t=5: {result.purity_each_timestep[5]}\n purity_each_timestep: {jnp.mean(jnp.array(result.purity_each_timestep), axis=1)}\n")
reward weights: (1, 1, 1, 1, 1)
 purity@t=5: [0.97211249 0.9988749  0.99234986 0.92729246 0.97211249 0.99356438
 0.9988749  0.99234986 0.99234986 0.5209344 ]
 purity_each_timestep: [0.20000209 0.30247793 0.52649781 0.74846354 0.92343833 0.93608156
 0.93201872 0.98886305 0.98556798 0.98166295 0.98577724 0.98420409
 0.97898881 0.9876874  0.99039578 0.98847357 0.99185657 0.99063171
 0.98832074 0.98476494 0.97691869]

reward weights: (0, 0, 0, 0, 1)
 purity@t=5: [0.97011251 0.98526631 0.96122304 0.99754393 0.96122304 0.93864283
 0.98526631 0.96122304 0.97112369 0.99430915]
 purity_each_timestep: [0.20000209 0.31919734 0.42498204 0.63902323 0.75627459 0.97259339
 0.98623023 0.9879479  0.98603863 0.98877424 0.98480941 0.98982122
 0.99196245 0.99290787 0.99324596 0.99359925 0.99374561 0.99375388
 0.99375182 0.99397963 0.9944511 ]
print(result.final_fidelity)
None
# 0.9163363647226792
print(result.final_purity)
0.9944510950702666
from feedback_grape.utils.purity import purity

# the highest purity can be 0.995 if the initial params that initializes the lookup table
# are between 0 and pi rather than -pi and pi
print("initial purity:", purity(rho=rho_cav))
for i, state in enumerate(result.final_state):
    print(f"Purity of state {i}:", purity(rho=state))
initial purity: 0.2000020860488993
Purity of state 0: 0.9976963389616152
Purity of state 1: 0.9991528026159782
Purity of state 2: 0.9807734068777262
Purity of state 3: 0.9999966853406007
Purity of state 4: 0.9857513139831835
Purity of state 5: 0.9999961126866741
Purity of state 6: 0.9991528026159782
Purity of state 7: 0.9857513139831828
Purity of state 8: 0.9972458444854705
Purity of state 9: 0.9989943291522569
result.returned_params
[[Array([[1.09333296, 0.05908029],
         [1.09333296, 0.05908029],
         [1.09333296, 0.05908029],
         [1.09333296, 0.05908029],
         [1.09333296, 0.05908029],
         [1.09333296, 0.05908029],
         [1.09333296, 0.05908029],
         [1.09333296, 0.05908029],
         [1.09333296, 0.05908029],
         [1.09333296, 0.05908029]], dtype=float64)],
 [Array([[1.09323741e+00, 4.51206584e-01],
         [1.57079633e+00, 8.24752039e-17],
         [1.09323741e+00, 4.51206584e-01],
         [1.57079633e+00, 8.24752039e-17],
         [1.09323741e+00, 4.51206584e-01],
         [1.57079633e+00, 8.24752039e-17],
         [1.57079633e+00, 8.24752039e-17],
         [1.09323741e+00, 4.51206584e-01],
         [1.09323741e+00, 4.51206584e-01],
         [1.57079633e+00, 8.24752039e-17]], dtype=float64)],
 [Array([[ 1.57079633e+00,  2.97152495e-15],
         [ 1.02955772e+00, -8.74809684e-03],
         [ 1.57079633e+00,  2.97152495e-15],
         [ 1.40156193e+00,  7.14115220e-01],
         [ 1.57079633e+00,  2.97152495e-15],
         [ 1.40156193e+00,  7.14115220e-01],
         [ 1.02955772e+00, -8.74809684e-03],
         [ 1.57079633e+00,  2.97152495e-15],
         [ 1.57079632e+00,  1.78988435e-07],
         [ 1.40156193e+00,  7.14115220e-01]], dtype=float64)],
 [Array([[1.38466876, 0.19230081],
         [1.30807422, 0.00858833],
         [1.57739266, 1.19786011],
         [0.78540056, 1.57078098],
         [1.57739266, 1.19786011],
         [0.78540056, 1.57078098],
         [1.30807422, 0.00858833],
         [1.57739266, 1.19786011],
         [1.1046952 , 1.52435026],
         [0.78540056, 1.57078098]], dtype=float64)],
 [Array([[ 8.23742464e-01, -2.54610945e-01],
         [ 7.85398163e-01,  9.02871536e-15],
         [ 1.29391358e+00,  5.20452918e-01],
         [ 1.17809723e+00,  7.85398159e-01],
         [ 1.29391358e+00,  5.20452918e-01],
         [ 1.17809723e+00,  7.85398159e-01],
         [ 7.85398163e-01,  9.02871536e-15],
         [ 1.29391358e+00,  5.20452918e-01],
         [ 1.33637017e+00,  1.04223380e+00],
         [ 1.48290516e+00,  5.06418573e-01]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)],
 [Array([[1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552],
         [1.34060419, 0.04715552]], dtype=float64)]]