# Forward sampling for a continuous-time variable D -> C
rm(list = ls())
library(LaplacesDemon)
library(rstan)
library(erer)
# generate data
generateCompDC <- function(N, q1, q2, q3, q4, D.tran)
{
  # set.seed(198)
  A.traj = c() # time trajectory for discrete-time variable
  B.traj = c() # time trajectory for continuous-time variable
  
  # sampled states
  D.states = c()
  C.states = c()
  # CPTs for D and C, no arcs in inital BN
  D.theta = c(0.2, 0.8)
  C.theta = c(0.4, 0.6)
  # D.tran = as.data.frame(matrix(c(0.9,0.1,0.2,0.8), ncol = 2, byrow = TRUE))
  # intensity matrices for C given the configuration of its parent D

  intensity.d1 = as.data.frame(matrix(c(-q1,q1,q2,-q2), ncol = 2, byrow = TRUE))
  intensity.d2 = as.data.frame(matrix(c(-q3,q3,q4,-q4), ncol = 2, byrow = TRUE))
  
  # generate sampling at time 1
  i = 1
  D.currstate = rcat(1, D.theta)
  C.currstate = rcat(1, C.theta)
  currTime = 0
  D.states = c(D.states, D.currstate)
  C.states = c(C.states, C.currstate)
  B.traj = c(B.traj, 0)
  # cat("currTime is ", currTime, "\n")
  time.traj = c()
  C.D.states = c(D.currstate)
  while(i <=N)
  {
    # choose the intensity for C
    if(D.currstate==1)
      intensity = intensity.d1
    else
    {
      intensity = intensity.d2 
    }
      # choose rate 
      q = -intensity[C.currstate, C.currstate]
      # cat("D value: ", D.currstate, " \tC value: ", C.currstate, " \t q is: ", q, "\n")
      # compute the probability for a transition between current time points and next time slice
      prob.tran = 1- exp(-q * (i - currTime))
      # generate a sample for variable isTran, 1 transition, 2 no transition
      isTran = rcat(1, c(prob.tran, 1-prob.tran))
      # there is a transition 
      if(isTran == 1)
      {
        # cat("transition between time ", currTime, " and ", i , "? Yes. \t" )
        # generate next time
        time = rexp(1, rate = q)
        # reject if it transitions after next discrete-time slice
        while((time + currTime) >= i)
          time = rexp(1, rate = q)
        time.traj = c(time.traj, time)
        C.currstate = setdiff(c(1,2), tail(C.states, 1))
        C.states = c(C.states, C.currstate)
        C.D.states = c(C.D.states, D.currstate)
        currTime = currTime + time
        # cat("next transition time: ", currTime, "\n")
        B.traj = c(B.traj, currTime)
      }
      else
      {
        # cat("transition between time ", currTime, " and ", i, "? NO.\n" )
        time.traj = c(time.traj, i-currTime)
        currTime = i
        D.currstate = rcat(1,as.numeric(as.vector(D.tran[D.currstate,])))
        D.states = c(D.states, D.currstate)
        # save the states of C at discrete-time slice
        C.states = c(C.states, C.currstate)
        C.D.states = c(C.D.states, D.currstate)
        B.traj = c(B.traj, currTime)
        # choose the intensity for C
        A.traj = c(A.traj, i)
        i = i + 1
      }
  }
  return(list(C.D.states = C.D.states, C.states = C.states, D.states = D.states, B.traj = B.traj, time.traj = time.traj))
}

