#include <stdio.h>
#include "../simulation.c"
#include <map>
#include <vector>
#include <deque>
#include <bitset>
#include <string>
#include <list>
#include <iostream>
#include <sstream>

// to change what is analysed, you may whish to (un)define:
// #define getSinkCounters    // we wish to know the number of packets ending in sinks 
// #define getSourceCounters  // we wish to know the number of packets coming from sources

// we still have to implement facilities for these:
// #define getQueueHeaders    // try to find out whether queues have the same packet at their heads

// do disable assert-checks, replace this line with #define NDEBUG
#include <assert.h>
extern bool* type_specific_queues;
using std::vector;
using std::list;
using std::map;
using std::pair;
using std::string;
using std::deque;
using std::stringstream;

typedef map<int,float> Vector;
typedef map<int,Vector > Line;
struct Row{
  int pivotVar;
  Vector vars;
};
deque<Line > eqns; //[0][numChans][num_of_headers]; // all equalities: line = 0
list<Row > eqMatrix;
list<Row > leMatrix;

const Line zeroVector(){
  /*
  // zeroVector as a vector of vectors of floats:
  vector<float> v (NUM_OF_HEADERS,0);
  Line r (numTot,v);
  */
  // zeroVector as a map of maps of floats (from int)
  Line r;
  return r;
}

// > data Cref = Cref{component::Int, port::Int}
struct cref { // reference to a channel, which is determined by the output-port of a component
  int component;
  int port;
};
// >             deriving (Eq,Ord)
bool operator==(cref a,cref b){
  return a.component == b.component && a.port==b.port;
}
bool operator<(cref a,cref b){
  return a.component < b.component || (a.component == b.component &&a.port<b.port);
}

// the following variables are initially set by init_chanId 
map<cref,cref> eqcmap; // equal channels
map<cref,int> chanId; // num of variable-vector for channel
map<int,int> bufId;   // num of variable-vector for queue or buffer
vector<string> varName;
vector<string> elemName;
vector<bool> interestingElem;
int numChans; // number of channels
int numBufs; // channels and queues
int numTot ; // total of all elements stored as variable-vectors
int numVars; // total of all variables, (length of the concatenation of all variable-vectors)



// > outChan i j = Cref i j
cref outChan(int i,int j){
  assert(i>=0 && i<xMas_network_size);// "Out of bounds: component is not a valid component"
  assert(j>=0 && j<xMas_network[i].num_of_outs);// "Out of bounds: invalid out-port of component"
  cref o;o.component=i;o.port=j;
  return o;
}
cref inChan(int i,int j){
  assert(i>=0 && i<xMas_network_size);// "Out of bounds: component is not a valid component"
  assert(j>=0 && j<xMas_network[i].num_of_ins);// "Out of bounds: invalid in-port of component"
  return outChan(xMas_network[i].in[j],xMas_network[i].out_port_connected_to_in[j]);
}


///
// Traverse the eqcmap structure, starting at it, and find out what element it actually points to.
// This function is used by init_chanId()
//
// For efficiency, we update the visited elements (ie: the function is not pure!)
// The function does not terminate if there are directed cycles in eqcmap
cref fixPoint(cref it){
    cref lookup;
    if(!eqcmap.count(it)){eqcmap[it]=it; return it;}
    lookup=eqcmap[it];
    if(lookup==it){
       return it;
    }else{
       cref v=fixPoint(lookup);
       eqcmap[it]=v;
       return v;
    }
}

