# Learning for an illustrated example Pump <- AMI-> Enz
rm(list = ls())
library(rstan)
source("HF_ForSamp.R")
generateComp <- function(N, q1, q2, q3, q4, q5, q6, Pump.tran)
{
  original_log = c()
  learned_log = c()
  learnDat = forwardHF(N, q1, q2, q3, q4, q5, q6, Pump.tran)
  # continuous-time variables
  AMI_states = as.integer(learnDat$AMI.states)
  Enz_states = learnDat$Enz.states
  # discrete-time variables
  Pump_states = learnDat$Pump.states
  D_N = length(Pump_states)
  C_N = length(AMI_states)
  # time.traj = time.traj[1:2]
  time.traj = learnDat$time.traj
  B.traj = learnDat$B.traj
  Joint_states = c()
  for(i in 1:length(AMI_states))
  {
    if(AMI_states[i]==1&&Enz_states[i]==1)
      Joint_states = c(Joint_states,1)
    else if(AMI_states[i]==1&&Enz_states[i]==2)
      Joint_states = c(Joint_states,2)
    else if(AMI_states[i]==2&&Enz_states[i]==1)
      Joint_states = c(Joint_states,3)
    else
      Joint_states = c(Joint_states,4) 
  }
  return(list(D_N = D_N, C_N = C_N,Joint_states = Joint_states, 
              Pump_states = Pump_states, Pump_AMI_val =learnDat$Pump.AMI.val, B_traj = B.traj))
}


getDis <- function(N, comp2, q1, q2, q3, q4, q5, q6, Pump.tran)
{
  comp = generateComp(N,q1, q2, q3, q4, q5, q6, Pump.tran)
  D_N = comp$D_N
  C_N = comp$C_N
  Joint_states = comp$Joint_states
  Pump_states = comp$Pump_states 
  Pump_AMI_val = comp$Pump_AMI_val
  B_traj = comp$B_traj
  
  K = 2
  alpha <-c(1, 1)
  joint_alpha <- c(1,1,1,1)
  
  data = list(K = K, alpha = alpha, Joint_alpha = joint_alpha, D_N = D_N, C_N = C_N,
              Joint_states = Joint_states, Pump_states = Pump_states, Pump_AMI_val = Pump_AMI_val, B_traj = B_traj)
  fit <- stan(file = 'HF.stan', data = data, iter = 1000, chains = 1)
  la <- extract(fit, permuted = TRUE) 
  q = hist(la$Joint_int[,1,2], breaks = 20)
  est12= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Joint_int[,2,2], breaks = 20)
  est22= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Joint_int[,3,2], breaks = 20)
  est32= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Joint_int[,4,2], breaks = 20)
  est42= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Joint_int[,1,1], breaks = 20)
  est11= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Joint_int[,2,1], breaks = 20)
  est21= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Joint_int[,3,1], breaks = 20)
  est31= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Joint_int[,4,1], breaks = 20)
  est41= q$mids[which(q$counts ==  max(q$counts))][[1]]

  
  q = hist(la$Pump_CPT[,1,1,1], breaks = 20)
  cpt11= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Pump_CPT[,2,1,1], breaks = 20)
  cpt21= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Pump_CPT[,1,2,1], breaks = 20)
  cpt12= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Pump_CPT[,2,2,1], breaks = 20)
  cpt22= q$mids[which(q$counts ==  max(q$counts))][[1]]
  
  
  # generate new data for evaluation 
  D_N = comp2$D_N
  C_N = comp2$C_N
  Joint_states = comp2$Joint_states
  Pump_states = comp2$Pump_states 
  Pump_AMI_val = comp2$Pump_AMI_val
  B_traj = comp2$B_traj
  
  # compute log likelihood, ignoring first time slice
  # sufficient statistics
  M_P1 = c(rep(0,4))
  M_P2 = c(rep(0,4))
  for(i in 1:length(Pump_AMI_val))
  {
    # Pump_t+1 = 1
    if(Pump_states[i+1]==1)
    {
      # Pump_t = 1, AMI_t+1 = 1
      if(Pump_states[i]==1 && Pump_AMI_val[i]==1)
        M_P1[1] = M_P1[1] + 1
      else if(Pump_states[i]==1 && Pump_AMI_val[i]==2)
        M_P1[2] = M_P1[2] + 1
      else if(Pump_states[i]==2 && Pump_AMI_val[i]==1)
        M_P1[3] = M_P1[3] + 1
      else
        M_P1[4] = M_P1[4] + 1
    }
    else{
      # Pump_t = 1, AMI_t+1 = 1
      if(Pump_states[i]==1 && Pump_AMI_val[i]==1)
        M_P2[1] = M_P2[1] + 1
      else if(Pump_states[i]==1 && Pump_AMI_val[i]==2)
        M_P2[2] = M_P2[2] + 1
      else if(Pump_states[i]==2 && Pump_AMI_val[i]==1)
        M_P2[3] = M_P2[3] + 1
      else
        M_P2[4] = M_P2[4] + 1
    }
  }
  # log likelihood for discrete-time variables in original model
  dis_log = M_P1[1] * log(Pump.tran[1,1,1]) + M_P1[2] * log(Pump.tran[2,1,1]) +
  M_P1[3] * log(Pump.tran[1,1,2]) + M_P1[4] * log(Pump.tran[2,1,2]) + 
  M_P2[1] * log(Pump.tran[1,2,1]) + M_P2[2] * log(Pump.tran[2,2,1]) +
  M_P2[3] * log(Pump.tran[1,2,2]) + M_P2[4] * log(Pump.tran[2,2,2]) 
  

  q = hist(la$Pump_CPT[,1,1,1], breaks = 20)
  cpt11= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Pump_CPT[,2,1,1], breaks = 20)
  cpt21= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Pump_CPT[,1,2,1], breaks = 20)
  cpt12= q$mids[which(q$counts ==  max(q$counts))][[1]]
  q = hist(la$Pump_CPT[,2,2,1], breaks = 20)
  cpt22= q$mids[which(q$counts ==  max(q$counts))][[1]]
  # log likelihood for discrete-time variables in learned model
  dis_log_learned = M_P1[1] * log(cpt11) + M_P1[2] * log(cpt21) +
    M_P1[3] * log(cpt12) + M_P1[4] * log(cpt22) + 
    M_P2[1] * log(1-cpt11) + M_P2[2] * log(1-cpt21) +
    M_P2[3] * log(1-cpt12) + M_P2[4] * log(1-cpt22) 
  
  
  # log likelihood for continuous-time variables
  M = matrix(rep(0,16), ncol = 4) # numbers of leaving a state 1, 2, 3, 4
  TotalD = c(rep(0, 4))
  for(i in 1:(length(Joint_states)-1))
  {
    M[Joint_states[i],Joint_states[i+1]] =  M[Joint_states[i],Joint_states[i+1]] + 1
    TotalD[Joint_states[i]] = TotalD[Joint_states[i]] + B_traj[i];
  }
  for(i in 1:4)
    M[i,i] = sum(M[i,])
  
  # log likelihood in orginal model
  Joint_int = matrix(c(-(q1+q5),q1,q5,0,q2,-(q2+q5),0,q5,q6,0,-(q3+q6),q3,0,q6,q4,-(q4+q6)), ncol = 4, byrow = T)
  total_log = 0
  for(i in 1:4)
  {
    total_log = total_log + M[i,i] * log(-Joint_int[i,i]) + Joint_int[i,i] * TotalD[i]
  }
  Joint_pro = Joint_int/-diag(Joint_int)
  for(i in 1:4)
    for(j in 1:4)
    {
      if(i!=j && Joint_pro[i, j]!=0)
        total_log = total_log + M[i,j] * log(Joint_pro[i, j])
    }
  total_log = total_log + dis_log
  total_log_org = total_log
  
  # log likelihood in learned model
  Joint_int_learned = matrix(c(0,est11 ,est12,0,
                               est21,0,0,est22,
                               est31,0,0,est32,
                               0,est41,est42,0), ncol = 4, byrow = T)
  for(i in 1:4)
    Joint_int_learned[i,i] = -sum(Joint_int_learned[i,])
  total_log = 0
  for(i in 1:4)
  {
    total_log = total_log + M[i,i] * log(-Joint_int_learned[i,i]) + Joint_int_learned[i,i] * TotalD[i]
  }
  Joint_pro = Joint_int_learned/-diag(Joint_int_learned)
  for(i in 1:4)
    for(j in 1:4)
    {
      if(i!=j && Joint_pro[i, j]!=0)
        total_log = total_log + M[i,j] * log(Joint_pro[i, j])
    }
  total_log = total_log + dis_log_learned
  return(list(true.ll = total_log_org, est.ll = total_log))
}

