Skip to content
Snippets Groups Projects
rq_simplex_cpp.cpp 3.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "RcppArmadillo.h"
    #include <Rcpp.h>
    // [[Rcpp::depends(RcppArmadillo)]]
    using namespace Rcpp;
    using namespace std; 
    
    // [[Rcpp::export]]
    Rcpp::List rq_simplex_cpp(arma::mat X,
    					   	arma::uvec Ih,
    					   	arma::uvec Ihc,
    					   	arma::mat IH,
    					   	unsigned int K, 
    					   	unsigned int n,
    					   	arma::vec xB,
    					   	arma::vec P,
    					   	double tau){
    
    
    	arma::mat invXh = inv(X.rows(Ih));
    	arma::vec cB = arma::conv_to<arma::vec>::from(P < 0) + P * tau;
    
    	arma::vec cC = arma::join_cols(arma::ones<arma::vec>(K) * tau, arma::ones<arma::vec>(K) * (1-tau));
    
    
    hgb's avatar
    hgb committed
    	// This here below will give problem if there is NA in the design matrix
    	// The sort_index will fail, and so, somehow in sorting then we need to make sure that
    	// index with NA needs to be removed or not used.
    
    	//Rcout << Xny .rows(Ihc) << "\n";
    
    hgb's avatar
    hgb committed
    	// Rcout << X.rows(Ihc) << "\n";
    	// Rcout << "bingooo" << "\n";
    	// Rcout << Ihc << "\n";
    	// Rcout << "faar" << "\n";
    	// Rcout << invXh << "\n";
    
    	arma::mat IB2 = -(P * arma::ones<arma::vec>(K).t() % X.rows(Ihc))*invXh;	
    
    
    hgb's avatar
    hgb committed
    	// Rcout << cB << "\n";
    	// Rcout << "bluee" << "\n";
    	// Rcout << IB2.t() << "\n";
    
    	arma::vec g = IB2.t() * cB;
    
    hgb's avatar
    hgb committed
    
    	//Rcout << g << "\n";
    	//Rcout << cC << "\n";
    
    	arma::vec d = cC - arma::join_cols(g,-g);
    	d.elem(find(abs(d) < 1.e-15)).fill(0);
    
    hgb's avatar
    hgb committed
    	// Rcout << d << "\n";
    
    	arma::uvec s = sort_index(d);
    	arma::vec md = sort(d);
    
    
    	s = s.elem(find(md < 0));
    	md = md.elem(find(md < 0));
    
    
    
    	arma::vec c = arma::ones<arma::vec>(s.size());
    
    	for(unsigned int i = 0; i < s.size(); i++){
    		if(s(i)  >= K){
    			s(i) = s(i) - K;
    			c(i) = -1;
    		}
    	}
    
    
    	//c.elem(s>K).fill(-1);
    	
    
    
    
    	arma::mat C(c.size(),c.size(),arma::fill::eye);
    	C.diag() = c;
    
    
    	arma::mat h = arma::join_cols(invXh.cols(s), IB2.cols(s)) * C;
    
    
    	arma::vec xm = xB.elem(arma::linspace<arma::uvec>(K,xB.size()-1,xB.size()-K));
    	xm.elem(find(xm < 0)).fill(0); 
    	arma::mat hm = h.rows(arma::linspace<arma::uvec>(K,xB.size()-1,xB.size()-K));
    	
    
    
    	arma::vec alpha(s.size());
    	arma::uvec q(s.size());
    	arma::vec cq(s.size());
    	arma::uvec idx2 = arma::linspace<arma::uvec>(0,s.size()-1,s.size());
    	
    
    
    	for(unsigned int k = 0; k < s.size(); k++){
    
    		arma::vec sigma = xm;
    		arma::uvec idx = find(hm.col(k) > 1.e-12);
    
    		sigma.elem(idx) = xm.elem(idx) / hm(idx, find(idx2 == k));
    
    		sigma(find(hm.col(k) <= 1.e-12)).fill(1.e30);
    
    
    
    		alpha.elem(find(idx2 == k)).fill(sigma.min());
    		q.elem(find(idx2 == k)).fill(sigma.index_min());
    
    		cq.elem(find(idx2 == k)) = c.elem(find(idx2 == k));
    	}
    
    
    	arma::vec gain = md % alpha;
    	arma::vec Mgain = sort(gain);
    	arma::uvec IMgain = sort_index(gain);
    	arma::uvec IhMid;
    
    	double CON = 1.e30;
    	unsigned int j = 0;
    
    
    
    	if(gain.size() < 1){
    		gain.resize(1);
    		gain(0) = 1;
    	} else {
    		while((CON > 1.e6) & (j < s.size())){
    			IhMid = Ih;
    			IhMid(s(IMgain(j))) = Ihc(q(IMgain(j))); 
    			IhMid = sort(IhMid);
    
    			if(sum(abs(IH.t() - IhMid.t())) == 0){
    				CON = 1.e30;
    			} else{
    				CON = cond(X.rows(IhMid));
    			}
    			s = s(IMgain(j));
    			q = q(IMgain(j));
    			cq = cq(IMgain(j));
    			alpha = alpha(IMgain(j));
    			IH = arma::join_rows(IH, arma::conv_to<arma::vec>::from(IhMid));
    			h = h.col(IMgain(j));
    			gain = gain(IMgain(j));
    			md = md(IMgain(j));
    
    
    			j += 1;
    
    		}
    
    	}
    
    
    	return List::create(Named("CON") = CON,
    						Named("s") = s,
    						Named("g") = g,
    						Named("q") = q,
    						Named("gain") = gain,
    						Named("md") = md,
    						Named("alpha") = alpha,
    						Named("h") = h,
    						Named("IH") = IH,
    						Named("cq") = cq);
    }