# compute MAP and log likelihood
computeMAP <- function(N, q1, q2, q3, q4, D.tran, dataTest){
  # generate data for learning
  dataLearn = generateCompDC(N, q1, q2, q3, q4, D.tran)
  C.D.states = dataLearn$C.D.states
  C.states = dataLearn$C.states
  D.states = dataLearn$D.states
  time.traj = dataLearn$time.traj
  
  # estimate paramters for C
  # number of transtions under the configuration of its parent
  M = array(rep(0,8), dim = c(2,2,2))
  # the total time staying on each state
  TotalDu = matrix(rep(0,4), ncol = 2)
  for(i in 1:(length(C.states)-1))
  {
    M[C.D.states[i],C.states[i], C.states[i+1]] = M[C.D.states[i],C.states[i], C.states[i+1]] + 1
    TotalDu[C.D.states[i],C.states[i]] = TotalDu[C.D.states[i],C.states[i]] + time.traj[i]
  }
  # add up the pesudo counts 1 and 2
  estq1 = (1+M[1,1,2])/(TotalDu[1,1]+2)
  estq2 = (1+M[1,2,1])/(TotalDu[1,2]+2)
  estq3 = (1+M[2,1,2])/(TotalDu[2,1]+2)
  estq4 = (1+M[2,2,1])/(TotalDu[2,2]+2)
  estq = c(estq1, estq2, estq3, estq4)

  # estimate parameters for D
  D_M = matrix(rep(0,4), ncol = 2)
  D_M_t = c(rep(0,2))
  for(i in 1:N)
  {
    D_M[D.states[i],D.states[i+1]] = D_M[D.states[i],D.states[i+1]] + 1
    D_M_t[D.states[i]] = D_M_t[D.states[i]] + 1
  }
  # add up pesudo counts
  est.D.tran = (1+D_M)/(D_M_t+2)
  colnames(est.D.tran) = c("d1", "d2")
  
  # fixed data for testing
  C.D.states = dataTest$C.D.states
  C.states = dataTest$C.states
  D.states = dataTest$D.states
  time.traj = dataTest$time.traj
  
  # sufficient statistics of D for test
  D_M = matrix(rep(0,4), ncol = 2)
  D_M_t = c(rep(0,2))
  for(i in 1:N)
  {
    D_M[D.states[i],D.states[i+1]] = D_M[D.states[i],D.states[i+1]] + 1
    D_M_t[D.states[i]] = D_M_t[D.states[i]] + 1
  }
  # sufficient statistics for C
  # number of transtions under the configuration of its parent
  M = array(rep(0,8), dim = c(2,2,2))
  # the total time staying on each state
  TotalDu = matrix(rep(0,4), ncol = 2)
  for(i in 1:(length(C.states)-1))
  {
    M[C.D.states[i],C.states[i], C.states[i+1]] = M[C.D.states[i],C.states[i], C.states[i+1]] + 1
    TotalDu[C.D.states[i],C.states[i]] = TotalDu[C.D.states[i],C.states[i]] + time.traj[i]
  }
  
  # log likelihood in estimated models for C
  log_con = M[1,1,2] * log(estq1) - estq1 * TotalDu[1,1] + 
    M[1,2,1] * log(estq2) - estq2 * TotalDu[1,2] + 
    M[2,1,2] * log(estq3) - estq3 * TotalDu[2,1] + 
    M[2,2,1] * log(estq4) - estq4 * TotalDu[2,2] 
  
  # log likelihood in original models for C
  log_con_org = M[1,1,2] * log(q1) - q1 * TotalDu[1,1] + 
    M[1,2,1] * log(q2) - q2 * TotalDu[1,2] + 
    M[2,1,2] * log(q3) - q3 * TotalDu[2,1] + 
    M[2,2,1] * log(q4) - q4 * TotalDu[2,2] 
  
  estll = sum(D_M*log(est.D.tran)) + log_con
  truell = sum(D_M*log(D.tran)) + log_con_org 
  return(list(est.ll = estll, true.ll = truell, est.q = estq, est.D.tran = est.D.tran))
}
# true models
q1 = rgamma(1, shape = 2, rate = 2)
q2 = rgamma(1, shape = 2, rate = 2)
q3 = rgamma(1, shape = 2, rate = 2)
q4 = rgamma(1, shape = 2, rate = 2)
trueq = c(q1, q2, q3, q4)
# D.tran = as.data.frame(matrix(c(0.9,0.1,0.2,0.8), ncol = 2, byrow = TRUE))
D.tran = as.data.frame(matrix(rep(0,4), ncol = 2, byrow = TRUE))
D.tran = rdirichlet(nrow(D.tran),c(1,1))
# data for test
dataTest = generateCompDC(10000, q1, q2, q3, q4, D.tran)
N.slice = c(10, 20, 40, 80, 160, 320, 640, 1024, 2048, 4096)