///
// Initialise the channels, setting most global variables that will be used as (constant) lookup-tables later on
void init_chanId(){
    numBufs=0;
    // we will build a list of channels, such that two channels always get mapped to the same number
    int j;
    int i; // used for looping over the in-ports or out-ports in nearly all of the below
    for (j=0;j<xMas_network_size;j++) {
        /*
          // display the name of the component to which the next bit belongs:
          printf("/%s component %s: %s/\n","*",xMas_network[j].id,"*");
        */
        
        // We recognise linear invariants.
        // the value of a channel is the number of packets that went through it
        // the value of a memory-component (think: queue or buffer) is the number of packets that entered it minus the number of packets that left it
        // (applying this definition to non-memory compontents, its value is 0 and we do not model it)
        
        // we will generate variable-names later (these are kept in a vector, and we need to count the number of channels first)
        switch (xMas_network[j].type)
        {
          case synch:
               assert(xMas_network[j].num_of_ins==xMas_network[j].num_of_outs); // must be a valid Synch
               for(i=0;i<xMas_network[j].num_of_ins;i++)
                 eqcmap[outChan(j,i)]=inChan(j,i);
               break;
          case xfork:
               assert(xMas_network[j].num_of_ins==1); // must be a valid fork: one port!
               for(i=0;i<xMas_network[j].num_of_outs;i++)
                 eqcmap[outChan(j,i)]=inChan(j,0);
               break;
          case queue:case buffer:
               bufId[j]=numBufs;
               numBufs++;
               // no break here: we have to create the ports
          case sink:case source:
          case merge:case join:case function:case xswitch:case aswitch:
               // we cannot map ports to ports, so we simply create them          
               for(i=0;i<xMas_network[j].num_of_ins;i++)
                 eqcmap.insert(pair<cref,cref>(inChan(j,i),inChan(j,i)));
               for(i=0;i<xMas_network[j].num_of_outs;i++)
                 eqcmap.insert(pair<cref,cref>(outChan(j,i),outChan(j,i)));
               break;
          default:
               assert(false); // xMas_network[j].type does not exist: unknown component
        }
    }
    
    i=0;
    map<cref,cref>::const_iterator end = eqcmap.end(); 
    for (map<cref,cref>::const_iterator it = eqcmap.begin(); it != end; ++it) {
        if(it->first==fixPoint(it->first)){
          chanId[it->first]=numBufs+i++;
        }
    }
    numChans=i;
    numTot = numChans+numBufs;
    for (map<cref,cref>::const_iterator it = eqcmap.begin(); it != end; ++it) {
        chanId[it->first]=chanId[it->second];
    }
    
    elemName.resize(numTot);
    // we generate variable-names for all relevant channels:
    // one channel for each in- or output of the current component
    for (j=0;j<xMas_network_size;j++) {
        switch (xMas_network[j].type)
        {
            case queue:
            case buffer:
                 elemName[bufId[j]] = xMas_network[j].id;
            default:
                for(i=0;i<xMas_network[j].num_of_ins;i++){
                  stringstream ss;
                  ss<<xMas_network[j].id<<"_in"<<i;
                  if(!elemName[chanId[inChan(j,i)]].empty()) ss << "/" << elemName[chanId[inChan(j,i)]];
                  elemName[chanId[inChan(j,i)]]=ss.str();
                }
                for(i=0;i<xMas_network[j].num_of_outs;i++){
                  stringstream ss;
                  ss<<xMas_network[j].id<<"_out"<<i;
                  if(!elemName[chanId[outChan(j,i)]].empty()) ss << "/" << elemName[chanId[outChan(j,i)]];
                  elemName[chanId[outChan(j,i)]]=ss.str();
                }
                break;
        }
    }
    // no return command: the function is void
}

