# Network: eight variables for HF
rm(list = ls())
library(rstan)
library(gtools)
library(LaplacesDemon)
library(erer)

# amalgamation A -> B
# ordering <a1, b1> , <a2, b1> , <a1, b2> , <a2, b2> , <a1, b3> , <a2, b3>.
amalma <- function(A.n, B.n)
{
  # intensity matrix for A
  A = matrix(rep(0, A.n^2), ncol = A.n)
  for(i in 1:A.n)
  {
    for(j in 1:A.n)
    {
      if(i!=j)
        A[i, j] = rgamma(1, shape = 2, rate = 2)
    }
    A[i, i] = -sum(A[i, ])
  }
  # intensity matrix for B
  B = array(data = rep(0,B.n*B.n*A.n), dim = c(B.n, B.n, A.n))
  for(k in 1:A.n)
  {
    for(i in 1:B.n)
    {
      for(j in 1:B.n)
      {
        if(i!=j)
          B[i,j,k] = rgamma(1, shape = 2, rate = 2)
      }
      B[i, i, k] = -sum(B[i,,k])
    }
  }

  # construct a joint intensity with structure A -> B
  # ordering <a1, b1> , <a2, b1> , <a1, b2> , <a2, b2> , <a1, b3> , <a2, b3>.
  order = unlist(lapply(c(1:B.n), function(x)paste(1:A.n, x, sep = "")))
  size = A.n * B.n
  AB = matrix(rep(0, size^2), ncol = size)
  
  for(i in 1:size)
  {
    for(j in 1:size)
    {
      if(i!=j)
      {
        S.curr = as.numeric(unlist(strsplit(order[i], split="")))
        S.next = as.numeric(unlist(strsplit(order[j], split="")))
        if(identical(S.curr[1], S.next[1]) && !identical(S.curr[2], S.next[2]))
          AB[i, j] = B[S.curr[2],S.next[2],S.curr[1]]           # transition for B
        else if(!identical(S.curr[1], S.next[1]) && identical(S.curr[2], S.next[2]))
          AB[i, j] = A[S.curr[1],S.next[1]]                     # transition for A
        else
          AB[i, j] = 0
      }
    }
    AB[i, i] = -sum(AB[i, ])
  }
  return(AB)
}
joint.inten <- function(A.n, B.n)
{
  # construct a joint intensity with structure A -> B
  # ordering <a1, b1> , <a2, b1> , <a1, b2> , <a2, b2> , <a1, b3> , <a2, b3>.
  order = unlist(lapply(c(1:B.n), function(x)paste(1:A.n, x, sep = "")))
  size = A.n * B.n
  AB = matrix(rep(0, size^2), ncol = size)
  for(i in 1:size)
  {
    ind = c()
    pro.ind = c()
    for(j in 1:size)
    {
      if(i==j)
        AB[i, j] =  rgamma(1, shape = 2, rate = 2) # hyperparameters 2, 2
      else{
        ind = c(ind, j)
        S.curr = as.numeric(unlist(strsplit(order[i], split="")))
        S.next = as.numeric(unlist(strsplit(order[j], split="")))
        if((S.curr[1]==S.next[1]&&S.curr[2]!=S.next[2]) |(S.curr[1]!=S.next[1]&&S.curr[2]==S.next[2]))
          pro.ind = c(pro.ind, j)
      }
    dis = rdirichlet(1, rep(1, length(pro.ind)))  # hyperparameters : 1
    }
    AB[i, pro.ind] = dis 
    AB[i, i] = -AB[i, i]
    AB[i, setdiff(ind, pro.ind)] = 0
  }
   return(AB)
}
# hyperparameter: 1,1
getTran <- function(par.n, child.n)
{
  tran = array(data = rep(0, child.n * child.n * par.n), dim = c(child.n, child.n, par.n))
  for(i in 1:par.n)
  {
    for(j in 1:child.n)
      tran[j,, i] = rdirichlet(1, alpha = rep(1,child.n))
  }
  return(tran)
}
# compute sufficient statistics
# child = AE$AE
# child.par = AE$AE.CO
# child.n = AE$AMI.n * AE$E.n
# par.n = AE$CO.n
# child.traj = AE$AE.traj
computeStatC <- function(child, child.par, child.n, par.n, child.traj)
{
  M.num = vector("list", par.n)
  M.den = vector("list", par.n)
  est.inten = vector("list", par.n)
  inten = matrix(rep(0, child.n^2), ncol = child.n)
  M.num = lapply(M.num, function(x)x=inten)
  M.den =lapply(M.den, function(x)x=inten)
  est.inten =lapply(est.inten, function(x)x=inten)
  for(i in 1:(length(child)-1))
  {
    M.num[[child.par[i]]][child[i],child[i+1]] =  M.num[[child.par[i]]][child[i],child[i+1]] + 1
    M.den[[child.par[i]]][child[i],child[i+1]] =  M.den[[child.par[i]]][child[i],child[i+1]] + child.traj[i+1]- child.traj[i]
  }
  
  for(i in 1:par.n)
  {
    diag(M.num[[i]])= 0
    for(j in 1:child.n){
      est.inten[[i]][j, j] = -(2+sum(M.num[[i]][j,]))/(2+sum(M.den[[i]][j,])) # hyperparameter alpha_c = 2, tau = 2
      for(k in 1:(child.n))
      {
        if(((j + k) != 5)&k!=j)
          M.num[[i]][j,k] = M.num[[i]][j,k] + 1  # hyperparameter alpha_cc = 1, alpha_c = 2
      }
      for(k in 1:(child.n))
      {
        if(k != j)
          est.inten[[i]][j, k] =M.num[[i]][j, k]/(sum(M.num[[i]][j,])) 
      }
    }
  }
  return(list(est.inten = est.inten, M.num = M.num, M.den = M.den))
}
# AE.stat = computeStatC(AE, AE.CO, AMI.n*E.n, CO.n, AE.traj)
# child = AE
# child.par = AE.CO
# child.n = AMI.n*E.n
# par.n = CO.n
# child.traj = AE.traj
# compute likelihood
computeLL <- function(est.inten, M.num,  M.den)
{
  # compute likelihood 
  ll = 0 
  for(i in 1:length(est.inten))
  {
    M.xx = est.inten[[i]]
    diag(M.xx) = 0
    child.n = nrow(M.xx)
    M = unlist(lapply(c(1:child.n), function(x)sum(M.num[[i]][x,])))
    Du = unlist(lapply(c(1:child.n), function(x)sum(M.den[[i]][x,])))
    q = abs(diag(est.inten[[i]]))
    ll = ll + sum(M *log(q))
    ll = ll - sum(q * Du)
    # print(q)
    # print(M *log(q)-q * Du)
    # avoid NaN value
    M.xx[which(M.xx==0)] = 1
    ll = ll + sum(M.num[[i]] * log(M.xx))
  }
  return(ll)
}
# test the sampling and real distribution for discrete-time variable
computeStatD <- function(child, child.par, child.n, par.n)
{
  M.Num = array(data = rep(0, child.n * child.n * par.n), dim = c(child.n, child.n, par.n))
  M.Den = array(data = rep(0, child.n * child.n * par.n), dim = c(child.n, child.n, par.n))
  for(i in 1:(length(child)-1))
  {
    M.Num[child[i], child[i+1], child.par[i]] =   M.Num[child[i], child[i+1], child.par[i]] + 1
    M.Den[child[i], 1, child.par[i]] =   M.Den[child[i], 1, child.par[i]] + 1
    M.Den[child[i], 2, child.par[i]] =   M.Den[child[i], 2, child.par[i]] + 1
  }
  # est.tran = M.Num/M.Den
  return(list(M.Num = 1+M.Num, M.Den = M.Den+2))
}
generateDat <- function(N, AE.inten, ASS.inten, P.tran, LC.tran, LV.tran, CO.tran, AE.ini, ASS.ini, P.ini, LC.ini, CO.ini)
{
  # generate data 
  AE = c()
  ASS = c()
  P = c()
  LC = c()
  LV = c()
  CO = c()
  # states of continuous-parents at discrete-time point
  P.AMI = c()
  LC.P = c()
  LV.LC = c()
  CO.LV = c()
  AE.CO = c()
  ASS.CO = c()
  # generate data from initial model
  AE.curr = rcat(1, AE.ini)
  ASS.curr = rcat(1, ASS.ini)
  P.curr = rcat(1, P.ini)
  LC.curr = rcat(1, LC.ini)
  LV.curr = rcat(1, LV.ini)
  CO.curr = rcat(1, CO.ini)
  
  AE = c(AE, AE.curr)
  ASS = c(ASS, ASS.curr)
  P = c(P, P.curr)
  LC = c(LC, LC.curr)
  LV = c(LV, LV.curr)
  CO = c(CO, CO.curr)
  # pick the intensities for AE, ASS at time 0
  q.AE = -AE.inten[[CO.curr]][AE.curr, AE.curr]
  q.ASS = -ASS.inten[[CO.curr]][ASS.curr, ASS.curr]
  
  # track time points for AE and ASS continuous component
  AE.currTime = 0
  ASS.currTime = 0
  AE.traj = c(AE.currTime)
  ASS.traj = c(ASS.currTime)
  i = 1
  while(i<=N)
  {
    while(AE.currTime <= i)
    {
      AE.istran = 1- exp(-q.AE * (i - AE.currTime))
      AE.Tran = rcat(1, c(AE.istran, 1-AE.istran))
      if(AE.Tran==1)
      {
        time = rexp(1, rate = q.AE)
        while((time + AE.currTime) >= i)
          time = rexp(1, rate = q.AE)
        # save time
        AE.currTime = AE.currTime + time
        AE.traj = c(AE.traj, AE.currTime)
        # compute distribution for next transition
        AE.tran.dis = AE.inten[[CO.curr]][AE.curr,]
        AE.tran.dis[AE.curr] = 0
        AE.curr = rcat(1, AE.tran.dis)
        # save states
        AE = c(AE, AE.curr)
        AE.CO = c(AE.CO, CO.curr)
      }
      else{
        break
      }
    }
    if(AE.currTime>=i)
    {
      AE.traj= AE.traj[1:(length(AE.traj)-1)]
      AE = AE[1:(length(AE)-1)]
      AE.CO = AE.CO[1:(length(AE.CO)-1)]
      
    }else{
      AE.currTime = i
      AE.traj = c(AE.traj, AE.currTime)
      AE = c(AE, AE.curr)
      AE.CO = c(AE.CO, CO.curr)
    }
    while(ASS.currTime <=i)
    {
      ASS.istran = 1- exp(-q.ASS * (i - ASS.currTime))
      ASS.Tran = rcat(1, c(ASS.istran, 1-ASS.istran))
      if(ASS.Tran==1)
      {
        time = rexp(1, rate = q.ASS)
        while((time + ASS.currTime) >= i)
          time = rexp(1, rate = q.ASS)
        # save time
        ASS.currTime = ASS.currTime + time
        ASS.traj = c(ASS.traj, ASS.currTime)
        # compute distribution for next transition
        ASS.tran.dis = ASS.inten[[CO.curr]][ASS.curr,]
        ASS.tran.dis[ASS.curr] = 0
        ASS.curr = rcat(1, ASS.tran.dis)
        # save states
        ASS = c(ASS, ASS.curr)    
        ASS.CO = c(ASS.CO, CO.curr)
      }
      else{
        break
      }
    }
    if(ASS.currTime>=i)
    {
      ASS.traj = ASS.traj[1:(length(ASS.traj)-1)]
      ASS = ASS[1:(length(ASS)-1)]
      ASS.CO = ASS.CO[1:(length(ASS.CO)-1)]
    }else{
      ASS.currTime = i
      ASS.traj = c(ASS.traj, ASS.currTime)
      ASS = c(ASS, ASS.curr)
      ASS.CO = c(ASS.CO, CO.curr)
    }
    
    # generate for discrete component
    order = rep(1:AMI.n, E.n)
    P.curr = rcat(1, P.tran[P.curr,,order[AE.curr]])
    P.AMI = c(P.AMI, order[AE.curr])
    P = c(P, P.curr)
    LC.curr = rcat(1, LC.tran[LC.curr,,P.curr])
    LC.P = c(LC.P, P.curr)
    LC = c(LC, LC.curr)
    LV.curr = rcat(1, LV.tran[LV.curr,,LC.curr])
    LV.LC = c(LV.LC, LC.curr)
    LV = c(LV, LV.curr)
    CO.curr = rcat(1, CO.tran[CO.curr,,LV.curr])
    CO.LV = c(CO.LV, LV.curr)
    CO = c(CO, CO.curr)
    # move to next slice
    i = i + 1
    # go back to continuous component 
    # change intensity for continuous components
    q.AE = -AE.inten[[CO.curr]][AE.curr, AE.curr]
    q.ASS = -ASS.inten[[CO.curr]][ASS.curr, ASS.curr]
  }
  P = list(P = P, P.AMI = P.AMI, P.n = P.n, AMI.n = AMI.n, P.tran = P.tran)
  LC = list(LC = LC, LC.P = LC.P, LC.n = LC.n, P.n = P.n, LC.tran = LC.tran)
  LV = list(LV = LV, LV.LC = LV.LC, LV.n = LV.n, LC.n = LC.n, LV.tran = LV.tran)
  CO = list(CO = CO, CO.LV = CO.LV, CO.n =  CO.n, LV.n = LV.n, CO.tran = CO.tran)
  AE = list(AE = AE, AE.CO = AE.CO, AE.traj = AE.traj, AMI.n = AMI.n, E.n = E.n, CO.n = CO.n, AE.inten = AE.inten)
  ASS = list(ASS = ASS, ASS.CO = ASS.CO, ASS.traj = ASS.traj,AN.n = AN.n, SS.n = SS.n, CO.n = CO.n, ASS.inten = ASS.inten)
  
  return(list(P = P, LC = LC, LV = LV, CO = CO, AE = AE, ASS = ASS, 
              P.tran = P.tran, LC.tran = LC.tran, LV.tran = LV.tran, CO.tran = CO.tran,
              AE.inten = AE.inten, ASS.inten = ASS.inten))
}
computeDis <- function(N, testDat, AE.inten, ASS.inten, P.tran, LC.tran, LV.tran, CO.tran, AE.ini, ASS.ini, P.ini, LC.ini, CO.ini)
{
  learnDat = generateDat(N, AE.inten, ASS.inten, P.tran, LC.tran, LV.tran, CO.tran, AE.ini, ASS.ini, P.ini, LC.ini, CO.ini)
  # data for estimate parameters
  P = learnDat$P
  LC = learnDat$LC
  LV = learnDat$LV
  CO = learnDat$CO
  AE = learnDat$AE
  ASS = learnDat$ASS
  # compute statistics
  statP = computeStatD(P$P, P$P.AMI, P$P.n, P$AMI.n)
  statLC = computeStatD(LC$LC, LC$LC.P, LC$LC.n, LC$P.n)
  statLV = computeStatD(LV$LV, LV$LV.LC, LV$LV.n, LV$LC.n)
  statCO = computeStatD(CO$CO, CO$CO.LV, CO$CO.n, CO$LV.n)
  AE.stat = computeStatC(AE$AE, AE$AE.CO, AE$AMI.n * AE$E.n,  AE$CO.n, AE$AE.traj)
  ASS.stat = computeStatC(ASS$ASS, ASS$ASS.CO, ASS$AN.n * ASS$SS.n,  ASS$CO.n, ASS$ASS.traj)
  # estimate parameters
  est.P.tran = statP$M.Num/statP$M.Den
  est.LC.tran = statLC$M.Num/statLC$M.Den
  est.LV.tran = statLV$M.Num/statLV$M.Den
  est.CO.tran = statCO$M.Num/statCO$M.Den

  AE.est.inten = AE.stat$est.inten
  ASS.est.inten = ASS.stat$est.inten
  # test data
  P = testDat$P
  LC = testDat$LC
  LV = testDat$LV
  CO = testDat$CO
  AE = testDat$AE
  ASS = testDat$ASS
  # recompute sufficient statistics
  AE.stat = computeStatC(AE$AE, AE$AE.CO, AE$AMI.n * AE$E.n,  AE$CO.n, AE$AE.traj)
  ASS.stat = computeStatC(ASS$ASS, ASS$ASS.CO, ASS$AN.n * ASS$SS.n,  ASS$CO.n, ASS$ASS.traj)
  statP = computeStatD(P$P, P$P.AMI, P$P.n, P$AMI.n)
  statLC = computeStatD(LC$LC, LC$LC.P, LC$LC.n, LC$P.n)
  statLV = computeStatD(LV$LV, LV$LV.LC, LV$LV.n, LV$LC.n)
  statCO = computeStatD(CO$CO, CO$CO.LV, CO$CO.n, CO$LV.n)
  # compute likelihood
  est.AE.LL = computeLL(AE.est.inten, AE.stat$M.num,  AE.stat$M.den)
  est.ASS.LL = computeLL(ASS.est.inten, ASS.stat$M.num,  ASS.stat$M.den)
  AE.LL = computeLL(AE.inten, AE.stat$M.num,  AE.stat$M.den)
  ASS.LL = computeLL(ASS.inten, ASS.stat$M.num,  ASS.stat$M.den)
  
  est.ll =  sum(statP$M.Num*log(est.P.tran))+
        sum(statLC$M.Num*log(est.LC.tran))+
        sum(statLV$M.Num*log(est.LV.tran))+
        sum(statCO$M.Num*log(est.CO.tran)) + est.AE.LL + est.ASS.LL
  ll =  sum(statP$M.Num*log(P.tran))+
        sum(statLC$M.Num*log(LC.tran))+
        sum(statLV$M.Num*log(LV.tran))+
        sum(statCO$M.Num*log(CO.tran)) + AE.LL + ASS.LL
  abs((ll-est.ll))
  abs(est.AE.LL + est.ASS.LL - (AE.LL + ASS.LL))
  return(list(true.ll = ll, est.ll = est.ll))
}