# complete case D -> C
est.ll = c()
true.ll = c()
est.q = list()
est.D.tran = list()
for(k in 1:length(N.slice))
{
  N = N.slice[k]
  ll = computeMAP(N, q1, q2, q3, q4, D.tran, dataTest)
  est.ll = c(est.ll, ll$est.ll)
  true.ll = c(true.ll, ll$true.ll)
  est.q[[length(est.q)+1]] = ll$est.q
  est.D.tran[[length(est.D.tran)+1]] = ll$est.D.tran
}
dataDCcompLL = data.frame(N.slice = log10(N.slice), est.ll = est.ll, true.ll = true.ll,  diff = abs(true.ll - est.ll))
dataDCcompLL
data.est.q = matrix(c(unlist(est.q)), byrow = T, ncol = 4)
colnames(data.est.q) = paste("q",c(1:4), sep = "")
rownames(data.est.q) = N.slice
trueq = matrix(trueq, nrow = 1)
colnames(trueq) = paste("q",c(1:4), sep = "")
write.csv(dataDCcompLL, "~/Papers/LearnHTBN/DC(complete)LL.diff.csv", row.names = FALSE, quote = FALSE)
write.list(est.D.tran, "~/Papers/LearnHTBN/DC(complete).est.CPT.csv", quote = FALSE, row.names = FALSE, t.name = N.slice, eol = "\n")
write.csv(data.est.q, "~/Papers/LearnHTBN/DC(complete).est.q.csv", quote = FALSE)
write.csv(trueq, "~/Papers/LearnHTBN/DC(complete).true.q.csv", quote = FALSE)
write.csv(D.tran, "~/Papers/LearnHTBN/DC(complete).true.CPT.csv", quote = FALSE)



