I had a go at a few SEIR models, this is a rough diary of the process.

Classical SEIR Model

The model is described on the Compartmental Models Wikipedia Page. Need to be careful about when the SEIR model starts though as none of the parameters explicitly control that aspect. I found it easier to truncate a bit of the data history before trying to minimize some kind of an error between predicted and actual cumulative “removed state” numbers. Since the paths are deterministic, this is quite easy to use. Rather than using Stan to optimize likelihoods, I found it easier to use black-box optimizers like PSO, GPO.

I wrote a quick-and-dirty Dash app to eyeball good initial starting values for the parameters (shown below; the values are most likely nonsensical at the moment). I’m also playing around with Imperial’s model.

Based on their assumptions, I made a few tweaks to mine (changed the beta parameter back to an exponential decay, used their IFR assumption). The infection-to-fatality ratio turned out to be a key assumption - setting this to 1% (following the Imperial model, based on Verity et al.) makes the results of this SEIR model quite close to the Imperial model results for one country I’m looking at, at the moment. The data for the app is based on a github repo, link is included in the code below.

Quick & Dirty Dash App Code
import requests
import dash, json
import numpy as np
import pandas as pd
import datetime as dt
from functools import lru_cache
import dash_core_components as dcc
import dash_html_components as html
import dash_bootstrap_components as dbc

from pyswarm import pso
from plotly.graph_objs import *
from scipy.integrate import odeint
from plotly import graph_objs as go
from dash.dependencies import Input, Output

@lru_cache(maxsize = 3)
def get_data(country = 'Spain'):
    assert country in ['Spain', 'United Kingdom']

    data = requests.get('https://pomber.github.io/covid19/timeseries.json')
    data = pd.DataFrame(data.json()[country])
    data = data.loc[data.deaths >= 10, :]
    data.date = pd.to_datetime(data.date)
    data.reset_index(inplace = True)

    if country == 'Spain':
        lockdown_date = dt.datetime(2020, 3, 15)
    if country == 'United Kingdom':
        lockdown_date = dt.datetime(2020, 3, 24)

    return data, lockdown_date

def dydt(y, t, *args):
    S, E, I, R = y; N = np.sum(y);
    b_0, b_1, a, y, t_l = args
    dxdt = np.zeros(4)
    
    b = b_0 * np.exp(-b_1 * np.max([0, t - t_l])) # b_0 if t <= t_l else b_1

    dxdt[0] = -b*S*I/N
    dxdt[1] = b*S*I/N - a*E
    dxdt[2] = a*E - y*I
    dxdt[3] = y*I
    return dxdt

def simulate(b_0 = 0.99, b_1 = 0.1, t_l = 50, n = 1e6, e_0 = 10, i_0 = 0, t = 100, infec_prd = 5, recov_prd = 7):
    theta = (b_0, b_1, 1/infec_prd, 1/recov_prd, t_l)
    initial_state = [n - e_0 - i_0, e_0, i_0, 0]
    simuation = odeint(dydt, initial_state, np.arange(t), args = theta)
    return simuation

app = dash.Dash(__name__, external_stylesheets = [dbc.themes.YETI])

app.layout = html.Div(children=[

    html.Div(children=[
        html.H1(children='SEIR Visualization'),

        html.Label('Select country:'),

        dbc.ListGroup([
            dbc.ListGroupItem("United Kingdom", id="country_uk", n_clicks=0),
            dbc.ListGroupItem("Spain", id="country_sp", n_clicks=1)]
        ),

        dcc.Markdown('---'),

        html.Label('β before intervention:'),

        dcc.Slider(
                id='b_0',
                min=0,
                max=5,
                marks={i: 'Label {}'.format(i) if i == 1 else str(np.round(i, 1)) for i in np.linspace(0, 5, 10)},
                value=1.15,
                step=0.01
            ),

        html.Label('q: (β after intervention = β exp[-q * time since intervention])'),

        dcc.Slider(
                id='b_1',
                min=0,
                max=1,
                marks={i: 'Label {}'.format(i) if i == 1 else str(np.round(i, 1)) for i in np.linspace(0, 1, 10)},
                value=0.125,
                step=0.01
            ),

        html.Label('Log of initial population exposed:'),

        dcc.Slider(
                id='e_0',
                min=-7,
                max=np.log(47e6),
                marks={i: 'Label {}'.format(i) if i == 1 else str(int(np.round(i))) for i in np.linspace(-7, np.log(47e6), 10)},
                step=0.01
            ),

        html.Label('Infectious Period:'),

        dcc.Slider(
                id='infec_prd',
                min=1,
                max=14,
                marks={i: 'Label {}'.format(i) if i == 1 else str(int(np.round(i))) for i in np.linspace(1, 14, 7)},
                value=5.1,
                step=0.1
            ),

        html.Label('Removal Period:'),

        dcc.Slider(
                id='recov_prd',
                min=1,
                max=14,
                marks={i: 'Label {}'.format(i) if i == 1 else str(int(np.round(i))) for i in np.linspace(1, 14, 7)},
                value=7.1,
                step=0.1
            )
        ], style = {
            "position": "fixed",
            "top": 0,
            "left": 0,
            "bottom": 0,
            "width": "19rem",
            "padding": "2rem 1rem",
            "background-color": "#f8f9fa",
        }),

    html.Div(children = [

        dbc.Alert(id = "prop-rem", color="primary"),

        dcc.Graph(id='viz-graph')], style = {
            "margin-left": "18rem",
            "margin-right": "2rem",
            "padding": "2rem 1rem",
        })
])