# true model
AMI.n = 2
E.n = 2
AN.n = 2
SS.n = 2
P.n = 2
LC.n = 2
LV.n = 2
CO.n = 2
# intensity matrix for AMI and E, AN and SS
AE.inten = list()
ASS.inten = list()
for(i in 1:CO.n)
{
  AE.inten[[i]] = joint.inten(AMI.n, E.n)
  ASS.inten[[i]] = joint.inten(AN.n, SS.n)
}
# CPT for P, LC, LC, CO
P.tran = getTran(AMI.n, P.n)
LC.tran = getTran(P.n, LC.n)
LV.tran = getTran(LC.n, LV.n)
CO.tran = getTran(LV.n, CO.n)
# initial distribution
AE.ini = as.vector(rdirichlet(1, alpha = rep(1, AMI.n * E.n)))
ASS.ini = as.vector(rdirichlet(1, alpha = rep(1, AN.n * SS.n)))
P.ini = as.vector(rdirichlet(1, alpha = rep(1, P.n)))
LC.ini = as.vector(rdirichlet(1, alpha = rep(1, LC.n)))
LV.ini = as.vector(rdirichlet(1, alpha = rep(1, LV.n)))
CO.ini = as.vector(rdirichlet(1, alpha = rep(1, CO.n)))