# structure D->C, where C is missing
run2 <-function(rate, q1,q2,q3,q4, D.tran, dataTest){

padeC=rbind(c(120, 60, 12, 1, 0, 0, 0, 0, 0, 0), 
            c(30240, 15120, 3360, 420, 30, 1, 0, 0, 0, 0), 
            c(17297280, 8648640, 1995840, 277200, 25200, 1512, 56, 1, 0, 0), 
            c(17643225600, 8821612800, 2075673600, 302702400, 30270240, 2162160, 110880, 3960, 90, 1))
padeCbig= c(64764752532480000, 32382376266240000, 7771770303897600,
            1187353796428800, 129060195264000, 10559470521600,
            670442572800, 33522128640, 1323241920, 40840800,
            960960, 16380, 182, 1)
N.slice = c(10, 20, 40, 80, 160, 320, 640, 1024, 2048, 4096)
est.ll = c()
true.ll = c()

true.D.tran = list()
est.D.tran = list()
true.inten = list()
est.inten = list()

for(k in 1:(length(N.slice)))
{
  cat("Now length of sequence is ", N.slice[k], "\n")
  N = N.slice[k]
  comp = generateCompDC(N, q1,q2,q3,q4, D.tran)
  D.states = comp$D.states
  C.states = comp$C.states
  B.traj = comp$B.traj
  # generate time-point evidence
  sys_time = 0
  # rate = 100
  rtp_sys = c()
  # generate observations for continuous-time variable C
  # random generate time points
  while(sys_time <= max(B.traj))
  {
    rtp_sys = c(rtp_sys, sys_time)
    time = rexp(1, rate = rate)
    sys_time = sys_time + time
  }
  rstates = c()
  # copy values
  for(i in 1:(length(B.traj)-1))
  {
    t = which(rtp_sys>= (B.traj[i] ) & rtp_sys<= (B.traj[i+1]))
    rstates = c(rstates, rep(C.states[i], length(t)))
  }
  data = list(padeC = padeC, padeCbig = padeCbig, D_states = D.states, 
              D_N = length(D.states), C_states = rstates, N = length(rstates), rtp_sys = rtp_sys)
  fit = stan("D->C(C missing).stan", data = data,  iter = 1000, chain = 1)
  la = extract(fit)
  CPT = matrix(rep(0,4), ncol = 2)
  for(i in 1:2)
  {
    for(j in 1:2)
    {
      q = hist(la$D_CPT[,i,j], breaks = 20)
      CPT[i,j] = q$mids[which(q$counts ==  max(q$counts))][[1]]
    }
  }

  # inten = matrix(rep(0,4), ncol = 2)
  inten = c(rep(0,4))
  for(i in 1:4)
  {
      q = hist(la$inten[,i], breaks = 20)
      inten[i] = q$mids[which(q$counts ==  max(q$counts))][[1]]
      # inten[j,i] = q$mids[which(q$counts ==  max(q$counts))][[1]]
      # inten[i,j] = mean(la$inten[,i,j])
  }
  inten = matrix(inten, ncol = 2, byrow = TRUE)
  org_inten = matrix(c(q1,q2,q3,q4), ncol = 2, byrow = TRUE)
  
  # fix test data set
  D.states = dataTest$D.states
  C.states = dataTest$C.states
  B.traj = dataTest$B.traj
  C.D.states = dataTest$C.D.states

  # compute sufficient statistics for D
  D.no = matrix(rep(0,4), ncol = 2)
  for(i in 1:(length(D.states)-1))
  {
    D.no[D.states[i],D.states[i+1]] =   D.no[D.states[i],D.states[i+1]] + 1
  }
  
  # compute sufficient statistics for C
  interval = diff(B.traj)
  C.no = matrix(rep(0,4), ncol = 2)
  C.dur = matrix(rep(0,4), ncol = 2)
  C.tran = matrix(rep(0,4), ncol = 2)
  for(i in 1:(length(C.states)-1))
  {
    C.no[C.states[i],C.states[i+1]] =   C.no[C.states[i],C.states[i+1]] + 1
    C.dur[C.D.states[i], C.states[i]]  =   C.dur[C.D.states[i], C.states[i]] + interval[i]
    
    if(C.states[i]!=C.states[i+1])
    {
      if(C.states[i] ==1)
      {
        C.tran[C.D.states[i], 1] =   C.tran[C.D.states[i], 1] + 1
      }
      else
      {
        C.tran[C.D.states[i], 2] =   C.tran[C.D.states[i], 2] + 1
      }
    }
  }

  # compute log likelihood in original model
  true.ll = c(true.ll, sum(D.no * log(D.tran)) +
              sum(C.tran * log(org_inten) - org_inten*C.dur))
  est.ll = c(est.ll, sum(D.no * log(CPT)) +
              sum(C.tran * log(inten) - inten*C.dur))
  

  true.D.tran[[length(true.D.tran) + 1]] = D.tran
  est.D.tran[[length(est.D.tran) + 1]] = CPT
  true.inten[[length(true.inten) + 1]] = org_inten
  est.inten[[length(est.inten) + 1]] = inten
}
return(list(est.ll = est.ll, true.ll = true.ll, 
            true.D.tran = true.D.tran, est.D.tran = est.D.tran, true.inten = true.inten, est.inten = est.inten))
}



st = Sys.time()
DC_rate0.5 = run2(0.5, q1, q2, q3, q4, D.tran, dataTest)
est.ll = DC_rate0.5$est.ll
true.ll = DC_rate0.5$true.ll
est.q = DC_rate10$est.inten
true.q = DC_rate10$true.inten
est.D.tran = DC_rate10$est.D.tran
true.D.tran = DC_rate10$true.D.tran

