functions{
  real factorial(int m);
  real factorial(int m) {
    if (m == 0)
      return 1;
    else
      return (m * factorial(m-1));
  }
  
  matrix matrix_pow(matrix a, real n);
  matrix matrix_pow(matrix a, real n) {
    if (n == 0){
      return diag_matrix(rep_vector(1, rows(a)));
    }
    else
      return a * matrix_pow(a, n - 1);
  }
  
  matrix expmt(matrix mat, real t){
    matrix[rows(mat),rows(mat)] out;
    out<-diag_matrix(rep_vector(1.0, rows(mat)));
    for(i in 1:50){
      out<-out+ matrix_pow(mat,i) * t^i /factorial(i);
    }
    return out;
  }
  
  matrix expmp(matrix A, matrix padeC, vector padeCbig){
    int n;
    real nA;
    real colsum;
    int l;
    matrix[4,10] C;
    vector[4] t;
    matrix[rows(A),rows(A)] I;
    matrix[rows(A),rows(A)] P;
    matrix[rows(A),rows(A)] U;
    matrix[rows(A),rows(A)] V;
    matrix[rows(A),rows(A)] X;
    
    vector[14] Cbig;
    real s;
    real si;
    matrix[rows(A),rows(A)] B;
    matrix[rows(A),rows(A)] B2;
    matrix[rows(A),rows(A)] B4;
    matrix[rows(A),rows(A)] B6;
    matrix[rows(A),rows(A)] A2;
    
    si <- 0;
    C <- padeC;
    Cbig <- padeCbig;
    
    n <-rows(A);
    if(n != cols(A)) print("expmp: Matrix not square!")
    
    if (n <= 1) X <- exp(A);
    else{
      
      // nA <- Matrix::norm(A, "1")
      nA <- 0;
      for(coli in 1:n){
        colsum<-0;
        for(rowi in 1:n){
          colsum<-colsum+fabs(A[rowi,coli]);
        }
        if(colsum > nA) nA <- colsum;
      }
      
      I <- diag_matrix(rep_vector(1,n));
      if (nA <= 2.1) {
        t[1] <- 0.015; t[2]<- 0.25; t[3]<- 0.95; t[4]<- 2.1;
        
        //l <- which.max(nA <= t)
        for(ti in 1:4){
          if(l==0){
            if(nA <= t[ti]) l <- ti;
          }
        }
        
        A2 <- A * A;
        P <- I;
        U <- C[l, 2] * I;
        V <- C[l, 1] * I;
        for (k in 1:l) {
          P <- P * A2;
          U <- U + C[l, (2 * k) + 2] * P;
          V <- V + C[l, (2 * k) + 1] * P;
        }
        U <- A * U;
        X <- inverse(V - U) * (V + U);
      }
      
      else {
        s <- log2(nA/5.4);
        B <- A;
        if (s > 0) {
          s <- ceil(s);
          B <- B/(2^s);
        }
        
        B2 <- B * B;
        B4 <- B2 * B2;
        B6 <- B2*B4;
        U <- B*(B6*(Cbig[14] * B6 + Cbig[12] * B4 + Cbig[10] * B2) + Cbig[8] * B6 + Cbig[6] * B4 + Cbig[4] * B2 + Cbig[2] * I);
        V <- B6*(Cbig[13] * B6 + Cbig[11] * B4 + Cbig[9] * B2) + Cbig[7] * B6 + Cbig[5] * B4 + Cbig[3] * B2 + Cbig[1] * I;
        X <- inverse(V - U) * (V + U);
        
        if (s > 0) {
          while (si < s){
            si <- si + 1;
            X <- X * X;
          }
        }
      }
      
    }
    return X;
  }
  
  matrix kron_prod(matrix mata, matrix matb){
    int m;
    int p;
    int n;
    int q;
    matrix[rows(mata)*rows(matb),cols(mata)*cols(matb)] C;
    m<-rows(mata);
    p<-rows(matb);
    n<-cols(mata);
    q<-cols(matb);
    for (i in 1:m){
      for (j in 1:n){
        for (k in 1:p){
          for (l in 1:q){
            C[p*(i-1)+k,q*(j-1)+l] <- mata[i,j]*matb[k,l];
          }
        }
      }
    }
    return C;
  }
}
data {
  matrix[4,10] padeC;
  vector[14] padeCbig;
  int<lower=1> AMI_n;
  int<lower=1> E_n;
  int<lower=1> AN_n;
  int<lower=1> SS_n;
  int<lower=1> P_n;
  int<lower=1> LC_n;
  int<lower=1> LV_n;
  int<lower=1> CO_n;
  matrix[AMI_n*E_n, AMI_n*E_n]scratch_joint_inten;
  int<lower=1> D_N;
  int<lower=1> AE_N;
  int<lower=1> ASS_N;
  int<lower=1,upper = 4> AE_states[AE_N];
  int<lower=1,upper = 4> ASS_states[ASS_N];
  int<lower=1,upper = 2> P_states[D_N];
  int<lower=1,upper = 2> P_AMI_states[D_N-1];
  
  int<lower=1,upper = 2> LC_states[D_N];
  int<lower=1,upper = 2> LC_P_states[D_N-1];
  int<lower=1,upper = 2> LV_states[D_N];
  int<lower=1,upper = 2> LV_LC_states[D_N-1];
  int<lower=1,upper = 2> CO_states[D_N];
  int<lower=1,upper = 2> CO_LV_states[D_N-1];
  
  real<lower=0> AE_traj_mis[AE_N];
  real<lower=0> ASS_traj_mis[ASS_N];
}
parameters{
  real<lower = 0, upper = 20>AE_dig[CO_n, AMI_n * E_n];
  simplex[AMI_n - 1 + E_n - 1] AE_theta[CO_n, AMI_n * E_n];
  real<lower = 0, upper = 20>ASS_dig[CO_n, AN_n * SS_n];
  simplex[AN_n - 1 + SS_n - 1] ASS_theta[CO_n, AN_n * SS_n];
  simplex[P_n]P_tran[AMI_n,P_n];                                  // CPTs for P
  simplex[LC_n]LC_tran[P_n,LC_n];                                 // CPTs for LC
  simplex[LV_n]LV_tran[LC_n,LV_n];                                // CPTs for LV
  simplex[CO_n]CO_tran[LV_n,CO_n];                                // CPTs for CO
}
model{
  real AE_inten[CO_n, AMI_n * E_n , AMI_n * E_n];
  real ASS_inten[CO_n, AN_n * SS_n , AN_n * SS_n];
  int ind;
  # matrix for AE
  matrix[AMI_n * E_n, AMI_n * E_n] A;
  matrix[AMI_n * E_n, AMI_n * E_n] B;
  matrix[AMI_n * E_n, AMI_n * E_n] C;
  # matrix for ASS
  matrix[AN_n * SS_n, AN_n * SS_n] D;
  matrix[AN_n * SS_n, AN_n * SS_n] E;
  matrix[AN_n * SS_n, AN_n * SS_n] F;
  int s_ind;
  real gammaAE[AE_N-1];
  vector[AMI_n * E_n]dis;
  vector[AMI_n - 1 + E_n - 1] alpha;
  real gammaASS[ASS_N-1];
  vector[AN_n * SS_n]ASS_dis;
  vector[AN_n - 1 + SS_n - 1] ASS_alpha;
  real gammaP[D_N-1];
  real gammaLC[D_N-1];
  real gammaLV[D_N-1];
  real gammaCO[D_N-1];
  vector[P_n] P_alpha;
  vector[LC_n] LC_alpha;
  vector[LV_n] LV_alpha;
  vector[CO_n] CO_alpha;
  real tmp;
  s_ind <- 1;
  # hyper parameters for AE
  for(i in 1:(AMI_n - 1 + E_n - 1))
    alpha[i] <- 1;
  # hyper parameters for ASS
  for(i in 1:(AN_n - 1 + SS_n - 1))
    ASS_alpha[i] <- 1;
  # hyper parameters for P
  for(i in 1:P_n)
    P_alpha[i] <- 1;
  # hyper parameters for LC
  for(i in 1:LC_n)
    LC_alpha[i] <- 1;
  # hyper parameters for LV
  for(i in 1:LV_n)
    LV_alpha[i] <- 1;
  # hyper parameters for CO
  for(i in 1:CO_n)
    CO_alpha[i] <- 1;
  
  # prior distribution for AE
  for(i in 1:CO_n)
  {
    for(j in 1:(AMI_n * E_n))
    {
      AE_theta[i, j] ~ dirichlet(alpha);
      AE_dig[i,j] ~ gamma(2,2);
    }
  }
  # prior distribution for ASS
  for(i in 1:CO_n)
  {
    for(j in 1:(AN_n * SS_n))
    {
      ASS_theta[i, j] ~ dirichlet(ASS_alpha);
      ASS_dig[i,j] ~ gamma(2,2);
    }
  }
  
  # construct intensity for AE
  for(i in 1:CO_n)
  {
    for(j in 1:AMI_n * E_n)
    {
      ind <- 1;
      for(k in 1:AMI_n * E_n)
      {
        if(scratch_joint_inten[j,k] ==0)
        AE_inten[i,j,k] <- 0;
        else if(scratch_joint_inten[j,k] == 1)
        {
          AE_inten[i,j,k] <- AE_theta[i, j, ind] * AE_dig[i,j];
          ind <- ind + 1;
        }
      }
      AE_inten[i,j,j] <- -AE_dig[i, j];
    }
  }
   # construct intensity for ASS
  for(i in 1:CO_n)
  {
    for(j in 1:AN_n * SS_n)
    {
      ind <- 1;
      for(k in 1:AN_n * SS_n)
      {
        if(scratch_joint_inten[j,k] ==0)
        ASS_inten[i,j,k] <- 0;
        else if(scratch_joint_inten[j,k] == 1)
        {
          ASS_inten[i,j,k] <- ASS_theta[i, j, ind] * ASS_dig[i,j];
          ind <- ind + 1;
        }
      }
      ASS_inten[i,j,j] <- -ASS_dig[i, j];
    }
  }
  
  # prior distribution for P
  for(i in 1:AMI_n)
  {
    for(j in 1:P_n)
    P_tran[i,j] ~ dirichlet(P_alpha);
  }
  # prior distribution for LC
  for(i in 1:P_n)
  {
    for(j in 1:LC_n)
    LC_tran[i,j] ~ dirichlet(LC_alpha);
  }
  # prior distribution for LV
  for(i in 1:LC_n)
  {
    for(j in 1:LV_n)
    LV_tran[i,j] ~ dirichlet(LV_alpha);
  }
  # prior distribution for CO
  for(i in 1:LV_n)
  {
    for(j in 1:CO_n)
    CO_tran[i,j] ~ dirichlet(CO_alpha);
  }
  
  
  # compute likelihood for AE
  A <- to_matrix(AE_inten[CO_states[1],,]);
  for(i in 2:AE_N)
  {
      if(AE_traj_mis[i] <= s_ind)
      {
        B <- expmp(A*(AE_traj_mis[i] - AE_traj_mis[i-1]), padeC, padeCbig);
        dis <- to_vector(B[AE_states[i-1]]);
       
      }
      else
      {
        # distribution before discrete slice
        while((s_ind+1) <= AE_traj_mis[i])
          s_ind <- s_ind + 1;
        B <- expmp(A*(s_ind- AE_traj_mis[i-1]), padeC, padeCbig);
        
        # distribution after continuous slice
        A <- to_matrix(AE_inten[CO_states[s_ind + 1],,]);
        C <- expmp(A*( AE_traj_mis[i] - s_ind), padeC, padeCbig);
        for(j in 1:AMI_n * E_n)
        {
          dis[j] <- 0;
          for(k in 1:AMI_n * E_n)
            dis[j] <- dis[j] + B[AE_states[i-1], k] * C[k, j];
        }
    }
   tmp <- 0;
   for(m in 1: (AMI_n *E_n-1))
    tmp <- tmp + dis[m];
   dis[AMI_n *E_n] <- 1-tmp;
   // print("dis[",i,"] = ", dis)
   // print("dis[",i,"] = ", sum(dis))
    gammaAE[i-1] <- categorical_log(AE_states[i], dis);
  }
  
  s_ind <- 1;
  # compute likelihood for ASS
  D <- to_matrix(ASS_inten[CO_states[1],,]);
  for(i in 2:ASS_N)
  {
      if(ASS_traj_mis[i] <= s_ind)
      {
        E <- expmp(D*(ASS_traj_mis[i] - ASS_traj_mis[i-1]), padeC, padeCbig);
        ASS_dis <- to_vector(E[ASS_states[i-1]]);
       
      }
      else
      {
        # distribution before discrete slice
        while((s_ind+1) <= ASS_traj_mis[i])
          s_ind <- s_ind + 1;
        E <- expmp(D*(s_ind- ASS_traj_mis[i-1]), padeC, padeCbig);
        
        # distribution after continuous slice
        D <- to_matrix(ASS_inten[CO_states[s_ind + 1],,]);
        F <- expmp(D*( ASS_traj_mis[i] - s_ind), padeC, padeCbig);
        for(j in 1:(AN_n * SS_n))
        {
          ASS_dis[j] <- 0;
          for(k in 1:(AN_n * SS_n))
            ASS_dis[j] <- ASS_dis[j] + E[ASS_states[i-1], k] * F[k, j];
        }
    }
    tmp <- 0;
   for(m in 1: (AN_n *SS_n-1))
    tmp <- tmp + ASS_dis[m];
   ASS_dis[AMI_n *E_n] <- 1-tmp;
   // print("dis[",i,"] = ", ASS_dis)
   // print("dis[",i,"] = ", sum(ASS_dis))
    gammaASS[i-1] <- categorical_log(ASS_states[i], ASS_dis);
  }
  
  # compute likelihood for P
  for(i in 2:D_N)
  {
    gammaP[i-1] <- categorical_log(P_states[i], P_tran[P_AMI_states[i-1],P_states[i-1]]);
    gammaLC[i-1] <- categorical_log(LC_states[i], LC_tran[LC_P_states[i-1],LC_states[i-1]]);
    gammaLV[i-1] <- categorical_log(LV_states[i], LV_tran[LV_LC_states[i-1],LV_states[i-1]]);
    gammaCO[i-1] <- categorical_log(CO_states[i], CO_tran[CO_LV_states[i-1],CO_states[i-1]]);
  }
  increment_log_prob(gammaAE);
  increment_log_prob(gammaASS);
  increment_log_prob(gammaP);
  increment_log_prob(gammaLC);
  increment_log_prob(gammaLV);
  increment_log_prob(gammaCO);
}