///
// Set the equations that hold in the network.
// The resulting list of equations is stored in eqns, and will be gauss-eliminated later.
// The order of these equations is arbitrary.
void set_eqns(){
    using std::bitset;
    int j;int i;
    for (j=0;j<xMas_network_size;j++) {
        switch (xMas_network[j].type)
        {
          case source:
               {int oChan = chanId[outChan(j,0)];
               for(i=0;i<NUM_OF_HEADERS;i++){
			 int num_of_packets = ((int*)(xMas_network[j].field[0]))[0];
			 bool found = false;
			 int k;
  			 for (k=1;k<num_of_packets+1&&!found;k++)
				found = ((((int*) (xMas_network[j].field[0]))[k]) == i);

                 if(!found){
                  Line zv=zeroVector();
                  zv[oChan][i]=-1;
                  eqns.push_back(zv);
                 }
               }

               break;}
          case synch:assert(false); // not tested
               // we do not need to assign sources to targets since they are represented by the same variable
               // just ensure that all input-channels get an equal sum
               {int oChan = chanId[outChan(j,0)];
               for (i=1;i<xMas_network[j].num_of_ins;i++){
                 int nChan = chanId[outChan(j,i)];
                 Line zv = zeroVector();
                 zv[oChan][NUM_OF_HEADERS]=1;
                 zv[nChan][NUM_OF_HEADERS]=-1;
                 eqns.push_back(zv);
                 oChan=nChan;
               }
               break;}
          case xfork: // fully handled by preparsing: channels are equal!
               break;
          case sink: // nothing to do
               break;
          case join: // for each input: output is its sum.
               // this is the weakest assumption we can make
               // attach a sink to a synch if you have more information
               {int oChan = chanId[outChan(j,0)];
               int n;
               int token_input = ((int*) (xMas_network[j].field[0]))[0];
               int data_input = token_input == 0 ? 1 : 0;
               for (i=0;i<xMas_network[j].num_of_ins;i++){
                 int nChan = chanId[inChan(j,i)];
                 if (i == data_input) {
                   for(n=0;n<NUM_OF_HEADERS;n++){   
                      Line zv = zeroVector();
                     zv[oChan][n]=-1;
                     zv[nChan][n]=1;
                     eqns.push_back(zv);
                      }
                 }else{
                 Line zv = zeroVector();
                 for(n=0;n<NUM_OF_HEADERS;n++){
                   zv[oChan][n]=-1;
                 }
                 for(n=0;n<NUM_OF_HEADERS;n++){
                   zv[nChan][n]=1;
                 }
                 eqns.push_back(zv);
                 }
               }
               break;}
          case function:
               {
               bitset<NUM_OF_HEADERS*NUM_OF_HEADERS> fmap;
               fmap.reset();
               for(i=0;i<NUM_OF_HEADERS;i++){
                 int mapto = ((int (*)(int)) (xMas_network[j].field[0]))(i);
                 assert(mapto>=0);
                 assert(mapto<NUM_OF_HEADERS);
                 fmap.set(mapto+NUM_OF_HEADERS*i); // packet is routed from i to mapto
               }
               int oChan = chanId[outChan(j,0)];
               int iChan = chanId[inChan(j,0)];
               int k;
               for(i=0;i<NUM_OF_HEADERS;i++){
                 Line zv=zeroVector();
                 zv[oChan][i]=-1;
                 for(k=0;k<NUM_OF_HEADERS;k++){
                   if (fmap[i+k*NUM_OF_HEADERS] // packet is routed from k to i
                      ) zv[iChan][k]=1;
                 }
                 eqns.push_back(zv);
               }
               break;
               }
          case merge:{
               // sum of inputs is output:
               int oChan = chanId[outChan(j,0)];
               int k;
               for(i=0;i<NUM_OF_HEADERS;i++){
                 Line zv=zeroVector();
                 zv[oChan][i]=-1;
                 for(k=0;k<xMas_network[j].num_of_ins;k++){
                   zv[chanId[inChan(j,k)]][i]=1;
                 }
                 eqns.push_back(zv);
               }break; }
          case xswitch:{
               int iChan = chanId[inChan(j,0)];
               int k;
               for(i=0;i<NUM_OF_HEADERS;i++){
                 Line zv=zeroVector();
                 zv[iChan][i]=-1;
                 if((bool) (((int (*)(int)) (xMas_network[j].field[0]))(i))){
                     zv[chanId[outChan(j,0)]][i]=1;
                 }else{
                     zv[chanId[outChan(j,1)]][i]=1;
                 }
                 eqns.push_back(zv);
                 if((bool) (((int (*)(int)) (xMas_network[j].field[0]))(i))){
                     zv=zeroVector();
                     zv[chanId[outChan(j,1)]][i]=1;
                     eqns.push_back(zv);
                 }else{
                     zv=zeroVector();
                     zv[chanId[outChan(j,0)]][i]=1;
                     eqns.push_back(zv);
                 }
               }
               break;}
          case aswitch:{assert(false); // not tested
               int iChan = chanId[inChan(j,0)];
               vector<bool> omap(xMas_network[j].num_of_outs*NUM_OF_HEADERS,0);
               int k;
               for(i=0;i<NUM_OF_HEADERS;i++){
                 Line zv=zeroVector();
                 zv[iChan][i]=-1;
                 for(k=0;k<xMas_network[j].num_of_outs;k++){
                   if((bool) ((int (*)(int)) (xMas_network[j].field[k]))(i)){
                     zv[chanId[outChan(j,k)]][i]=1;
                   }else
                   omap[k*NUM_OF_HEADERS+i]=1;
                 }
                 eqns.push_back(zv);
               }
               for(i=0;i<NUM_OF_HEADERS;i++){
                 for(k=0;k<xMas_network[j].num_of_outs;k++){
                   if(omap[k*NUM_OF_HEADERS+i]){
                     Line zv=zeroVector();
                     zv[chanId[outChan(j,k)]][i]=1;
                     eqns.push_back(zv);
                   }
                 }
               }
               break;}
          // queues and buffers (3)
          case queue:case buffer:{
               int oChan = chanId[outChan(j,0)];
               int iChan = chanId[inChan(j,0)];
               int buf = bufId[j];
               for(i=0;i<NUM_OF_HEADERS;i++){
              bool packet_in_queue = false;
              int k;
              for (k=0;k<xMas_network[j].types[0]->num_of_elts&&!packet_in_queue;k++)
                packet_in_queue = (xMas_network[j].types[0]->array[k] == i);
                  if (packet_in_queue){
                    Line zv=zeroVector();
                    zv[iChan][i]=1;
                    zv[oChan][i]=-1;
                    zv[buf][i]=-1;
                    assert(buf!=iChan);
                    assert(buf!=oChan);
                    eqns.push_back(zv);
                  }else{
                    Line zv=zeroVector();
                    zv[oChan][i]=1;
                    eqns.push_back(zv);
                    zv=zeroVector();
                    zv[iChan][i]=1;
                    eqns.push_back(zv);
                  }
               }
               break;
          }
        }
    }
}