LL.mis.rate0.5 = data.frame(N.slice = log10(N.slice), est.ll = est.ll, true.ll = true.ll,  diff = abs(true.ll - est.ll))
trueq = matrix(trueq, nrow = 1)
colnames(trueq) = paste("q",c(1:4), sep = "")
write.csv(LL.mis.rate0.5, "~/Papers/LearnHTBN/DC(mis)_rate0.5.LL.diff.csv", row.names = FALSE, quote = FALSE)
write.list(est.D.tran, "~/Papers/LearnHTBN/DC(mis)_rate0.5.est.CPT.csv", quote = FALSE, row.names = FALSE, t.name = N.slice, eol = "\n")
write.csv(data.est.q, "~/Papers/LearnHTBN/DC(mis)_rate0.5.est.q.csv", quote = FALSE)
write.csv(trueq, "~/Papers/LearnHTBN/DC(mis)_rate0.5.true.q.csv", quote = FALSE)
write.csv(D.tran, "~/Papers/LearnHTBN/DC(mis)_rate0.5.true.CPT.csv", quote = FALSE)

compute <- function(rate)
{
  q1 = rgamma(1, shape = 2, rate = 2)
  q2 = rgamma(1, shape = 2, rate = 2)
  q3 = rgamma(1, shape = 2, rate = 2)
  q4 = rgamma(1, shape = 2, rate = 2)
  trueq = c(q1,q2,q3,q4)
  D.tran = as.data.frame(matrix(rep(0,4), ncol = 2, byrow = TRUE))
  D.tran = rdirichlet(nrow(D.tran),c(1,1))
  dataTest = generateCompDC(10000, q1,q2,q3,q4, D.tran)
  N.slice = c(10, 20, 40, 80, 160, 320, 640, 1024, 2048, 4096)
  st = Sys.time()
  DC.data = run2(rate, q1, q2, q3, q4, D.tran, dataTest)
  est.ll = DC.data$est.ll
  true.ll = DC.data$true.ll
  est.q = DC.data$est.inten
  true.q = DC.data$true.inten
  est.D.tran = DC.data$est.D.tran
  true.D.tran = DC.data$true.D.tran
  LL.mis = data.frame(N.slice = log10(N.slice), est.ll = est.ll, true.ll = true.ll,  diff = abs(true.ll - est.ll))
  trueq = matrix(trueq, nrow = 1)
  colnames(trueq) = paste("q",c(1:4), sep = "")
  write.csv(LL.mis, paste("~/Papers/LearnHTBN/DC(mis)_rate", rate,".LL.diff.2.csv", sep = ""), row.names = FALSE, quote = FALSE)
  write.list(est.D.tran, paste("~/Papers/LearnHTBN/DC(mis)_rate",rate, ".est.CPT.2.csv", sep=""), quote = FALSE, row.names = FALSE, t.name = N.slice, eol = "\n")
  write.list(est.q, paste("~/Papers/LearnHTBN/DC(mis)_rate",rate, ".est.q.2.csv", sep=""), quote = FALSE,row.names = FALSE, t.name = N.slice, eol = "\n")
  write.csv(trueq, paste("~/Papers/LearnHTBN/DC(mis)_rate",rate, ".true.q.2.csv", sep=""), quote = FALSE)
  write.csv(D.tran, paste("~/Papers/LearnHTBN/DC(mis)_rate",rate, ".true.CPT.2.csv", sep=""), quote = FALSE)
  cat("It takes", format(Sys.time() - st), "to finish.")
  
}

DC.mis.rate10 = run2(rate, q1, q2, q3, q4, D.tran, dataTest)
compute(0.5, dataTest, trueq)
compute(4, dataTest, trueq)
compute(1, dataTest, trueq)
compute(10, dataTest, trueq)

est10.ll = DC.mis.rate10$est.ll
true10.ll = DC.mis.rate10$true.ll