testDat = generateDat(10000, AE.inten, ASS.inten, P.tran, LC.tran, LV.tran, CO.tran, AE.ini, ASS.ini, P.ini, LC.ini, CO.ini)


# complete case
N.slice = c(10, 20, 40, 80, 160, 320, 640, 1024, 2048, 4096)
N.slice = c(2000)
true.ll = c()
est.ll = c()
distance = c()
for(i in 1:(length(N.slice)))
{
  ll = computeDis(N.slice[i],testDat, AE.inten, ASS.inten, P.tran, LC.tran, LV.tran, CO.tran, AE.ini, ASS.ini, P.ini, LC.ini, CO.ini)
  true.ll = c(true.ll, ll$true.ll)
  est.ll = c(est.ll, ll$est.ll)
}
abs(true.ll - est.ll)
true.ll
dataVar8 = data.frame(N.slice = log10(N.slice), true.ll = true.ll, est.ll = est.ll, diff = abs(true.ll - est.ll))
write.csv(dataVar8, "~/Papers/LearnHTBN/Var8(complete)LL.diff.csv", row.names = FALSE, quote = FALSE)

abs(m$true.ll-m$est.ll)/abs(m$true.ll)


# incomplete case
getScratch<- function(A.n, B.n)
{
  # construct a joint intensity with structure A -> B
  # ordering <a1, b1> , <a2, b1> , <a1, b2> , <a2, b2> , <a1, b3> , <a2, b3>.
  order = unlist(lapply(c(1:B.n), function(x)paste(1:A.n, x, sep = "")))
  size = A.n * B.n
  AB = matrix(rep(0, size^2), ncol = size)
  for(i in 1:size)
  {
    ind = c()
    pro.ind = c()
    for(j in 1:size)
    {
      if(i==j)
        AB[i, j] =  -1
      else{
        S.curr = as.numeric(unlist(strsplit(order[i], split="")))
        S.next = as.numeric(unlist(strsplit(order[j], split="")))
        if((S.curr[1]==S.next[1]&&S.curr[2]!=S.next[2]) |(S.curr[1]!=S.next[1]&&S.curr[2]==S.next[2]))
          AB[i,j] = 1
        else{
          AB[i,j] = 0
        }
      }
    }
  }
  return(AB)
}
# estimate the intensity from posterior distribution
estInten <- function(child.n, par.n, poste.diag, poste.theta, scratch.joint.inten)
{
  inten.dig = matrix(rep(0, child.n* par.n), nrow = par.n)
  for(i in 1:par.n)
  {
    for(j in 1:(child.n))
    {
      q = hist(poste.diag[,i,j], breaks = 20)
      inten.dig[i,j] = q$mids[which(q$counts ==  max(q$counts))][[1]]
    }
  }
  est.inten = list()
  for(i in 1:par.n)
  {
    est.inten[[i]] = scratch.joint.inten
    est.inten[[i]][,] = 0
    diag(est.inten[[i]]) = -inten.dig[i,]
  }
  for(i in 1:par.n)
  {
    for(j in 1:(child.n))
    {
      ind = 1;
      for(k in 1:(child.n))
      {
        if(scratch.joint.inten[j,k] ==1)
        {
          q = hist(poste.theta[,i,j,ind], breaks = 20)
          est.inten[[i]][j, k] = q$mids[which(q$counts ==  max(q$counts))][[1]]
          ind = ind + 1
        }
      }
    }
  }
  return(est.inten)
}
# estimate CPTs from posterior distribution
estTran <- function(par.n, child.n, poste)
{
  est.tran = array(rep(0, par.n * child.n * child.n), dim = c( child.n, child.n, par.n))
  for(i in 1:par.n)
  {
    for(j in 1:child.n)
      for(k in 1:child.n)
      {
        q = hist(poste[,i,j,k], breaks = 20)
        est.tran[j,k,i] = q$mids[which(q$counts ==  max(q$counts))][[1]]
      }
  }
  return(est.tran)
}
# generate missing 
generateMiss <- function(traj, states, rate)
{
  curr = 0 
  traj.mis  = c(curr)
  while(max(traj.mis) < max(traj))
  {
    curr = curr + rexp(1, rate = rate)
    traj.mis = c(traj.mis, curr)
  }
  traj.mis = traj.mis[-length(traj.mis)]
  states.mis = c() # initial state
  for(i in 1:(length(traj)-1))
  {
    ind = which(traj.mis>=traj[i]&traj.mis<=traj[i+1])
    states.mis = c(states.mis, rep(states[i], length(ind)))
  }
  return(list(traj.mis = traj.mis, states.mis = states.mis))
}
run <- function(N, rate, testDat,  AE.inten, ASS.inten, P.tran, LC.tran, LV.tran, CO.tran, AE.ini, ASS.ini, P.ini, LC.ini, CO.ini){
padeC=rbind(c(120, 60, 12, 1, 0, 0, 0, 0, 0, 0), 
            c(30240, 15120, 3360, 420, 30, 1, 0, 0, 0, 0), 
            c(17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1, 0, 0), 
            c(17643225600, 8821612800, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1))
padeCbig= c(64764752532480000, 32382376266240000, 7771770303897600,
            1187353796428800, 129060195264000, 10559470521600,
            670442572800, 33522128640, 1323241920, 40840800,
            960960, 16380, 182, 1)
learnDat = generateDat(N, AE.inten, ASS.inten, P.tran, LC.tran, LV.tran, CO.tran, AE.ini, ASS.ini, P.ini, LC.ini, CO.ini)
AE.info = learnDat$AE
LC.info = learnDat$LC
P.info = learnDat$P
LV.info = learnDat$LV
CO.info = learnDat$CO
ASS.info = learnDat$ASS

AMI.n = 2
E.n = 2
P.n = P.info$P.n
LC.n = LC.info$LC.n
LV.n = LV.info$LV.n
CO.n = CO.info$CO.n
AN.n = ASS.info$AN.n
SS.n = ASS.info$SS.n
CO = CO.info$CO
AE = AE.info$AE
ASS = ASS.info$ASS
AE.traj = AE.info$AE.traj
mis = generateMiss(AE.traj, AE,rate)
AE.mis = mis$states.mis
AE.traj.mis = mis$traj.mis

ASS.traj = ASS.info$ASS.traj
mis = generateMiss(ASS.traj, ASS,rate)
ASS.mis = mis$states.mis
ASS.traj.mis = mis$traj.mis

scratch.joint.inten = getScratch(AMI.n, E.n)
dat = list(padeCbig = padeCbig, padeC = padeC, scratch_joint_inten = scratch.joint.inten, 
           AMI_n = AMI.n, E_n = E.n, CO_n = CO.n, 
            AE_N = length(AE.mis), CO_states = CO, 
            AE_states = AE.mis, AE_traj_mis = AE.traj.mis, 
           AN_n = AN.n, SS_n = SS.n, ASS_N = length(ASS.mis),
           ASS_states = ASS.mis, ASS_traj_mis = ASS.traj.mis, 
           P_AMI_states = P.info$P.AMI, P_states = P.info$P, P_n = P.n, 
           LC_P_states = LC.info$LC.P, LC_states = LC.info$LC, LC_n = LC.n, 
           LV_LC_states = LV.info$LV.LC, LV_states = LV.info$LV, LV_n = LV.n, 
           CO_LV_states = CO.info$CO.LV, CO_states = CO.info$CO, CO_n = CO.n, 
           D_N = length(CO))

fit = stan("Var8.stan", data = dat, iter = 1000, chain = 1)
la = extract(fit)

est.P.tran = estTran(AMI.n, P.n, la$P_tran)
est.LC.tran = estTran(P.n, LC.n, la$LC_tran)
est.LV.tran = estTran(LC.n, LV.n, la$LV_tran)
est.CO.tran = estTran(LV.n, CO.n, la$CO_tran)
est.AE.inten = estInten(AE.info$AMI.n * AE.info$E.n, AE.info$CO.n, la$AE_dig, la$AE_theta, scratch.joint.inten)
est.ASS.inten = estInten(ASS.info$AN.n * ASS.info$SS.n, ASS.info$CO.n, la$ASS_dig, la$ASS_theta, scratch.joint.inten)

# fix test data
# testDat = generateDat(N)
P = testDat$P
LC = testDat$LC
LV = testDat$LV
CO = testDat$CO
AE = testDat$AE
ASS = testDat$ASS

# compute statistics
statP = computeStatD(P$P, P$P.AMI, P$P.n, P$AMI.n)
statLC = computeStatD(LC$LC, LC$LC.P, LC$LC.n, LC$P.n)
statLV = computeStatD(LV$LV, LV$LV.LC, LV$LV.n, LV$LC.n)
statCO = computeStatD(CO$CO, CO$CO.LV, CO$CO.n, CO$LV.n)
AE.stat = computeStatC(AE$AE, AE$AE.CO, AE$AMI.n * AE$E.n,  AE$CO.n, AE$AE.traj)
ASS.stat = computeStatC(ASS$ASS, ASS$ASS.CO, ASS$AN.n * ASS$SS.n,  ASS$CO.n, ASS$ASS.traj)


# compute likelihood
est.AE.LL = computeLL(est.AE.inten, AE.stat$M.num,  AE.stat$M.den)
est.ASS.LL = computeLL(est.ASS.inten, ASS.stat$M.num,  ASS.stat$M.den)
AE.LL = computeLL(AE.info$AE.inten, AE.stat$M.num,  AE.stat$M.den)
ASS.LL = computeLL(ASS.info$ASS.inten, ASS.stat$M.num,  ASS.stat$M.den)

est.ll =  sum(statP$M.Num*log(est.P.tran))+
  sum(statLC$M.Num*log(est.LC.tran))+
  sum(statLV$M.Num*log(est.LV.tran))+
  sum(statCO$M.Num*log(est.CO.tran)) + est.AE.LL 
ll =  sum(statP$M.Num*log(P.info$P.tran))+
  sum(statLC$M.Num*log(LC.info$LC.tran))+
  sum(statLV$M.Num*log(LV.info$LV.tran))+
  sum(statCO$M.Num*log(CO.info$CO.tran)) + AE.LL 

return(list(est.ll = est.ll, true.ll = ll, est.AE.inten = est.AE.inten, est.ASS.inten = est.ASS.inten, est.P.tran = est.P.tran))
}
# true model
AMI.n = 2
E.n = 2
AN.n = 2
SS.n = 2
P.n = 2
LC.n = 2
LV.n = 2
CO.n = 2
# intensity matrix for AMI and E, AN and SS
AE.inten = list()
ASS.inten = list()
for(i in 1:CO.n)
{
  AE.inten[[i]] = joint.inten(AMI.n, E.n)
  ASS.inten[[i]] = joint.inten(AN.n, SS.n)
}
# CPT for P, LC, LC, CO
P.tran = getTran(AMI.n, P.n)
LC.tran = getTran(P.n, LC.n)
LV.tran = getTran(LC.n, LV.n)
CO.tran = getTran(LV.n, CO.n)
# initial distribution
AE.ini = as.vector(rdirichlet(1, alpha = rep(1, AMI.n * E.n)))
ASS.ini = as.vector(rdirichlet(1, alpha = rep(1, AN.n * SS.n)))
P.ini = as.vector(rdirichlet(1, alpha = rep(1, P.n)))
LC.ini = as.vector(rdirichlet(1, alpha = rep(1, LC.n)))
LV.ini = as.vector(rdirichlet(1, alpha = rep(1, LV.n)))
CO.ini = as.vector(rdirichlet(1, alpha = rep(1, CO.n)))
testDat = generateDat(10000, AE.inten, ASS.inten, P.tran, LC.tran, LV.tran, CO.tran, AE.ini, ASS.ini, P.ini, LC.ini, CO.ini)

