#include "RcppArmadillo.h"
#include <Rcpp.h>
// [[Rcpp::depends(RcppArmadillo)]]
using namespace Rcpp;
using namespace std; 

// [[Rcpp::export]]
Rcpp::List rq_simplex_cpp(arma::mat X,
					   	arma::uvec Ih,
					   	arma::uvec Ihc,
					   	arma::mat IH,
					   	unsigned int K, 
					   	unsigned int n,
					   	arma::vec xB,
					   	arma::vec P,
					   	double tau){


	arma::mat invXh = inv(X.rows(Ih));
	arma::vec cB = arma::conv_to<arma::vec>::from(P < 0) + P * tau;

	arma::vec cC = arma::join_cols(arma::ones<arma::vec>(K) * tau, arma::ones<arma::vec>(K) * (1-tau));

	// This here below will give problem if there is NA in the design matrix
	// The sort_index will fail, and so, somehow in sorting then we need to make sure that
	// index with NA needs to be removed or not used.
	//Rcout << Xny .rows(Ihc) << "\n";
	// Rcout << X.rows(Ihc) << "\n";
	// Rcout << "bingooo" << "\n";
	// Rcout << Ihc << "\n";
	// Rcout << "faar" << "\n";
	// Rcout << invXh << "\n";
	arma::mat IB2 = -(P * arma::ones<arma::vec>(K).t() % X.rows(Ihc))*invXh;	

	// Rcout << cB << "\n";
	// Rcout << "bluee" << "\n";
	// Rcout << IB2.t() << "\n";
	arma::vec g = IB2.t() * cB;

	//Rcout << g << "\n";
	//Rcout << cC << "\n";
	arma::vec d = cC - arma::join_cols(g,-g);
	d.elem(find(abs(d) < 1.e-15)).fill(0);
	// Rcout << d << "\n";
	arma::uvec s = sort_index(d);
	arma::vec md = sort(d);


	s = s.elem(find(md < 0));
	md = md.elem(find(md < 0));



	arma::vec c = arma::ones<arma::vec>(s.size());

	for(unsigned int i = 0; i < s.size(); i++){
		if(s(i)  >= K){
			s(i) = s(i) - K;
			c(i) = -1;
		}
	}


	//c.elem(s>K).fill(-1);
	



	arma::mat C(c.size(),c.size(),arma::fill::eye);
	C.diag() = c;


	arma::mat h = arma::join_cols(invXh.cols(s), IB2.cols(s)) * C;


	arma::vec xm = xB.elem(arma::linspace<arma::uvec>(K,xB.size()-1,xB.size()-K));
	xm.elem(find(xm < 0)).fill(0); 
	arma::mat hm = h.rows(arma::linspace<arma::uvec>(K,xB.size()-1,xB.size()-K));
	


	arma::vec alpha(s.size());
	arma::uvec q(s.size());
	arma::vec cq(s.size());
	arma::uvec idx2 = arma::linspace<arma::uvec>(0,s.size()-1,s.size());
	


	for(unsigned int k = 0; k < s.size(); k++){

		arma::vec sigma = xm;
		arma::uvec idx = find(hm.col(k) > 1.e-12);

		sigma.elem(idx) = xm.elem(idx) / hm(idx, find(idx2 == k));

		sigma(find(hm.col(k) <= 1.e-12)).fill(1.e30);



		alpha.elem(find(idx2 == k)).fill(sigma.min());
		q.elem(find(idx2 == k)).fill(sigma.index_min());

		cq.elem(find(idx2 == k)) = c.elem(find(idx2 == k));
	}


	arma::vec gain = md % alpha;
	arma::vec Mgain = sort(gain);
	arma::uvec IMgain = sort_index(gain);
	arma::uvec IhMid;

	double CON = 1.e30;
	unsigned int j = 0;



	if(gain.size() < 1){
		gain.resize(1);
		gain(0) = 1;
	} else {
		while((CON > 1.e6) & (j < s.size())){
			IhMid = Ih;
			IhMid(s(IMgain(j))) = Ihc(q(IMgain(j))); 
			IhMid = sort(IhMid);

			if(sum(abs(IH.t() - IhMid.t())) == 0){
				CON = 1.e30;
			} else{
				CON = cond(X.rows(IhMid));
			}
			s = s(IMgain(j));
			q = q(IMgain(j));
			cq = cq(IMgain(j));
			alpha = alpha(IMgain(j));
			IH = arma::join_rows(IH, arma::conv_to<arma::vec>::from(IhMid));
			h = h.col(IMgain(j));
			gain = gain(IMgain(j));
			md = md(IMgain(j));


			j += 1;

		}

	}


	return List::create(Named("CON") = CON,
						Named("s") = s,
						Named("g") = g,
						Named("q") = q,
						Named("gain") = gain,
						Named("md") = md,
						Named("alpha") = alpha,
						Named("h") = h,
						Named("IH") = IH,
						Named("cq") = cq);
}