import jax.profiler jax.profiler.start_server(9999) import numpy as onp import jax.numpy as jnp from functools import partial from jax import random from jax.nn.initializers import (xavier_normal, xavier_uniform, glorot_normal, glorot_uniform, uniform, normal, lecun_uniform, lecun_normal,kaiming_uniform,kaiming_normal) from jax.nn import (softplus, selu,gelu,glu,swish,relu,relu6,elu,sigmoid, swish) from jax import vmap, grad, partial, pmap, value_and_grad, jit from jax.experimental.ode import odeint coupling_matrix_ = onp.load('./coupling_matrix.npy') epi_array_ = onp.load('./epi_array.npy') mobilitypopulation_array_scaled_ = onp.load('./mobilitypopulation_array_scaled.npy') coupling_matrix = jnp.asarray(coupling_matrix_) epi_array = jnp.asarray(epi_array_) mobilitypopulation_array_scaled = jnp.asarray(mobilitypopulation_array_scaled_) def inv_softplus(x): return x+jnp.log(-jnp.expm1(-x)) key = random.PRNGKey(0) layers = [7, 14, 14, 7, 1] activations = [swish, swish, swish, softplus] weight_initializer = kaiming_uniform bias_initializer = normal def init_layers(nn_layers,nn_weight_initializer_, nn_bias_initializer_): init_w = weight_initializer() init_b = bias_initializer() params = [] for in_, out_ in zip(layers[:-1],layers[1:]): key = random.PRNGKey(in_) weights = init_w(key,(in_,out_)).reshape((in_*out_,)) biases = init_b(key,(out_,)) params_ = jnp.concatenate((weights,biases)) params.append(params_) return jnp.concatenate(params) def nnet(nn_layers, nn_activations, nn_params, x): n_s = 0 x_in = jnp.expand_dims(x,axis=1) # #x_in = x.reshape(len(x),1) for in_,out_, act_ in zip(nn_layers[:-1],nn_layers[1:],nn_activations): n_w = in_*out_ n_b = out_ n_t = n_w+n_b weights = nn_params[n_s:n_s+n_w].reshape((out_,in_)) biases = jnp.expand_dims(nn_params[n_s+n_w:n_s+n_t],axis=1) x_in = act_(jnp.matmul(weights,x_in)+biases) n_s += n_t return x_in nn = jit(partial(nnet, layers,activations)) nn_batch = vmap(partial(nnet,layers,activations), (None,0),0) #nn_batch=partial(nnet, layers,activations) p_net = init_layers(layers,weight_initializer,bias_initializer) # county-wise learnable scaling factors n_counties = coupling_matrix.shape[0] init_b = bias_initializer() p_scaling = softplus(200*init_b(key,(n_counties,))) def SEIRD_mobility_coupled(u, t, p_, mobility_, coupling_matrix_): s, e, id1, id2, id3, id4, id5, id6, id7, d, ir1, ir2, ir3, ir4, ir5, r = u κ, α, γ = softplus(p_[:3]) # κ*α and γ*η are not independent. The probablibility of transition from e to Ir and Id has to add up to 1 η = - jnp.log(-jnp.expm1(-κ*α))/(γ+1.0e-8) ind = jnp.rint(t.astype(jnp.float32)) n_c = coupling_matrix_.shape[0] scaler_ = softplus(p_[3:3+n_c]) cm_ = jnp.expand_dims(scaler_,(1))*coupling_matrix_[...,ind.astype(jnp.int32)] β = nn_batch(p_[3+n_c:], mobility_[...,ind.astype(jnp.int32)])[:,0,0] i = id1+id2+id3+ir1+ir2+ir3+ir4+ir5 a = β*s*i+β*s*(jnp.matmul(i,cm_.T)+jnp.matmul(cm_,i)) ds = -a de = a - κ*α*e - γ*η*e d_id1 = κ*(α*e-id1) d_id2 = κ*(id1-id2) d_id3 = κ*(id2-id3) d_id4 = κ*(id3-id4) d_id5 = κ*(id4-id5) d_id6 = κ*(id5-id6) d_id7 = κ*(id6-id7) d_d = κ*id7 d_ir1 = γ*(η*e-ir1) d_ir2 = γ*(ir1-ir2) d_ir3 = γ*(ir2-ir3) d_ir4 = γ*(ir3-ir4) d_ir5 = γ*(ir4-ir5) d_r = γ*ir5 return jnp.stack([ds, de, d_id1, d_id2, d_id3, d_id4, d_id5, d_id6, d_id7, d_d, d_ir1 ,d_ir2, d_ir3, d_ir4, d_ir5, d_r]) # Initial conditions ifr = 0.007 n_counties = epi_array.shape[2] n = jnp.tile(1.0,(n_counties,)) ic0 = epi_array[0,0,:] d0 = epi_array[0,1,:] r0 = d0/ifr s0 = n-ic0-r0-d0 e0 = ic0 id10=id20=id30=id40=id50=id60=id70=ic0*ifr/7.0 ir10=ir20=ir30=ir40=ir50=ic0*(1.0-ifr)/5.0 u0 = jnp.array([s0, e0, id10,id20,id30,id40,id50, id60, id70, d0, ir10,ir20,ir30,ir40,ir50,r0]) # ODE Parameters κ0_ = 0.97 α0_ = 0.00185 β0_ = 0.5 tb_ = 15 β1_ = 0.4 γ0_ = 0.24 p_ode = inv_softplus(jnp.array([κ0_, α0_, γ0_])) # Initial model parameters p_init = jnp.concatenate((p_ode,p_scaling,p_net)) t0 = jnp.linspace(0, float(epi_array.shape[0]), int(epi_array.shape[0])+1) # LOSS Function def diff(sol_,data_): l1 = jnp.square(jnp.ediff1d((1-sol_[:,0])) - data_[:,0]) l2 = jnp.square(jnp.ediff1d(sol_[:,9]) - data_[:,1]) return l1+20000*l2 diff_v = vmap(diff,(2,2)) def loss(data_,m_array_, coupling_matrix_, params_): sol_ = odeint(SEIRD_mobility_coupled, u0, t0, params_, m_array_,coupling_matrix_, rtol=1e-4, atol=1e-8) return jnp.sum(diff_v(sol_,data_)) loss_ = partial(loss, epi_array,mobilitypopulation_array_scaled,coupling_matrix) grad_jit = jit(grad(loss_)) loss_jit = jit(loss_) %timeit loss_jit(p_init).block_until_ready() # 91.1 ms ± 4.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) %timeit grad_jit(p_init).block_until_ready() # 24.6 s ± 1.13 s per loop (mean ± std. dev. of 7 runs, 1 loop each)