Skip to content

Commit

Permalink
Commit before refactoring with HF trait
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinRJDagleish committed Jan 14, 2024
1 parent 6139319 commit 1838331
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 6 deletions.
16 changes: 16 additions & 0 deletions src/calc_type/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ pub(crate) enum HF_Ref {
ROHF_ref,

Check warning on line 16 in src/calc_type/mod.rs

View workflow job for this annotation

GitHub Actions / Rust project

variant `ROHF_ref` is never constructed
}

pub(crate) trait HF {
fn new(
basis: &crate::basisset::BasisSet,
calc_sett: &CalcSettings,
ref_type: HF_Ref,
) -> Self;

fn run_scf(
&mut self,
calc_sett: &CalcSettings,
exec_times: &mut crate::print_utils::ExecTimes,
basis: &crate::basisset::BasisSet,
mol: &crate::molecule::Molecule,
) -> SCF;
}

#[derive(Debug)]
pub struct CalcSettings {
pub max_scf_iter: usize,
Expand Down
274 changes: 270 additions & 4 deletions src/calc_type/uhf.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{CalcSettings, EriArr1, SCF};
use super::{CalcSettings, EriArr1, SCF, HFMatrices, HF_Ref};
use crate::{
basisset::BasisSet,
calc_type::{
Expand All @@ -14,9 +14,275 @@ use ndarray::{linalg::general_mat_mul, s, Array1, Array2, Zip};
use ndarray_linalg::{Eigh, UPLO};

pub(crate) struct UHF {
// Temp. matrices
P_matr_old: Array2<f64>,
E_scf_old: f64,
// Matrices needed for the SCF calculation
hf_matrs: HFMatrices,

// f64 values for SCF calculation
E_scf_prev: f64,
E_scf_curr: f64,
E_tot_prev: f64,
E_tot_curr: f64,
}

// impl HF for UHF {
// pub fn run_scf(
// &mut self,
// calc_sett: &CalcSettings,
// exec_times: &mut crate::print_utils::ExecTimes,
// basis: &crate::basisset::BasisSet,
// mol: &crate::molecule::Molecule,
// ) -> SCF {
// print_scf_header_and_settings(calc_sett, HF_Ref::UHF_ref);
//
// let mut is_scf_conv = false;
// let mut scf = SCF::default();
// let mut diis: Option<DIIS> = if calc_sett.use_diis {
// Some(DIIS::new(
// &calc_sett.diis_sett,
// [basis.no_bf(), basis.no_bf()],
// ))
// } else {
// None
// };
//
// let V_nuc: f64 = if mol.no_atoms() > 100 {
// mol.calc_core_potential_par()
// } else {
// mol.calc_core_potential_ser()
// };
//
// // Calculate 1e ints
// exec_times.start("1e ints");
// self.calc_1e_int_matrs_inp(basis, mol);
// exec_times.stop("1e ints");
//
// // Calculate 2e ints / Schwarz estimates
// exec_times.start("2e ints / Schwarz esti.");
// self.dir_indir_scf_2e_matr(basis, calc_sett);
// exec_times.stop("2e ints / Schwarz esti.");
//
// // Initial guess -> H_core
// }
// }

impl UHF {
pub fn new(basis: &BasisSet, calc_sett: &CalcSettings, ref_type: HF_Ref) -> Self {
let create_beta_vars = match ref_type {
HF_Ref::RHF_ref => false,
HF_Ref::UHF_ref | HF_Ref::ROHF_ref => true,
};

Self {
hf_matrs: HFMatrices::new(basis.no_bf(), calc_sett.use_direct_scf, create_beta_vars),
E_scf_prev: 0.0,
E_scf_curr: 0.0,
E_tot_prev: 0.0,
E_tot_curr: 0.0,
}
}

// pub fn run_scf(
// &mut self,
// calc_sett: &CalcSettings,
// exec_times: &mut crate::print_utils::ExecTimes,
// basis: &BasisSet,
// mol: &Molecule,
// ) -> SCF {
// print_scf_header_and_settings(calc_sett, HF_Ref::UHF_ref);
//
// let mut is_scf_conv = false;
// let mut scf = SCF::default();
// let mut diis: Option<DIIS> = if calc_sett.use_diis {
// Some(DIIS::new(
// &calc_sett.diis_sett,
// [basis.no_bf(), basis.no_bf()],
// ))
// } else {
// None
// };
//
// let V_nuc: f64 = if mol.no_atoms() > 100 {
// mol.calc_core_potential_par()
// } else {
// mol.calc_core_potential_ser()
// };
//
// // Calculate 1e ints
// exec_times.start("1e ints");
// self.calc_1e_int_matrs_inp(basis, mol);
// exec_times.stop("1e ints");
//
// // Calculate 2e ints / Schwarz estimates
// exec_times.start("2e ints / Schwarz esti.");
// self.dir_indir_scf_2e_matr(basis, calc_sett);
// exec_times.stop("2e ints / Schwarz esti.");
//
// // Initial guess -> H_core
// // TODO: [ ] replace with guess
// self.hf_matrs.F_matr_alpha = self.hf_matrs.H_core_matr.clone();
//
// // Print SCF iteration Header
// println!(
// "{:>3} {:^20} {:^20} {:^20} {:^20}",
// "Iter", "E_scf", "E_tot", "ΔE", "RMS(|FPS - SPF|)"
// );
// let mut diis_str = "";
// for scf_iter in 0..=calc_sett.max_scf_iter {
// if scf_iter == 0 {
// self.hf_matrs.F_matr_pr_alpha = self
// .hf_matrs
// .S_matr_inv_sqrt
// .dot(&self.hf_matrs.F_matr_alpha)
// .dot(&self.hf_matrs.S_matr_inv_sqrt);
//
// (self.hf_matrs.orb_ener_alpha, self.hf_matrs.C_matr_MO_alpha) =
// self.hf_matrs.F_matr_pr_alpha.eigh(UPLO::Upper).unwrap();
// self.hf_matrs.C_matr_AO_alpha = self
// .hf_matrs
// .S_matr_inv_sqrt
// .dot(&self.hf_matrs.C_matr_MO_alpha);
//
// Self::calc_P_matr_rhf(
// &mut self.hf_matrs.P_matr_alpha,
// &self.hf_matrs.C_matr_AO_alpha,
// basis.no_occ(),
// );
// if calc_sett.use_direct_scf {
// self.hf_matrs.delta_P_matr_alpha = Some(self.hf_matrs.P_matr_alpha.clone());
// }
// } else {
// /// direct or indirect scf
// match self.hf_matrs.eri_opt {
// Some(ref eri) => {
// Self::calc_new_F_matr_ind_scf_rhf(
// &mut self.hf_matrs.F_matr_alpha,
// &self.hf_matrs.H_core_matr,
// &self.hf_matrs.P_matr_alpha,
// eri,
// );
// }
// None => {
// Self::calc_new_F_matr_dir_scf_rhf(
// &mut self.hf_matrs.F_matr_alpha,
// self.hf_matrs.delta_P_matr_alpha.as_ref().unwrap(),
// self.hf_matrs.schwarz_est.as_ref().unwrap(),
// basis,
// );
// }
// }
// self.E_scf_curr = Self::calc_E_scf_rhf(
// &self.hf_matrs.P_matr_alpha,
// &self.hf_matrs.H_core_matr,
// &self.hf_matrs.F_matr_alpha,
// );
// self.E_tot_curr = self.E_scf_curr + V_nuc;
// // FPS - SPF
// let fps_comm = DIIS::calc_FPS_comm(
// &self.hf_matrs.F_matr_alpha,
// &self.hf_matrs.P_matr_alpha,
// &self.hf_matrs.S_matr,
// );
//
// // F' = S^(-1/2) * F * S^(-1/2)
// self.hf_matrs.F_matr_pr_alpha = self
// .hf_matrs
// .S_matr_inv_sqrt
// .dot(&self.hf_matrs.F_matr_alpha)
// .dot(&self.hf_matrs.S_matr_inv_sqrt);
//
// if calc_sett.use_diis {
// let repl_idx = (scf_iter - 1) % calc_sett.diis_sett.diis_max; // always start with 0
// let err_matr = self
// .hf_matrs
// .S_matr_inv_sqrt
// .dot(&fps_comm)
// .dot(&self.hf_matrs.S_matr_inv_sqrt);
// diis.as_mut().unwrap().push_to_ring_buf(
// &self.hf_matrs.F_matr_pr_alpha,
// &err_matr,
// repl_idx,
// );
//
// if scf_iter >= calc_sett.diis_sett.diis_min {
// let err_set_len = std::cmp::min(calc_sett.diis_sett.diis_max, scf_iter);
// self.hf_matrs.F_matr_pr_alpha =
// diis.as_ref().unwrap().run_DIIS(err_set_len);
// diis_str = "DIIS";
// }
// }
//
// (self.hf_matrs.orb_ener_alpha, self.hf_matrs.C_matr_MO_alpha) =
// self.hf_matrs.F_matr_pr_alpha.eigh(UPLO::Upper).unwrap();
// self.hf_matrs.C_matr_AO_alpha = self
// .hf_matrs
// .S_matr_inv_sqrt
// .dot(&self.hf_matrs.C_matr_MO_alpha);
//
// let delta_E = self.E_scf_curr - self.E_scf_prev;
// let rms_comm_val = (fps_comm.par_iter().map(|x| x * x).sum::<f64>()
// / fps_comm.len() as f64)
// .sqrt();
// println!(
// "{:>3} {:>20.12} {:>20.12} {} {} {:>10} ",
// scf_iter,
// self.E_scf_curr,
// self.E_tot_curr,
// fmt_f64(delta_E, 20, 8, 2),
// fmt_f64(rms_comm_val, 20, 8, 2),
// diis_str
// );
// diis_str = "";
//
// if (delta_E.abs() < calc_sett.e_diff_thrsh)
// && (rms_comm_val < calc_sett.commu_conv_thrsh)
// {
// scf.tot_scf_iter = scf_iter;
// scf.E_scf_conv = self.E_scf_curr;
// scf.E_tot_conv = self.E_tot_curr;
// scf.C_matr_conv_alph = self.hf_matrs.C_matr_AO_alpha.clone();
// scf.P_matr_conv_alph = self.hf_matrs.P_matr_alpha.clone();
// println!("P_matr conv:\n{:>12.8}", &self.hf_matrs.P_matr_alpha);
// scf.C_matr_conv_beta = None;
// scf.P_matr_conv_beta = None;
// scf.orb_E_conv_alph = self.hf_matrs.orb_ener_alpha.clone();
// println!("\nSCF CONVERGED!\n");
// is_scf_conv = true;
// break;
// } else if scf_iter == calc_sett.max_scf_iter {
// println!("\nSCF DID NOT CONVERGE!\n");
// break;
// }
// self.E_scf_prev = self.E_scf_curr;
// self.hf_matrs.P_matr_prev_alpha = self.hf_matrs.P_matr_alpha.clone();
// Self::calc_P_matr_rhf(
// &mut self.hf_matrs.P_matr_alpha,
// &self.hf_matrs.C_matr_AO_alpha,
// basis.no_occ(),
// );
// if calc_sett.use_direct_scf {
// Zip::from(&self.hf_matrs.P_matr_alpha.view())
// .and(&self.hf_matrs.P_matr_prev_alpha.view())
// .par_map_assign_into(
// self.hf_matrs.delta_P_matr_alpha.as_mut().unwrap(),
// |&P, &P_prev| P - P_prev,
// );
// }
// }
// }
//
// if is_scf_conv {
// println!("{:*<55}", "");
// println!("* {:^51} *", "FINAL RESULTS");
// println!("{:*<55}", "");
// println!(" {:^50}", "RHF SCF (in a.u.)");
// println!(" {:=^50} ", "");
// println!(" {:<25}{:>25}", "Total SCF iterations:", scf.tot_scf_iter);
// println!(" {:<25}{:>25.18}", "Final SCF energy:", scf.E_scf_conv);
// println!(" {:<25}{:>25.18}", "Final tot. energy:", scf.E_tot_conv);
// println!("{:*<55}", "");
// }
// scf
// }
}

#[allow(unused)]
Expand Down
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ fn main() {
exec_times.stop("Molecule");

exec_times.start("BasisSet");
// let basis = BasisSet::new("STO-3G", &mol);
let basis = BasisSet::new("6-311++G**", &mol);
let basis = BasisSet::new("STO-3G", &mol);
// let basis = BasisSet::new("6-311++G**", &mol);
exec_times.stop("BasisSet");

//##################################
Expand Down

0 comments on commit 1838331

Please sign in to comment.