///
// Remove a certain variable from the list of equations as with gaussian elimination.
// The variable-to-be-removed is assumed to be non-negative,
// so the pivoting-equation from ???òeqMatrix???ô may be moved into leMatrix.
// Equations in leMatrix may also be removed, namely:
// when the variable to be removed occurs in some equation in leMatrix
// and this occurence is negative.
void reduceFor(int var){
  using std::cout;
  list<Row >::iterator it;
  list<Row >::iterator pivotRow;
  Vector::iterator v;
  float x;
  // find a pivoting row
  for (it=eqMatrix.begin(); it!=eqMatrix.end(); it++){
    if(it->vars.count(var)){
      assert(it->vars[var]!=0); // zero's should not be stored!
      break;}
  }
  if(it!=eqMatrix.end()){
    if (it->pivotVar>=0){
      //cout << "# not reducing for "<<varName[var]<<"\n";
      return; // already assigned to some variable.. no need to use this as a pivot-row
    }
    //cout << "# reducing for "<<varName[var]<<"\n";
    it->pivotVar = var;
    // pivot with row ???òit???ô: we can effectively use this row to create zero's in all other rows
    pivotRow = it;
    // normalize line:
    x = pivotRow->vars[var];
    assert (x!=0);
    for ( v=pivotRow->vars.begin() ; v != pivotRow->vars.end(); v++ ) v->second/=x;
    assert(pivotRow->vars[var]==1); // since it was divided by itself.
    
    // perform iteration over the other non-zero's
    it++;
    while(it!=eqMatrix.end()){
      if(!it->vars.empty() && it->vars.count(var)){
        assert(it->vars[var]!=0); // zero's should not be stored!
        // determine the factor with which we multiply the *pivotRow,
        // such that *it.count(var)==0 afterwards
        x = -(it->vars[var]);
        for ( v=pivotRow->vars.begin() ; v != pivotRow->vars.end(); v++ ) {
          // add x*v->second to the corresponding factor in equation-row it,
          // and remove the factor if this turns out to be zero.
          if((it->vars[v->first]+=x*(v->second))==0) it->vars.erase(v->first);
        }
        assert(it->vars.count(var)==0); // variable should have been erased!
        // if the above assertion gives errors, we've float-arithmetics gone wrong
        // hence we should change a Row to holding integers instead.
        if(it->vars.size()==0) eqMatrix.erase(it++); // since eqMatrix is a list: all of the previously obtained iterators and references remain valid after the erasing operation and refer to the same elements they were referring before (except, naturally, for those referring to erased elements).
          else it++;
      }else it++;
    }
    Row pr = *pivotRow;
    eqMatrix.erase(pivotRow);
    eqMatrix.push_back(pr);
  }
}

///
// Put the rows of eqns into a giant (but sparse) matrix
// This is to keep ???úreduceFor???ù portable (ie: requiring a matrix)
void eqnsToMatrix(){
  Line::iterator it;
  map<int,float>::iterator it2;
  Row er;
  er.pivotVar=-1;
  varName.resize(numVars=NUM_OF_HEADERS*numTot);
  int i;
  int n;
  while(eqns.size()){
    er.vars.clear();
    for(it=eqns.front().begin();it!=eqns.front().end();it++){
      for(it2=(it->second).begin();it2!=(it->second).end();it2++){
        n=it->first*NUM_OF_HEADERS+it2->first;
        er.vars[n]=it2->second;
      }
    }
    eqMatrix.push_back(er);
    eqns.pop_front();
  }
  for(i=0;i<numTot;i++){
    if(!elemName[i].empty()){
      for(n=0;n<NUM_OF_HEADERS;n++){
        stringstream ss;
        ss<<elemName[i]<<"."<<print_header(n);
        varName[n+i*NUM_OF_HEADERS]=ss.str();
      }
    }

  }
}