q1 = rgamma(1, shape = 2, rate = 2)
q2 = rgamma(1, shape = 2, rate = 2)
q3 = rgamma(1, shape = 2, rate = 2)
q4 = rgamma(1, shape = 2, rate = 2)
q5 = rgamma(1, shape = 2, rate = 2)
q6 = rgamma(1, shape = 2, rate = 2)

Pump.tran = array(data = c(0.9,0.2,0.1,0.8,0.3,0.4,0.7,0.6), dim = c(2,2,2))
# intensity matrices for AMI
AMI.inten = as.data.frame(matrix(c(-q5,q5,q6,-q6), ncol = 2, byrow = TRUE))
# intensity matrices for Enz
Enz.inten = array(c(-q1,q2,q1,-q2,-q3,q4,q3,-q4), dim= c(2,2,2))

N.slices= c(10, 20, 40, 80, 160, 320, 640, 1024, 2048, 4096)
comp2 = generateComp(10000,q1, q2, q3, q4, q5, q6, Pump.tran)
est.ll = c()
true.ll = c()
for(i in 1:length(N.slices))
{
  N = N.slices[i]
  ll = getDis(N,comp2, q1, q2, q3, q4, q5, q6, Pump.tran)
  est.ll = c(est.ll, ll$est.ll)
  true.ll = c(true.ll, ll$true.ll)
}


dataHF = data.frame(N.slice = log10(N.slices), true.ll = true.ll, est.ll = est.ll, diff = abs(est.ll -true.ll))
write.csv(dataHF,file = "~/Papers/LearnHTBN/HF(complete)LL.diff.csv", row.names = FALSE, quote = FALSE)
