{-# OPTIONS_GHC -Wall #-}
module InvToInv (parse, parseInv, invToInv, main) where
import Text.ParserCombinators.Parsec (parse,spaces,many,string,(<|>),Parser,noneOf,oneOf,digit,(<?>),eof)
import qualified Data.Map.Lazy as Map
import Data.List (sortBy)
import System.Environment
-- example input (note that spaces around the + and * are mandatory, and x * y must have x as a number and y as a queue):
{-
    A + B = C + D
    E + 2 * F = 4
    # queue sizes:
    A <= 3
    B <= 3
    C <= 3
    D <= 3
-}


main :: IO ()
main
 = do args <- getArgs
      (if "-?" `elem` args || "--?" `elem` args then fail "Use stdin. Available options: -R and -plain" else return ())
      s <- getContents
      case (parse parseInv "(stdin)" s) of
        (Left e) -> fail (show e)
        (Right d) -> putStrLn (showInv args (invToInv d))

(.+.),(.-.) :: (Ord k, Num v, Eq v) => Map.Map k v -> Map.Map k v -> Map.Map k v
(.-.) = Map.mergeWithKey (\_ a b->case a-b of {0->Nothing;s->Just s}) id (Map.map (\x -> -x))
(.+.) = Map.mergeWithKey (\_ a b->case a+b of {0->Nothing;s->Just s}) id id

showInv :: [String] -> [(Map.Map String Int,String)] -> String
showInv a invs = prelude ++ concat (map show' invs)
 where show' (i,s) = "(assert ("++s++" 0"++numPostfix++" (+" ++ concat [" "++show'' k v | (k,v)<-Map.toList i] ++")))\n"
       show'' k 1 = k
       show'' k (-1) = "(~ "++k++")"
       show'' "" v = if v < 0 then "(~ "++show (-v)++numPostfix++")" else show v ++ numPostfix
       show'' k v = if v < 0 then "(* (~ "++show (-v)++") "++k++")" else "(* "++show v++" "++k++")"
       reals = "-R" `elem` a
       numPostfix = if reals then ".0" else ""
       typ = if reals then "Real" else "Int"
       prelude = if "-plain" `elem` a then [] else
                   (if reals then "(set-logic QF_LRA)" else "(set-logic QF_LIA)")
                   ++ "\n\n"
                   ++ concat (map showP [k | k<- Map.keys (Map.unions (map fst invs)), not (null k)])
       showP x = "(declare-fun "++x++" () "++typ++") (assert (<= 0"++numPostfix++" "++x++"))\n"

data InvTyp = Invariant | IDefinition | ILEQ

parseInv :: Parser [(Map.Map String Int,InvTyp)]
parseInv
 = do spaces
      e<-many invariant
      _<-eof
      return ( normalize e
             )
  where
   invariant
    = do { soc1 <- termList <?> "a new invariant"; spaces; d <- delimiter; spaces; soc2 <- termList; spaces; return (d (soc1,soc2))}
      <|> do { c <- commentline ; spaces; return (Right c)}
   delimiter
    =     do { _ <- string "="; return (\x -> Left (x,Invariant))}
      <|> do { _ <- string ":="; return (\x -> Left (x,IDefinition))}
      <|> do { _ <- string "<="; return (\x -> Left (x,ILEQ))}
   termList -- :: Parser [(Int, String)]
    =     do { t <- term; spaces; tl <- termListRemainder; spaces; return (t:tl)}
      <|> do { _ <- string "0"; spaces; return []} <?> "another term or 0"
   termListRemainder -- :: Parser [(Int,String)]
    =     do { _ <- string "+" <?> "+ with another term"; spaces; tl <- termList; spaces; return tl}
      <|> do {return []}
   term -- :: Parser (Int,String)
    =     do { f <- factor; spaces; timesString f }
      <|> do { c <- symb; return (1,c) } <?> "term"
   timesString f
    =     do { _ <- string "*"; spaces; c <- symb; return (f,c)}
      <|> return (f,"")
   factor
    = do{d<-oneOf "-123456789"; s<-many digit; return (read$ d:s)}
   symbRest
    = do{c<-many (noneOf " \t\r\n()|"); return c}
   symb
    = do{c<-noneOf "#*0123456789-()|"; r<-symbRest; return (c:r)} <?> "symbol"
   commentline
    = do{_<-string "#";s<-many (noneOf "\n\r");spaces; return s} <?> "comment line starting with #"
   normalize lst = map normfunc [l | Left l <- lst]
     where
      normfunc ((negs, poss), delim)
       = ( Map.filter (not . ((==) 0))$ Map.fromListWith (+)
                                                (map (\(x,y) -> (y,x)) poss ++ map (\(x,y) -> (y,-x)) negs)
         , delim)

invToInv :: [(Map.Map String Int,InvTyp)] -> [(Map.Map String Int,String)]
invToInv invs
 = map (\x -> (x,"=")) resInvAndDefs ++ map (\x -> (x,"<=")) resLeqs
 where
  (keepInvs, sweepKeys', leqs) = partition3 invs
  performSweep (k,inv) acc = fullSweep (k,(Map.!) inv k) inv  acc
  sweepKeys = Map.fromList [ (k, inv) | inv <- sweepKeys', k <- Map.keys inv ]
  keyM :: [(String,Map.Map String Int)] -- the variables on which a sweep needs to be performed, and the corresponding linear equation
  keyM = map fst (sortBy by2 [ ((k,inv), v) | (k,inv) <- Map.toList sweepKeys, let v = length [()|'.'<-k], v>0 ]) -- number of dots in the key
  by2 x y = compare (snd x) (snd y)
  resInvs = foldr performSweep keepInvs keyM
  resDefs = filterRelevants sweepKeys' (Map.fromList [(k,()) | k<-Map.keys (Map.unions (resInvs++resLeqs)), not (null [()|'.'<-k])])
  resInvAndDefs = resInvs ++ resDefs
  resLeqs = foldr performSweep leqs keyM
  
filterRelevants :: Ord a => [Map.Map [a] b] -> Map.Map [a] c  -> [Map.Map [a] b]
filterRelevants ks flt
 = [ k | k <- ks, not (null [() | k'<-Map.keys (Map.intersection k flt), not (null k')])]

fullSweep :: (Integral a, Ord k) => (k,a) -> Map.Map k a -> [Map.Map k a] -> [Map.Map k a]
fullSweep x a vs = map (sweepOn x a) vs

-- note: it is your responsibility that (Map.lookup ll a == Just ht)
-- if you don't, the new `matrix' will still be linearly equivalent to the old, but with more terms introduced
sweepOn :: (Integral a, Ord k) => (k,a) -> Map.Map k a -> Map.Map k a -> Map.Map k a
sweepOn (k,ht) a v
 = case Map.lookup k v of
        Nothing    -> v
        Just f'    -> if f' == ht then v .-. a -- even though normalization might be possible, it is not required here. This reduces runtime with about 25% in some examples
                      else if f'==(-ht) then v .+. a -- same here
                      else norm (Map.map ((*) ht) v .+. Map.map ((*) (-f')) a)

norm :: Integral b => Map.Map k b -> Map.Map k b
norm sp = Map.map (\f -> quot f factor) sp
          where factor = Map.foldr gcd 0 sp

partition3 :: [(a, InvTyp)] -> ([a], [a], [a])
{-# INLINE partition3 #-}
partition3 xs = foldr select ([],[],[]) xs
 where
   select (x,v) ~(as,bs,cs)
    = case v of
        Invariant   -> (x:as, bs  , cs  )
        IDefinition -> (as  , x:bs, cs  )
        ILEQ        -> (as  , bs  , x:cs)