@app.callback(
    [Output("country_uk", "active"),
     Output("country_sp", "active")],
    [Input("country_uk", "n_clicks"),
     Input("country_sp", "n_clicks")]
)
def update_country(ncl_uk, ncl_sp):
    print((ncl_uk, ncl_sp))
    if ncl_uk is None:
        return False, True
    if (ncl_sp is None) or (ncl_sp >= ncl_uk):
        return False, True
    return True, False

@app.callback(
    Output("e_0", "value"),
    [Input("country_uk", "active"),
     Input("country_sp", "active")]
)
def update_initial_exposed(country_uk, country_sp):
    if country_uk:
        e_0 = 9
    if country_sp:
        e_0 = 10
    return e_0

@app.callback(
    [Output("viz-graph", "figure"),
     Output("prop-rem", "children")],
    [Input("country_uk", "active"),
     Input("country_sp", "active"),
     Input("b_0", "value"),
     Input("b_1", "value"),
     Input("e_0", "value"),
     Input("infec_prd", "value"),
     Input("recov_prd", "value")]
)
def update_histogram(country_uk, country_sp, b_0, b_1, e_0, infec_prd, recov_prd):

    if country_uk:
        country = 'United Kingdom'
    if country_sp:
        country = 'Spain'
    data, lockdown_date = get_data(country)

    ifr = 0.01
    n = len(data.deaths)
    t_l = np.argwhere(np.array(data.date == lockdown_date))[0, 0]
    simulated_seir = simulate(b_0 = b_0, b_1 = b_1, t_l = t_l, n = 47e6, t = n, e_0 = np.exp(e_0), i_0 = 0, infec_prd = infec_prd, recov_prd = recov_prd)
    timestamps = data.loc[1:, 'date']
    simulated_removals_diff = np.diff(simulated_seir[:, 3])
    actual_removals_diff = np.diff(data.deaths)

    return {
            'data': [
                {'x': timestamps, 'y': simulated_removals_diff, 'type': 'line', 'name': 'Predicted'},
                {'x': timestamps, 'y': actual_removals_diff/ifr, 'type': 'line', 'name': country},
            ],
            'layout': {
                'title': 'Removed Population',
                'yaxis': {'range': [1, 7], 'type': 'log'}
            }
        }, str('Proportion removed: ' + str(int(1000 - simulated_seir[-1, 0]/47e3)/10) + '%.')

if __name__ == "__main__":
    app.run_server(debug=True, port=8080, host='0.0.0.0')
Stan SEIR Implementation
// stan, use with integrate_ode_rk45

real[] ode(real time, real[] state, real[] theta,
        real[] x_r, int[] x_i) {

    real dxdt[4]; real b; real a; real y;
    real S; real E; real I; real R; real N;
    real dSdt; real dEdt; real dIdt; real dRdt;

    if(time <= x_r[1]) {
        b = theta[1];
    } else {
        b = theta[2];
    }

    a = theta[3]; y = theta[4]; N = sum(state);
    S = state[1]; E = state[2]; I = state[3]; R = state[4];

    dxdt[1] = -b*S*I/N;
    dxdt[2] = b*S*I/N - a*E;
    dxdt[3] = a*E - y*I;
    dxdt[4] = y*I;
    return dxdt;
}

A Discrete Time SEIR Model

This was the first model that I tried. This is an implementation of the discrete time epidemiological (SEIR) model based on:

P. E. Lekone and B. F. Finkenstädt, “Statistical Inference in a Stochastic Epidemic SEIR Model with Control Intervention”, 2006.

I’ve made some changes to it, e.g. below, an intervention does not lead to an exponential decay of exposure probabilities - rather, the intervention considered here (a ‘lockdown’) just leads to lower exposure probabilities. If the population is large, the paths are very close to the model’s continuous time counterpart (the binomial variance is pretty small), so perhaps the stochastic treatment of the paths (and the resultant presence of so many hidden states) isn’t necessary here.