int main (void){
    using std::cout;
    cout << "# Building network..\n";
    init_xMas_network(); // build the network, fill xMas_network
    cout << "# Loading typing information..\n";
    analyze_typing_information();
    cout << "# Initializing channels..\n";
    init_chanId(); // get the references to the channels right (call init_chanId so we can find it)
    //cout << "// setting linear equations..\n";
    set_eqns(); // view the xMas network and code it as the linear equalities eqns
    cout << "# putting everything in a big matrix..\n";
    eqnsToMatrix();
    cout << "# reducing the matrix ..\n";
    int i,j;
    for(i=numBufs*NUM_OF_HEADERS;i<numVars;i++){
      reduceFor(i);
      varName[i]="";
    }
    cout << "# printing the matrix..\n";

    stringstream newinvs;
    list<Row >::iterator it;
    for (it=eqMatrix.begin(); it!=eqMatrix.end(); it++){
      bool le=true;
      bool ge=true;
      bool geok=false;
      bool leok=false;
      Vector::iterator v;
      for (v=it->vars.begin();v!=it->vars.end();v++){
        if(v->first>=numVars || varName[v->first].empty()){
          if(v->second < 0) ge = false; else le = false;
          if(!ge && !le) break;
        }else{
          if(v->second < 0) geok = true; else leok = true;
       }
      }
      if(le&&leok){
        for (v=it->vars.begin();v!=it->vars.end();v++){
          v->second = -v->second;
        }
      }
      if(ge&&geok){
        ge = le;
        geok = leok;
        le = true;
        leok = true;
      }
      if(le && leok){
        int plus = 0;
        for (v=it->vars.begin();v!=it->vars.end();v++){
          if(!varName[v->first].empty() && v->second>0){
           if(!elemName[floor(v->first/NUM_OF_HEADERS)].empty()){
             // newinvs << elemName[floor(v->first/NUM_OF_HEADERS)] << " >= 0\n";
             elemName[floor(v->first/NUM_OF_HEADERS)].clear();
           }
            if (plus) cout << " + ";
            if (v->second != 1)
              cout << v->second << " * ";
            cout << varName[v->first];
            plus++;
            if (plus%24==0) cout << "\n  ";
          }
        }
        if(!plus) cout << "0";
        if (ge && le)
          cout << " = ";
        else
          cout << " <= ";
        plus=0;
        for (v=it->vars.begin();v!=it->vars.end();v++){
          if(!varName[v->first].empty() && v->second<0){
           if(!elemName[floor(v->first/NUM_OF_HEADERS)].empty()){
             // newinvs << elemName[floor(v->first/NUM_OF_HEADERS)] << " >= 0\n";
             elemName[floor(v->first/NUM_OF_HEADERS)].clear();
           }
            if (plus) cout << " + ";
            if (v->second != -1)
              cout << (-v->second) << " * ";
            cout << varName[v->first];
            plus++;
            if (plus%24==0) cout << "\n  ";
          }
        }
        if(!plus) cout << "0";
        cout << "\n";
      } // else cout << "Not ok found\n";
    }
    cout << newinvs.str(); 


 
    for(i=0;i<xMas_network_size;i++) {
    if (xMas_network[i].type == queue || xMas_network[i].type == buffer) {
        cout << xMas_network[i].id << " := ";
        for (j=0;j<xMas_network[i].types[0]->num_of_elts;j++) {
        if (j != 0) cout << " + ";
        cout << xMas_network[i].id << "." << print_header(xMas_network[i].types[0]->array[j])   ;
        }
        if (j == 0) cout << "0";
        cout << "\n";
    }
    }

/*
    cout << "# queue sizes\n";
    for(i=0;i<xMas_network_size;i++) {
        if (xMas_network[i].type == queue || xMas_network[i].type == buffer) {
            int qsize = ((int*) (xMas_network[i].field[0]))[0];
            cout << xMas_network[i].id << " <= " << qsize << "\n";
        }
    }
*/
    return 0;
}