N.slice = c(10, 20, 40, 80, 160, 320, 640, 1024, 2048, 4096)
rate.list = c(0.5, 1, 4, 10)
st1 = Sys.time()
for(i in 1:1)
{
  est.ll = c()
  true.ll = c()
  est.AE.inten = list()
  est.ASS.inten = list()
  est.P.tran = list()
  
  rate = rate.list[i]
  for(j in 1:10)
  {
    st = Sys.time()
    cat("Now the length of sequence is ", N.slice[j], ", rate is ", rate, "\n")
    ll =  run(N.slice[j], rate, testDat,  AE.inten, ASS.inten, P.tran, LC.tran, LV.tran, CO.tran, AE.ini, ASS.ini, P.ini, LC.ini, CO.ini)
    est.ll = c(est.ll, ll$est.ll)
    true.ll = c(true.ll, ll$true.ll)
    est.AE.inten[[length(est.AE.inten) + 1]] = ll$est.AE.inten
    est.ASS.inten[[length(est.ASS.inten) + 1]] = ll$est.ASS.inten
    est.P.tran[[length(est.P.tran) + 1]] = ll$est.P.tran
  }
  data = data.frame(N.slice = log10(N.slice), ll = est.ll, true.ll = true.ll, diff = abs(true.ll-est.ll))
  write.csv(data, paste("Var8(mis)_rate", rate,".LL.diff.csv", sep = ""), row.names = FALSE, quote = FALSE)
  write.list(est.AE.inten, paste("Var8(mis)_rate", rate,".est.AE.inten.csv", sep = ""), quote = FALSE, row.names = FALSE, eol = "\n")
  write.list(est.ASS.inten, paste("Var8(mis)_rate", rate,".est.ASS.inten.csv", sep = ""), quote = FALSE, row.names = FALSE,  eol = "\n")
  write.list(est.P.tran, paste("Var8(mis)_rate", rate,".est.P.tran.csv", sep = ""), quote = FALSE, row.names = FALSE,  eol = "\n")
  write.list(AE.inten, paste("Var8(mis)_rate", rate,".true.AE.inten.csv", sep = ""), quote = FALSE, eol = "\n")
  write.list(ASS.inten, paste("Var8(mis)_rate", rate,".true.ASS.inten.csv", sep = ""), quote = FALSE, eol = "\n")

  write.list(list(P.tran), paste("Var8(mis)_rate", rate,".true.P.tran.csv", sep = ""), quote = FALSE, eol = "\n")
  cat("It takes", format(Sys.time() - st), "to finish running for ", rate, "\n" )
}
cat("Total it takes", format(Sys.time() - st1), "to finish.\n")