JAGS Code
model {
    b_i ~ dexp(1)
    b_m ~ dunif(0, b_i)

    S[1] = N
    E[1] = E_0
    I[1] = 0
    R[1] = 0

    for(t in 1:(T - 1)) {

        S[t + 1] = S[t] - B[t]
        E[t + 1] = E[t] + B[t] - C[t]
        I[t + 1] = I[t] + C[t] - D[t]
        R[t + 1] = R[t] + D[t]

        B[t] ~ dbin(Pr[t], S[t])
        C[t] ~ dbin(1 - exp(-p), E[t])
        D[t] ~ dbin(1 - exp(-y), I[t])

        b[t] = ifelse(t <= T_l, b_m, b_i)
        Pr[t] = 1 - exp(-b[t] * I[t] / N)
    }
}

I tried to use quite a few probabilistic programming languages:

  • Stan: Didn’t work because integer parameters are not supported. Marginalizing the parameters would be very expensive I think due to the number of paths. Treating the parameters as real and making normal approximations with truncation was a nightmare (as linear combinations of parameters themselves had to be positive and I ran into precision issues).
  • PyMC3: I’m a beginner with PyMC3 and my implementation was too inefficient. In the docs, the PyMC-devs suggest using theano scan instead of for-loops but I couldn’t figure out how parameter declarations worked in the backend. Code is still shown below if you’re interested.
  • Tensorflow probability: I’m not used to the API and couldn’t find a Gibbs sampler.
  • JAGS: Implementation was very simple and sampling works like a charm. Code above.
PyMC3 Naive Implementation
with pm.Model() as m:

    p_c = 1 - np.exp(-p)
    p_r = 1 - np.exp(-y)

    b_0 = pm.Exponential('b_0', lam = 1)
    b_1 = pm.Uniform('b_1', lower = 0, upper = b_0)

    S = [N, ]; E = [E_0, ]; I = [0, ]; R = [0, ]
    B = []; C = []; D = []; Pr = []

    for t in trange(T):
        b = b_0 if t < T_l else b_1
        t_now = str(t); t_next = str(t + 1)

        # this is inefficient
        Pr.append(pm.Deterministic('Pr_' + t_now, 1 - np.exp(-b * I[t] / N)))
        B.append(pm.Binomial('B_' + t_now, S[t], Pr[t]))
        C.append(pm.Binomial('C_' + t_now, E[t], p_c))
        D.append(pm.Binomial('D_' + t_now, I[t], p_r))

        S.append(pm.Deterministic('S_' + t_next, S[t] - B[t]))
        E.append(pm.Deterministic('E_' + t_next, E[t] + B[t] - C[t]))
        I.append(pm.Deterministic('I_' + t_next, I[t] + C[t] - D[t]))
        R.append(pm.Deterministic('R_' + t_next, R[t] + D[t]))

2020

Astrophotography

I used to do a fair bit of astrophotography in university - it’s harder to find good skies now living in the city. Here are some of my old pictures. I’ve kept making rookie mistakes (too much ISO, not much exposure time, using a slow lens, bad stacking, …), for that I apologize!

1 min read

Probabilistic PCA

I’ve been reading about PPCA, and this post summarizes my understanding of it. I took a lot of this from Pattern Recognition and Machine Learning by Bishop.

1 min read

Modelling with Spotify Data

The main objective of this post was just to write about my typical workflow and views rather than come up with a great model. The structure of this data is also outside my immediate domain so I thought it’d be fun to write up a small diary on making a model with it.

5 min read

Morphing with GPs

The main aim here was to morph space inside a square but such that the transformation preserves some kind of ordering of the points. I wanted to use it to generate some random graphs on a flat surface and introduce spatial deformation to make the graphs more interesting.

1 min read

SEIR Models

I had a go at a few SEIR models, this is a rough diary of the process.

4 min read

Speech Synthesis

The initial aim here was to model speech samples as realizations of a Gaussian process with some appropriate covariance function, by conditioning on the spectrogram. I fit a spectral mixture kernel to segments of audio data and concatenated the segments to obtain the full waveform. Partway into writing efficient sampling code (generating waveforms using the Gaussian process state space representation), I realized that it’s actually quite easy to obtain waveforms if you’ve already got a spectrogram.

5 min read

Efficient Gaussian Process Computation

I’ll try to give examples of efficient gaussian process computation here, like the vec trick (Kronecker product trick), efficient toeliptz and circulant matrix computations, RTS smoothing and Kalman filtering using state space representations, and so on.

4 min read
Back to Top ↑

2019

Gaussian Process Middle C

First of my experiments on audio modeling using Gaussian processes. Here, I construct a GP that, when sampled, plays middle c the way a grand piano would.

1 min read
Back to Top ↑

2018

Back to Top ↑