Skip to content

Commit

Permalink
--glm debug build 4
Browse files Browse the repository at this point in the history
  • Loading branch information
chrchang committed Nov 21, 2023
1 parent 41a1237 commit fbfdc19
Showing 1 changed file with 38 additions and 54 deletions.
92 changes: 38 additions & 54 deletions 2.0/plink2_glm_logistic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2875,6 +2875,36 @@ void ComputeHessianD(const double* mm, const double* vv, uint32_t col_ct, uint32
}
}

void ComputeHessianDDebug(const double* mm, const double* vv, uint32_t col_ct, uint32_t row_ct, uint32_t debug_print, double* __restrict dest) {
const uintptr_t col_ctav = RoundUpPow2(col_ct, kDoublePerDVec);
const uintptr_t row_ctav = RoundUpPow2(row_ct, kDoublePerDVec);
const uintptr_t row_ctavp1 = row_ctav + 1;
if (row_ct > 3) {
const uint32_t row_ctm3 = row_ct - 3;
for (uint32_t row_idx = 0; row_idx < row_ctm3; row_idx += 3) {
const double* mm_cur = &(mm[row_idx * col_ctav]);
ComputeTwoDiagTripleProductD(mm_cur, &(mm_cur[col_ctav]), vv, col_ct, &(dest[row_idx * row_ctavp1]), &(dest[(row_idx + 1) * row_ctavp1 - 1]), &(dest[(row_idx + 1) * row_ctavp1]));
if (debug_print) {
logprintf("ComputeHessianDDebug: dest[0]: %g\n", dest[0]);
}
ComputeTwoPlusOneTripleProductD(&(mm_cur[2 * col_ctav]), &(mm_cur[col_ctav]), mm_cur, vv, col_ct, &(dest[(row_idx + 2) * row_ctavp1]), &(dest[(row_idx + 2) * row_ctavp1 - 1]), &(dest[(row_idx + 2) * row_ctavp1 - 2]));
for (uint32_t row_idx2 = row_idx + 3; row_idx2 != row_ct; ++row_idx2) {
ComputeThreeTripleProductD(&(mm[row_idx2 * col_ctav]), mm_cur, &(mm_cur[col_ctav]), &(mm_cur[2 * col_ctav]), vv, col_ct, &(dest[row_idx2 * row_ctav + row_idx]), &(dest[row_idx2 * row_ctav + row_idx + 1]), &(dest[row_idx2 * row_ctav + row_idx + 2]));
}
}
}
switch (row_ct % 3) {
case 0:
ComputeTwoPlusOneTripleProductD(&(mm[(row_ct - 3) * col_ctav]), &(mm[(row_ct - 2) * col_ctav]), &(mm[(row_ct - 1) * col_ctav]), vv, col_ct, &(dest[(row_ct - 3) * row_ctavp1]), &(dest[(row_ct - 2) * row_ctavp1 - 1]), &(dest[(row_ct - 1) * row_ctavp1 - 2]));
// fall through
case 2:
ComputeTwoDiagTripleProductD(&(mm[(row_ct - 2) * col_ctav]), &(mm[(row_ct - 1) * col_ctav]), vv, col_ct, &(dest[(row_ct - 2) * row_ctavp1]), &(dest[(row_ct - 1) * row_ctavp1 - 1]), &(dest[(row_ct - 1) * row_ctavp1]));
break;
case 1:
dest[(row_ct - 1) * row_ctavp1] = TripleProductD(&(mm[(row_ct - 1) * col_ctav]), &(mm[(row_ct - 1) * col_ctav]), vv, col_ct);
}
}

