{-# OPTIONS_GHC -Wall -XFlexibleInstances #-}
module ToSMT (showSMT)  where
import ToINV (getINV)
import Terms (Term(..))
import qualified Data.Set as Set
import qualified Data.Map.Lazy as Map
import qualified Data.Array as Array
import EParser (QueueAssignment(..),FlopAssignment(..),InputAssignment(..),OutputAssignment(..)
               ,OutputWire(..))
-- import RewriteStructure (rewrite,rewrite')
import Debug.Trace (trace)
import Prelude hiding (init)
import Data.Maybe (catMaybes)
import Formula

data SMTProblem = SMTProblem Declarations (Formula TimedLPEq) [String]
type Declarations = Map.Map VariableInTime (String,Maybe String)
data LPEq a = EQUATION a a
            | LEQ a a
type LPMap a = Map.Map a Int
type LPList a =[(a,Int)]
type UntimedLPTerm = Set.Set String
type LPEqWithSets = LPEq (LPMap UntimedLPTerm)
type TimedLPEq    = LPEq (LPList VariableInTime)

data VariableInTime = Initially UntimedLPTerm
                    | Currently UntimedLPTerm
                    | Finally   UntimedLPTerm
                    | CurrentQueuePackets String
                    deriving (Eq,Ord)

(.*.) :: Int -> Map.Map a Int -> Map.Map a Int 
(.*.) x y = Map.map ((*) x) y
(.-.),(.+.) :: Ord a => Map.Map a Int -> Map.Map a Int -> Map.Map a Int
(.-.) x y = (.+.) x ((-1) .*. y)
(.+.) = Map.mergeWithKey (\_ a b->case a+b of {0->Nothing;s->Just s}) id id
one,zero :: Map.Map (Set.Set x) Int
one = lpItm Set.empty
zero = Map.empty
lpItm :: Num a => k -> Map.Map k a
lpItm a = Map.singleton a 1



data TranslatableTerm
 = TR_OR    Term Term
 | TR_XOR   Term Term
 | TR_ITE   Term Term Term
 | TR_NOT   Term
 | TR_VALUE Bool
 | TR_RESET deriving Show

translateTerms :: Set.Set UntimedLPTerm -- previous set of terms to be declared
                  -> [Term] -- terms to be translated
                  -> (Set.Set UntimedLPTerm, -- a new (accumulated) set of terms to be declared
                      [LPMap UntimedLPTerm] -- translations in the same order
                      )
translateTerms t as
 = (Set.union t (Set.fromList (concat t')), res)
 where (t',res) = unzip [ ( Map.keys trm, trm )
                        | trm <- map translateTerm as ]

extendLPTerm :: UntimedLPTerm -> UntimedLPTerm -> UntimedLPTerm
extendLPTerm = Set.union

-- compute the LP-term for (x && y) in which x is a list of primitives and y is an LP-term
extendLPMap :: Set.Set String -> LPMap UntimedLPTerm -> LPMap UntimedLPTerm
extendLPMap s = Map.mapKeysWith (+) (extendLPTerm s)

-- "translateTranslatableTerm" would be a nice candidate to prove correct in ACL2
translateTranslatableTerm :: [TranslatableTerm]
                          -> LPMap UntimedLPTerm -- the resulting LP variable
translateTranslatableTerm [] = (Map.singleton Set.empty 1)
translateTranslatableTerm (ini:bs)
 = combineAnd (case ini of
                (TR_RESET    ) -> (Map.empty)
                (TR_VALUE b  ) -> if b then (Map.singleton Set.empty 1)
                                       else (Map.empty)
                (TR_ITE a b c) -> translateTerm (T_AND a b) .+. translateTerm (T_AND (T_NOT a) c) -- note: this is more efficient than translating T_ITE to a T_OR: we know that both sides of the T_OR are mutually exclusive and use that here. We can still save on our translation, since "a" is translated twice here!
                (TR_OR  y1 y2) -> let t1=translateTerm (tnot y1)
                                      t2=translateTerm (tnot y2)
                                  in (Map.singleton Set.empty 1) .-. (combineAnd t1 t2) -- 
                                     -- note: this was (on positive terms t1 t2
                                     -- t1 .+. t2 .-. combineAnd t1 t2
                                     -- which may be slower by creating t1 and t2 as separate terms
                                     -- (it is just as fast if t1 and t2 are positive terms, however)
                (TR_XOR y1 y2) -> let t1=translateTerm y1
                                      t2=translateTerm y2
                                  in t1 .+. t2 .-. (2 .*. combineAnd t1 t2)
                (TR_NOT y1   ) -> (Map.singleton Set.empty 1) .-. translateTerm y1
              )
              (translateTranslatableTerm bs)

-- "combineAnd" would be a nice candidate to prove correct in ACL2
combineAnd :: LPMap UntimedLPTerm -> LPMap UntimedLPTerm -- two LPs for which the AND must be calculated
           -> LPMap UntimedLPTerm -- result
-- this is just some sort of Cartesian product
-- examples:
--   combineAnd (W1 + W2 - W3) W4 = (W1-AND-W4 + W2-AND-W4 - W3-AND-W4)
--   combineAnd W1-AND-W2 W3 = W1-AND-W2-AND-W3
--   combineAnd W1 True = W1 -- True is represented as one number containing no conjuncts
--   combineAnd W1 False = False -- False is represented as a zero
combineAnd a b = Map.fromListWith (+) [(Set.union a' b', ai*bi) | (a',ai) <- Map.toList a, (b',bi) <- Map.toList b]

translateTerm :: Term -> LPMap UntimedLPTerm
translateTerm t = extendLPMap (Set.fromList s) (translateTranslatableTerm tt)
 where (s,tt) = collectAnds t
       collectAnds :: Term -> ([String],[TranslatableTerm])
       -- collectAnds:
       -- returns Either:
       --   Left: a set of strings which is to be used for declarations
       --   Right: tuple: ( list of conjuncts that are primitive and succesfully translated as a set of strings to be used for declarations
       --                 , list of translatables (i.e. not T_AND or primitive)
       --                   which may still contain T_AND in a nested place (such as T_NOT (T_AND ...))
       --                 )
       collectAnds (T_QTRDY x  ) = single  "QTRDY" x
       collectAnds (T_QIRDY x  ) = single  "QIRDY" x
       collectAnds (T_QDATA x y) = single' "QDATA" x y
       collectAnds (T_OTRDY x  ) = single  "OTRDY" x
       collectAnds (T_IIRDY x  ) = single  "IIRDY" x
       collectAnds (T_IDATA x y) = single' "IDATA" x y
       collectAnds (T_FLOPV x  ) = single  "FLOPV" x
       collectAnds (T_UNKNOWN x) = single  "UNKNO" (show x)
       collectAnds (T_INPUT x  ) = single  "INPUT" x
       collectAnds (T_AND x y)
         = case (collectAnds x,collectAnds y) of
             ~((nx,ox),(ny,oy)) -> (nx++ny, ox++oy)
       collectAnds (T_OR  x y)   = wrong (TR_OR    x y)
       collectAnds (T_XOR x y)   = wrong (TR_XOR   x y)
       collectAnds (T_ITE x y z) = wrong (TR_ITE   x y z)
       collectAnds (T_NOT x)     = wrong (TR_NOT   x)
       collectAnds (T_VALUE x)   = wrong (TR_VALUE x)
       collectAnds (T_RESET)     = wrong (TR_RESET)
       single x y = ([x++y],[])
       single' x y u = ([x++y++"-"++show u],[])
       wrong x = ([],[x])

tnot :: Term -> Term
tnot (T_NOT a) = a
tnot (T_XOR a b) = (T_XOR (tnot a) b)
tnot (T_OR a b) = (T_AND (tnot a) (tnot b))
tnot b = (T_NOT b)

getDeadlock :: QueueAssignment -> Term
getDeadlock (QueueAssignment nm _ _ _ t) = T_AND (T_QIRDY nm) (T_NOT t)

showSMT :: ([QueueAssignment],[FlopAssignment],[InputAssignment],[OutputAssignment],[OutputWire]) -> IO ()
showSMT x
 = putStrLn (problemToString (genProblem x))

problemToString :: SMTProblem -> String
problemToString (SMTProblem ds fs cm)
 = -- "(set-logic QF_LRA)\n" ++ -- maybe we don't have to set a logic?
   showds (Map.elems ds) ++ "\n\n" ++showfm fs ++
   "\n(check-sat)\n(get-model)\n(exit)\n" ++
   foldl (\x y-> x ++ "\n; " ++ y) "" cm
 where showds ((nm,Just tp):rs) = "(declare-fun "++nm ++ " () "++tp++")\n" ++ showds rs
       showds ((_ ,Nothing):rs) = showds rs -- needs no declaration for some reason
       showds [] = ""
       showfm f = "(assert " ++ tosmt 1 (showLPEqWithSets ds) f ++ ")\n"

showLPEqWithSets :: Map.Map VariableInTime (String, Maybe String) -> TimedLPEq -> String
showLPEqWithSets ds x
 = case x of
    (EQUATION s t) -> "= "++showsum2 (pos s ++ neg t) (pos t ++ neg s)
    (LEQ s t)      -> "<= "++showsum2 (pos s ++ neg t) (pos t ++ neg s)
 where
  pos ((a,v):rst) = if v>0 then (a, v):pos rst else pos rst
  pos [] = []
  neg ((a,v):rst) = if v<0 then (a,-v):neg rst else neg rst
  neg [] = []
  showsum2 a b = showsum ta ++" "++showsum tb
    where
     ta = map (\(e,f) -> ((Map.!) ds e,f)) a
     tb = map (\(e,f) -> ((Map.!) ds e,f)) b
     -- maybeReal will convert all numbers to Real iff one of them actually is a "Real".
     -- in case all numbers are Int, we leave all numbers as Int (to_real slows down z3 significantly)
     maybeReal (v,t) = if allInt then v else case t of {Just "Int" -> "(to_real "++v++")";Nothing -> v++".0";_->v}
     allInt = and [v == "Int" | ((_,Just v),_) <- ta++tb]
     showsum s
      = case s of
         [] -> if allInt then "0" else "0.0"
         [e] -> showElm e
         lst -> "(+"++concat [' ':showElm elm | elm<-lst]++")"
     showElm (e,1) = maybeReal e
     showElm (("1",Nothing),f) = maybeReal (show f,Nothing)
     showElm (e@(_,Just _),f) = "(* "++show f++" "++maybeReal e ++")"
     showElm ((s,Nothing),_) = error$ "ToSMT.hs error 190: only constants may be undeclared, but \""++s++"\" was not recognised as one"
{-
mapfind :: (Ord a1, Show a1) => Map.Map a1 a -> a1 -> a
mapfind a b = case Map.lookup b a of
                Nothing -> error ("Cannot find: "++show b)
                Just c -> c
-}

isAVG1 :: LPMap UntimedLPTerm -> (Formula TimedLPEq)
isAVG1 v = finally (BASIC (EQUATION one v))

getTerms :: Ord a => [Map.Map a a1] -> Set.Set a
getTerms lst = Set.fromList (concat (map Map.keys lst))

obtainQueues :: Set.Set UntimedLPTerm                 -- and-terms found so far
             -> [QueueAssignment]                     -- to be translated into (Formula TimedLPEq)
             -> (Set.Set UntimedLPTerm                -- used untimed terms (to be related and declared)
                ,[(VariableInTime, (String, Maybe String))]           -- other declarations (which may exist in CUR only)
                ,[(Formula TimedLPEq)])            -- 
obtainQueues t [] = (t, [], [])
obtainQueues t (QueueAssignment nm size inIrdy _data ouTrdy : lst)
 = (Set.union t' (getTerms [ enter  
                           , exit   
                           , qouIrdy
                           , qinTrdy
                           , qinIrdy
                           , qouTrdy
                           -- , fullONt
                           ])
   , [ (currPack,("Packets_in_"++nm,Just "Int"))
     ] ++ moreDecs
   , itm:itms)
 where enter   = translateTerm (T_AND inIrdy (T_QTRDY nm))
       exit    = translateTerm (T_AND ouTrdy (T_QIRDY nm))
       qouIrdy = translateTerm (T_QIRDY nm)
       qinTrdy = translateTerm (T_QTRDY nm)
       qinIrdy = translateTerm inIrdy
       qouTrdy = translateTerm ouTrdy
       -- fullONt = translateTerm (T_OR (T_QIRDY nm) (T_QTRDY nm))
       currPack = CurrentQueuePackets ("PcksIn"++nm)
       pcks    = [(currPack,1)]
       itm = AND ("Queue "++ nm)
                 [ finally$ BASIC (EQUATION enter exit) -- what goes in, must come out (eventually)
                 -- , BASIC (EQUATION [(outp,1)] (fina qouTrdy))
                 , BASIC (EQUATION (init (enter .-. exit))
                                   (pcks ++ curr (enter .-. exit)))
                 -- , globally$ BASIC (EQUATION one fullONt) -- queue is full or not
                  -- queue is full <=> there is no space
                 , EQUIV (BASIC (EQUATION (curr (size .*. one)) pcks))
                         (currently (BASIC (EQUATION qinTrdy zero)))
                  -- queue is empty <=> nothing's available
                 , EQUIV (BASIC (EQUATION (curr zero) pcks))
                         (currently (BASIC (EQUATION qouIrdy zero)))
                 , AND ("Size restrictions")
                       [BASIC $ LEQ (curr zero) pcks
                       ,BASIC $ LEQ pcks (curr (size .*. one))]
                  -- nothing enters (eventually) => emptyness is constant
                  -- proof:
                  --   (1) nothing enters => qouIrdy is non-decreasing 
                  --   (2) nothing enters => qouIrdy is non-increasing
                  -- (1) follows from the fact that qouIrdy indicates queue availability
                  --     since no packets enter the queue, it will remain available.
                  -- (2) by the first equation: "what goes in, must come out (eventually)"
                  --     hence nothing enters <=> nothing leaves
                  --     using the reasoning as in (1), this means that qouIrdy is non-increasing
                 , IMPL (finally$ BASIC (EQUATION enter zero))
                        (BASIC (EQUATION (curr qouIrdy) (fina qouIrdy)))
                  -- nothing enters (eventually) => fullness is constant
                 , IMPL (finally$ BASIC (EQUATION enter zero))
                        (BASIC (EQUATION (curr qinTrdy) (fina qinTrdy)))
                 
                 , finally$ IMPL (BASIC$ one `EQUATION` qouTrdy)
                                 (OR [ BASIC (one  `EQUATION` qouIrdy)
                                     , BASIC (zero `EQUATION` qouIrdy)])
                 , finally$ IMPL (NOT (BASIC (zero `EQUATION` qouTrdy)))
                                 (NOT (BASIC (zero `EQUATION` qinTrdy)))
                 , finally$ IMPL (BASIC (one `EQUATION` qouTrdy))
                                 (BASIC (one `EQUATION` qinTrdy))
                 , finally$ IMPL (NOT (BASIC (zero `EQUATION` qinIrdy)))
                                 (NOT (BASIC (zero `EQUATION` qouIrdy)))
                 , finally$ IMPL (BASIC (one `EQUATION` qinIrdy))
                                 (BASIC (one `EQUATION` qouIrdy))
                 , finally$ IMPL (AND "" [BASIC (zero `EQUATION` qinIrdy)
                                      , NOT$BASIC (zero `EQUATION` qouTrdy)])
                                 (AND "" [BASIC (one  `EQUATION` qinTrdy)
                                      , BASIC (zero `EQUATION` qouIrdy)])
                 
                 ]
       ~(t', moreDecs, itms ) = obtainQueues t lst

globally :: Formula LPEqWithSets -> Formula TimedLPEq
globally v
 = AND "Global property"
       [ initially v
       , currently v
       , finally   v
       ]

initially,currently,finally ::  Formula LPEqWithSets -> Formula TimedLPEq
initially = formulaChangeCore (lpChangeCore Initially)
currently = formulaChangeCore (lpChangeCore Currently)
finally   = formulaChangeCore (lpChangeCore Finally  )

lpChangeCore :: (Ord b) => (a -> b) -> LPEq (LPMap a) -> LPEq (LPList b)
lpChangeCore f x
 = case x of 
    EQUATION a b -> EQUATION (f' a) (f' b)
    LEQ      a b -> LEQ      (f' a) (f' b)
 where f' = lPMap2LPList f

-- note: these types are a narrower than necessary to get better type errors
lPMap2LPList :: (t -> t1) -> LPMap t -> LPList t1
lPMap2LPList f a = map (\(k,v)->(f k,v)) (Map.toList a)
init,curr,fina :: (LPMap UntimedLPTerm) -> (LPList VariableInTime)
init = lPMap2LPList Initially
curr = lPMap2LPList Currently
fina = lPMap2LPList Finally  

getLooseWires :: (Set.Set (Set.Set String)) -> [Formula TimedLPEq]
getLooseWires lst
 = [f | ms <- Set.toList lst, f <- form ms]
 where
   form st = case (length [() | v <- Set.toList st, take 5 v == "INPUT" || take 5 v == "UNKNO"], Set.size st) of
               (0,_) -> []
               (1,1) -> [finally$ (AND "wire is loose" [NOT (BASIC (EQUATION zero (lpItm st)))
                                                      ,NOT (BASIC (EQUATION one  (lpItm st)))])]
               _ -> [finally$ (AND "conjunct containing loose wire" [NOT (BASIC (EQUATION one (lpItm st)))])]

obtainFlops :: Set.Set UntimedLPTerm                 -- and-terms found so far
            -> [FlopAssignment]                      -- to be translated into (Formula TimedLPEq)
            -> (Set.Set UntimedLPTerm,               -- 
                [(Formula TimedLPEq)])
obtainFlops t [] = (t,[])
obtainFlops t (FlopAssignment nm drv : lst) = (Set.union t' (getTerms [driver,value]),itm:itms)
 where driver = translateTerm drv
       value  = translateTerm (T_FLOPV nm)
       reset = zero -- TODO: get initial value (value after reset) (which might be a variable! Use translateTerm..)
       itm = AND ("flop driver: "++show drv)
                 [ finally$ BASIC (EQUATION driver value) -- on average, the driver value is the current
                 , BASIC (EQUATION (init (driver .-. value)) (curr (driver .-. reset)))
                 
                 ]
       ~(t', itms ) = obtainFlops t lst

obtainOuts :: Set.Set UntimedLPTerm                -- and-terms found so far
           -> [OutputAssignment]                   -- to be translated into (Formula TimedLPEq)
           -> (Set.Set UntimedLPTerm,              -- 
               [(Formula TimedLPEq)])
obtainOuts t [] = (t,[])
obtainOuts t (OutputAssignment nm _ _ : lst) = (Set.union t' (getTerms [trdy]),itm:itms)
 where trdy = translateTerm (T_OTRDY nm)
       itm = AND ("Non-blocking output for "++nm)
                 [ finally$ NOT$ BASIC (EQUATION zero trdy) -- packets get accepted sometimes
                 ]
       ~(t', itms ) = obtainOuts t lst

obtainWires :: Set.Set UntimedLPTerm                -- and-terms found so far
           -> [OutputWire]                   -- to be translated into (Formula TimedLPEq)
           -> (Set.Set UntimedLPTerm,              -- 
               [(Formula TimedLPEq)])
obtainWires t [] = (t,[])
obtainWires ts (OutputWire nm t : lst) = (Set.union t' (getTerms [trm,w]),itm:itms)
 where trm = translateTerm t
       w = Map.singleton (Set.singleton$ "OUT_"++nm) 1
       itm = AND ("Value for "++nm++" ("++show t++")")
                 [ globally$ BASIC (EQUATION w trm) -- packets get accepted sometimes
                 ]
       ~(t', itms ) = obtainWires ts lst
-- idea of formIntersections:
-- it gets a set of sets, and returns relevant lattice information.
-- That is: it performs a piecewise comparison. That is, for each a and b,
--   If (a /\ b) is nonempty, the following sets are added (in case they do not exist):
--    u: a \/ b
--    d: u - (a /\ b), the symmetric difference between a and b
--   We know: d isSubset a, d isSubset b, a isSubset u, b isSubset u
--   From this follows:
--    #a <= #d, #b <= #d
--    #u <= #a, #u <= #b
--   Since (#a=1 and #b=1) implies #d=#u=1, we know: #a + #b + #d <= 2*#u + 1
--   Note that in case a /\ b = a (so a is a subset of b) we get u = b and d = b - a
--   This adds #a + #b + #d <= 2*#b + 1, or #a + #d <= #b + 1 (which is desired, since b means "a AND d")
formIntersections :: Set.Set (Set.Set String) -> ([Set.Set String], [Formula LPEqWithSets])
formIntersections st
 = if null news then (Set.toList st, comps)
                else formIntersections (Set.union st (Set.fromList news))
 where
   stList = Set.toList st
   (news,comps) = accum (tuples stList)
   accum [] = ([],[])
   accum ((a,b):rst) = (anew++rstnew,aprop++rstprop)
     where (anew,aprop) = compareAsLP a b
           (rstnew,rstprop) = accum rst
   tuples (a:as) = map (\x -> (a,x)) as ++ tuples as
   tuples [] = []
   compareAsLP :: (Set.Set String) -> (Set.Set String) -> ([Set.Set String], [Formula LPEqWithSets])
   compareAsLP a b
    = if (not (Set.null i))
      then ([v|v<-[i,d,u],not (Set.member v st)] -- note that when "a and u" are picked, "i" will be generated
                    -- by adding "i" here, we'll reach the fixed-point faster
           , [AND ("LP information for the AND-pairs: "++show (a,b))
                  [ BASIC$LEQ (lpItm u) (lpItm d)
                  , BASIC$LEQ (lpItm u) (lpItm a)
                  , BASIC$LEQ (lpItm u) (lpItm b)
                  , BASIC$LEQ (lpItm a .+. lpItm b .+. lpItm d) ((2 .*. lpItm u) .+. one)
                  ]
             ]
           )
      else ([],[])
    where i = Set.intersection a b
          u = Set.union a b
          d = Set.difference u i

trmToDecl :: (Int,Set.Set String) -> ([(VariableInTime, (String,Maybe String))],Maybe (Formula TimedLPEq))
trmToDecl (i,v)
 = case (Set.toList v) of
        []    -> ([(Initially v,("N",Just "Int"))
                  ,(Currently v,("1",Nothing))
                  ,(Finally   v,("1",Nothing))
                  ],Nothing)
        [a]   -> mk3 a
        [a,b] -> mk3 (a++"_AND_"++b)
        -- s     -> mk3 (inter "_AND_" s)
        _     -> mk3 ("var"++show i)
 where mk3 b = ([ (Initially v,("Init_"++b,Just "Int" ))
                , (Currently v,("Curr_"++b,Just "Int" ))
                , (Finally   v,("Fina_"++b,Just "Real"))
                ],
                Just$ AND ("Relating initial current and final of "++b++" (stemming from "++show v++")")
                [ BASIC$ LEQ [(Currently v,1)] [(Initially v,1)]
                , IMPL (NOT$BASIC$ EQUATION [(Currently v,1)] (curr one))
                       (NOT$BASIC$ EQUATION [(Finally v,1)] (fina one))
                , IMPL (NOT$BASIC$ EQUATION [(Currently v,1)] (curr zero))
                       (NOT$BASIC$ EQUATION [(Finally v,1)] (fina zero))
                , globally (BASIC$ LEQ (Map.singleton v 1) one)
                , globally (BASIC$ LEQ zero (Map.singleton v 1))
                ]
                )
                
genProblem ::([QueueAssignment],[FlopAssignment],[InputAssignment],[OutputAssignment],[OutputWire]) -> SMTProblem
genProblem (queues,flops,ins,outs,wires)
 = SMTProblem (Map.fromList (concat (map fst trmICF) ++ qDcls))
              (AND "SMT problem" ([OR (map isAVG1 deadlocks)] 
                                  ++ compositeInfo
                                  ++ catMaybes (map snd trmICF)
                                  ++ map globally latticeInfo
                                  ++ outWireInfo
                                 )
              )
              [] -- no comments (TODO: add useful info)
 where ( trms0, deadlocks) = translateTerms (Set.empty) (map getDeadlock queues)
       ( trms1, qDcls, queueInfo) = obtainQueues (trms0)     queues
       ( trms2, flopInfo )        = obtainFlops  (trms1)     flops
       ( trms3, outInfo  )        = obtainOuts   (trms2)     outs
       ( trms4, outWireInfo )     = obtainWires  (trms3)     wires
       inWireInfo = getLooseWires trms4
       (trmList,latticeInfo) = formIntersections trms4
       trmICF = map trmToDecl (zip [0..] (Set.empty:trmList))
       compositeInfo = queueInfo ++ flopInfo ++ outInfo ++ inWireInfo -- TODO: inputs, maybe output wires?
       