# D.tran = DC_rate0.5Par$true.D.tran
# est.D.tran = DC_rate0.5Par$est.D.tran
# est.inten = DC_rate0.5Par$est.D.tran
# true.inten = DC_rate0.5Par$true.inten
# est.q1 = unlist(lapply(c(1:10),function(x)est.inten[[x]][1,1]))
# est.q2 = unlist(lapply(c(1:10),function(x)est.inten[[x]][1,2]))
# est.q3 = unlist(lapply(c(1:10),function(x)est.inten[[x]][2,1]))
# est.q4 = unlist(lapply(c(1:10),function(x)est.inten[[x]][2,2]))
# pard1 = unlist(lapply(c(1:10),function(x)est.D.tran[[x]][1,1]))
# true.d1 = D.tran[[1]][1,1]
# pard2 = unlist(lapply(c(1:10),function(x)est.D.tran[[x]][1,2]))
# true.d2 = D.tran[[1]][1,2]
# true.inten = true.inten[[1]]
# data.par = data.frame(y = rep(1, length(pard1)),pard1 = pard1,  pard2 = pard2, 
#                       est.q1 = est.q1, est.q2 = est.q2, est.q3 = est.q3, est.q4 = est.q4)
# data.par.true = data.frame(true.d1 = true.d1, true.d2 = true.d2,
#                            true.q1 = true.inten[1,1], true.q2 = true.inten[1,2], 
#                            true.q3 = true.inten[2,1], true.q4 = true.inten[2,2] )
# 
# write.csv(data.par, "~/Papers/LearnHTBN/dcPar.csv", row.names = FALSE, quote = FALSE)
# write.csv(data.par.true, "~/Papers/LearnHTBN/dcpar.true.csv", row.names = FALSE, quote = FALSE)




# llDC_rate0.5 = run2(0.5,q1,q2,q3,q4)
# llDC_rate1 = run2(1,q1,q2,q3,q4)
# llDC_rate4 = run2(4, q1,q2,q3,q4)


# llCD_rate0.5 = run(0.5)
# llCD_rate1 = run(1)
# llCD_rate4 = run(4)
# dataCD_rate0.5 = data.frame(true.ll = llCD_rate0.5$true.ll, est.ll = llCD_rate0.5$est.ll, diff = llCD_rate0.5$true.ll - llCD_rate0.5$est.ll )
# dataCD_rate1 = data.frame(true.ll = llCD_rate1$true.ll, est.ll = llCD_rate1$est.ll, diff = llCD_rate1$true.ll - llCD_rate1$est.ll)
# dataCD_rate4 = data.frame(true.ll = llCD_rate4$true.ll, est.ll = llCD_rate4$est.ll, diff = llCD_rate4$true.ll - llCD_rate4$est.ll)
# 
# 
# dataDC_rate4 = data.frame(true.ll = llDC_rate4$true.ll, est.ll = llDC_rate4$est.ll, diff = llDC_rate4$true.ll - llDC_rate4$est.ll)
# dataDC_rate1 = data.frame(true.ll = llDC_rate1$true.ll, est.ll = llDC_rate1$est.ll, diff = llDC_rate1$true.ll - llDC_rate1$est.ll)
# dataDC_rate0.5 = data.frame(true.ll = llDC_rate0.5$true.ll, est.ll = llDC_rate0.5$est.ll, diff = llDC_rate0.5$true.ll - llDC_rate0.5$est.ll)
# dataDC = data.frame(N.slice = log10(N.slice),rate0.5.true = llDC_rate0.5$true.ll, rate0.5.est = llDC_rate0.5$est.ll, rate0.5.diff = abs(llDC_rate0.5$true.ll - llDC_rate0.5$est.ll),
#                   rate1.true = llDC_rate1$true.ll, rate1.est = llDC_rate1$est.ll, rate1.diff = abs(llDC_rate1$true.ll - llDC_rate1$est.ll),
#                   rate4.true = llDC_rate4$true.ll, rate4.est = llDC_rate4$est.ll, rate4.diff = abs(llDC_rate4$true.ll - llDC_rate4$est.ll),
#                   complete.est.ll = est.ll, complete.true.ll = true.ll,  complete.diff = abs(true.ll - est.ll))
# 
# 
# 
# write.csv(dataDC, "~/Papers/LearnHTBN/DC.csv", row.names = FALSE, quote = FALSE)