void CholeskyDecompositionD(const double* aa, uint32_t predictor_ct, double* __restrict ll) {
const uintptr_t predictor_ctav = RoundUpPow2(predictor_ct, kDoublePerDVec);
const uintptr_t predictor_ctavp1 = predictor_ctav + 1;
Expand Down Expand Up @@ -3018,16 +3048,16 @@ BoolErr LogisticRegressionD(const double* yy, const double* xx, uint32_t sample_
const uint32_t debug_print = g_debug_on && (iteration == 1);
if (debug_print) {
logprintf("sample_ct: %" PRIuPTR " sample_ctav: %" PRIuPTR "\n", sample_ct, sample_ctav);
for (uint32_t uii = 0; uii < 4; ++uii) {
logprintf("yy[%u]: %g pp[%u]: %g vv[%u]: %g\n", uii, yy[uii], uii, pp[uii], uii, vv[uii]);
}
for (uint32_t uii = sample_ctav - 4; uii < sample_ct; ++uii) {
logprintf("yy[%u]: %g pp[%u]: %g vv[%u]: %g\n", uii, yy[uii], uii, pp[uii], uii, vv[uii]);
}
logprintf("\n");
logprintf("yy[%u]: %g pp[%u]: %g vv[%u]: %g\n", sample_ctav - 1, yy[sample_ctav - 1], sample_ctav - 1, pp[sample_ctav - 1], sample_ctav - 1, vv[sample_ctav - 1]);

logprintf("xx[0][0]: %g xx[0][1]: %g\n", xx[0], xx[1]);
logprintf("xx[0][%u]: %g xx[0][%u]: %g\n", sample_ctav - 2, xx[sample_ctav - 2], sample_ctav - 1, xx[sample_ctav - 1]);
logprintf("xx[1][0]: %g xx[1][1]: %g\n", xx[sample_ctav], xx[sample_ctav + 1]);
logprintf("xx[1][%u]: %g xx[1][%u]: %g\n", sample_ctav * 2 - 2, xx[sample_ctav * 2 - 2], sample_ctav * 2 - 1, xx[sample_ctav * 2 - 1]);
}

ComputeHessianD(xx, vv, sample_ct, predictor_ct, hh);
// ComputeHessianD(xx, vv, sample_ct, predictor_ct, hh);
ComputeHessianDDebug(xx, vv, sample_ct, predictor_ct, debug_print, hh);
if (debug_print) {
logprintf("predictor_ct: %u\n", predictor_ct);
for (uint32_t row_idx = 0; row_idx < predictor_ct; ++row_idx) {
Expand All @@ -3041,71 +3071,25 @@ BoolErr LogisticRegressionD(const double* yy, const double* xx, uint32_t sample_
// grad = X^T P
// Separate categorical loop also possible here
ColMajorVectorMatrixMultiplyStrided(pp, xx, sample_ct, sample_ctav, predictor_ct, grad);
if (debug_print) {
for (uint32_t pred_idx = 0; pred_idx != predictor_ct; ++pred_idx) {
logprintf("grad[%u]: %g\n", pred_idx, grad[pred_idx]);
}
logprintf("\n");
}

// maybe this should use a QR decomposition instead?
CholeskyDecompositionD(hh, predictor_ct, ll);
if (debug_print) {
for (uint32_t row_idx = 0; row_idx < predictor_ct; ++row_idx) {
for (uint32_t col_idx = 0; col_idx <= row_idx; ++col_idx) {
logprintf("ll[%u][%u]: %g\n", row_idx, col_idx, ll[row_idx * predictor_ctav + col_idx]);
}
}
logprintf("\n");
}

SolveLinearSystemD(ll, grad, predictor_ct, dcoef);
if (debug_print) {
for (uint32_t pred_idx = 0; pred_idx != predictor_ct; ++pred_idx) {
logprintf("dcoef[%u]: %g\n", pred_idx, dcoef[pred_idx]);
}
logprintf("\n");
}

for (uint32_t pred_idx = 0; pred_idx != predictor_ct; ++pred_idx) {
coef[pred_idx] -= dcoef[pred_idx];
}
if (debug_print) {
for (uint32_t pred_idx = 0; pred_idx != predictor_ct; ++pred_idx) {
logprintf("coef[%u]: %g\n", pred_idx, coef[pred_idx]);
}
logprintf("\n");
}
// P[i] = \sum_j X[i][j] * coef[j];
ColMajorMatrixVectorMultiplyStrided(xx, coef, sample_ct, sample_ctav, predictor_ct, pp);
if (debug_print) {
for (uint32_t uii = 0; uii < 4; ++uii) {
logprintf("pp[%u]: %g\n", uii, pp[uii]);
}
for (uint32_t uii = sample_ctav - 4; uii < sample_ct; ++uii) {
logprintf("pp[%u]: %g\n", uii, pp[uii]);
}
}

// P[i] = 1 / (1 + exp(-P[i]));
logistic_v_unsafe(pp, sample_ctav);
if (debug_print) {
logprintf("after logistic_v_unsafe:\n");
for (uint32_t uii = 0; uii < 4; ++uii) {
logprintf("yy[%u]: %g pp[%u]: %g\n", uii, yy[uii], uii, pp[uii]);
}
for (uint32_t uii = sample_ctav - 4; uii < sample_ct; ++uii) {
logprintf("yy[%u]: %g pp[%u]: %g\n", uii, yy[uii], uii, pp[uii]);
}
}
const double loglik = ComputeLoglikD(yy, pp, sample_ct);
if (loglik != loglik) {
return 1;
}

if (debug_print) {
logprintf("iteration 1 loglik: %g\n", iteration, loglik);
}
// TODO: determine other non-convergence criteria
if (fabs(loglik - loglik_old) < 1e-8 * (0.05 + fabs(loglik))) {
return 0;
Expand Down

0 comments on commit fbfdc19

Please sign in to comment.