#include "RcppArmadillo.h" #include <Rcpp.h> // [[Rcpp::depends(RcppArmadillo)]] using namespace Rcpp; using namespace std; // [[Rcpp::export]] Rcpp::List rq_simplex_cpp(arma::mat X, arma::uvec Ih, arma::uvec Ihc, arma::mat IH, unsigned int K, unsigned int n, arma::vec xB, arma::vec P, double tau){ arma::mat invXh = inv(X.rows(Ih)); arma::vec cB = arma::conv_to<arma::vec>::from(P < 0) + P * tau; arma::vec cC = arma::join_cols(arma::ones<arma::vec>(K) * tau, arma::ones<arma::vec>(K) * (1-tau)); // This here below will give problem if there is NA in the design matrix // The sort_index will fail, and so, somehow in sorting then we need to make sure that // index with NA needs to be removed or not used. //Rcout << Xny .rows(Ihc) << "\n"; // Rcout << X.rows(Ihc) << "\n"; // Rcout << "bingooo" << "\n"; // Rcout << Ihc << "\n"; // Rcout << "faar" << "\n"; // Rcout << invXh << "\n"; arma::mat IB2 = -(P * arma::ones<arma::vec>(K).t() % X.rows(Ihc))*invXh; // Rcout << cB << "\n"; // Rcout << "bluee" << "\n"; // Rcout << IB2.t() << "\n"; arma::vec g = IB2.t() * cB; //Rcout << g << "\n"; //Rcout << cC << "\n"; arma::vec d = cC - arma::join_cols(g,-g); d.elem(find(abs(d) < 1.e-15)).fill(0); // Rcout << d << "\n"; arma::uvec s = sort_index(d); arma::vec md = sort(d); s = s.elem(find(md < 0)); md = md.elem(find(md < 0)); arma::vec c = arma::ones<arma::vec>(s.size()); for(unsigned int i = 0; i < s.size(); i++){ if(s(i) >= K){ s(i) = s(i) - K; c(i) = -1; } } //c.elem(s>K).fill(-1); arma::mat C(c.size(),c.size(),arma::fill::eye); C.diag() = c; arma::mat h = arma::join_cols(invXh.cols(s), IB2.cols(s)) * C; arma::vec xm = xB.elem(arma::linspace<arma::uvec>(K,xB.size()-1,xB.size()-K)); xm.elem(find(xm < 0)).fill(0); arma::mat hm = h.rows(arma::linspace<arma::uvec>(K,xB.size()-1,xB.size()-K)); arma::vec alpha(s.size()); arma::uvec q(s.size()); arma::vec cq(s.size()); arma::uvec idx2 = arma::linspace<arma::uvec>(0,s.size()-1,s.size()); for(unsigned int k = 0; k < s.size(); k++){ arma::vec sigma = xm; arma::uvec idx = find(hm.col(k) > 1.e-12); sigma.elem(idx) = xm.elem(idx) / hm(idx, find(idx2 == k)); sigma(find(hm.col(k) <= 1.e-12)).fill(1.e30); alpha.elem(find(idx2 == k)).fill(sigma.min()); q.elem(find(idx2 == k)).fill(sigma.index_min()); cq.elem(find(idx2 == k)) = c.elem(find(idx2 == k)); } arma::vec gain = md % alpha; arma::vec Mgain = sort(gain); arma::uvec IMgain = sort_index(gain); arma::uvec IhMid; double CON = 1.e30; unsigned int j = 0; if(gain.size() < 1){ gain.resize(1); gain(0) = 1; } else { while((CON > 1.e6) & (j < s.size())){ IhMid = Ih; IhMid(s(IMgain(j))) = Ihc(q(IMgain(j))); IhMid = sort(IhMid); if(sum(abs(IH.t() - IhMid.t())) == 0){ CON = 1.e30; } else{ CON = cond(X.rows(IhMid)); } s = s(IMgain(j)); q = q(IMgain(j)); cq = cq(IMgain(j)); alpha = alpha(IMgain(j)); IH = arma::join_rows(IH, arma::conv_to<arma::vec>::from(IhMid)); h = h.col(IMgain(j)); gain = gain(IMgain(j)); md = md(IMgain(j)); j += 1; } } return List::create(Named("CON") = CON, Named("s") = s, Named("g") = g, Named("q") = q, Named("gain") = gain, Named("md") = md, Named("alpha") = alpha, Named("h") = h, Named("IH") = IH, Named("cq") = cq); }