Skip to content

Commit

Permalink
Merge pull request #486 from jiang-yuha0/mpi_bands_plot
Browse files Browse the repository at this point in the history
Add mpi to `bands_plot` fixes #452
  • Loading branch information
JeromeCCP9 authored Jun 24, 2024
2 parents 6d433b6 + 7a0a2aa commit e3ee42e
Show file tree
Hide file tree
Showing 11 changed files with 1,506 additions and 83 deletions.
213 changes: 130 additions & 83 deletions src/plot.F90
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ module w90_plot_mod

!! This module handles various plots

use w90_comms, only: comms_reduce, w90_comm_type, mpisize, mpirank
use w90_comms, only: comms_reduce, w90_comm_type, mpisize, mpirank, comms_array_split

implicit none

Expand Down Expand Up @@ -49,7 +49,6 @@ subroutine plot_main(atom_data, band_plot, dis_manifold, fermi_energy_list, ferm
use w90_constants, only: eps6, dp
use w90_hamiltonian, only: hamiltonian_get_hr, hamiltonian_write_hr, hamiltonian_setup, &
hamiltonian_write_rmn, hamiltonian_write_tb
use w90_hamiltonian, only: hamiltonian_setup, hamiltonian_get_hr
use w90_io, only: io_stopwatch_start, io_stopwatch_stop
use w90_types, only: kmesh_info_type, wannier_data_type, atom_data_type, dis_manifold_type, &
kpoint_path_type, print_output_type, ws_region_type, ws_distance_type, timer_list_type, &
Expand Down Expand Up @@ -137,10 +136,41 @@ subroutine plot_main(atom_data, band_plot, dis_manifold, fermi_energy_list, ferm
if (allocated(error)) return
endif

call utility_recip_lattice_base(real_lattice, recip_lattice, volume)

if (w90_calculation%bands_plot .or. w90_calculation%fermi_surface_plot .or. &
output_file%write_hr .or. output_file%write_tb) then
! Check if the kmesh includes the gamma point
have_gamma = .false.
do nkp = 1, num_kpts
if (all(abs(kpt_latt(:, nkp)) < eps6)) have_gamma = .true.
end do
if (.not. have_gamma) &
write (stdout, '(1x,a)') '!!!! Kpoint grid does not include Gamma. '// &
& ' Interpolation may be incorrect. !!!!'
! Transform Hamiltonian to WF basis

call hamiltonian_setup(ham_logical, print_output, ws_region, w90_calculation, ham_k, ham_r, &
real_lattice, wannier_centres_translated, irvec, mp_grid, ndegen, &
num_kpts, num_wann, nrpts, rpt_origin, band_plot%mode, stdout, &
timer, error, transport_mode, comm)
if (allocated(error)) return

call hamiltonian_get_hr(atom_data, dis_manifold, ham_logical, real_space_ham, print_output, &
ham_k, ham_r, u_matrix, u_matrix_opt, eigval, kpt_latt, &
real_lattice, wannier_data%centres, wannier_centres_translated, &
irvec, shift_vec, nrpts, num_bands, num_kpts, num_wann, &
have_disentangled, stdout, timer, error, lsitesymmetry, comm)
if (allocated(error)) return

bands_num_spec_points = 0

if (allocated(kpoint_path%labels)) bands_num_spec_points = size(kpoint_path%labels)
endif

if (on_root) then
if (print_output%timing_level > 0) call io_stopwatch_start('plot: main', timer)

call utility_recip_lattice_base(real_lattice, recip_lattice, volume)
! Print the header only if there is something to plot
if (w90_calculation%bands_plot .or. w90_calculation%fermi_surface_plot .or. &
output_file%write_hr .or. w90_calculation%wannier_plot .or. output_file%write_u_matrices &
Expand All @@ -153,42 +183,16 @@ subroutine plot_main(atom_data, band_plot, dis_manifold, fermi_energy_list, ferm

if (w90_calculation%bands_plot .or. w90_calculation%fermi_surface_plot .or. &
output_file%write_hr .or. output_file%write_tb) then
! Check if the kmesh includes the gamma point
have_gamma = .false.
do nkp = 1, num_kpts
if (all(abs(kpt_latt(:, nkp)) < eps6)) have_gamma = .true.
end do
if (.not. have_gamma) &
write (stdout, '(1x,a)') '!!!! Kpoint grid does not include Gamma. '// &
& ' Interpolation may be incorrect. !!!!'
! Transform Hamiltonian to WF basis

call hamiltonian_setup(ham_logical, print_output, ws_region, w90_calculation, ham_k, ham_r, &
real_lattice, wannier_centres_translated, irvec, mp_grid, ndegen, &
num_kpts, num_wann, nrpts, rpt_origin, band_plot%mode, stdout, &
timer, error, transport_mode, comm)
if (allocated(error)) return

call hamiltonian_get_hr(atom_data, dis_manifold, ham_logical, real_space_ham, print_output, &
ham_k, ham_r, u_matrix, u_matrix_opt, eigval, kpt_latt, &
real_lattice, wannier_data%centres, wannier_centres_translated, &
irvec, shift_vec, nrpts, num_bands, num_kpts, num_wann, &
have_disentangled, stdout, timer, error, lsitesymmetry, comm)
if (allocated(error)) return

bands_num_spec_points = 0

if (allocated(kpoint_path%labels)) bands_num_spec_points = size(kpoint_path%labels)

if (w90_calculation%bands_plot) then
call plot_interpolate_bands(mp_grid, real_lattice, band_plot, kpoint_path, &
real_space_ham, ws_region, print_output, recip_lattice, &
num_wann, wannier_data, ham_r, irvec, ndegen, nrpts, &
wannier_centres_translated, ws_distance, &
bands_num_spec_points, stdout, seedname, timer, error, &
comm)
if (allocated(error)) return
endif
! if (w90_calculation%bands_plot) then
! call plot_interpolate_bands(mp_grid, real_lattice, band_plot, kpoint_path, &
! real_space_ham, ws_region, print_output, recip_lattice, &
! num_wann, wannier_data, ham_r, irvec, ndegen, nrpts, &
! wannier_centres_translated, ws_distance, &
! bands_num_spec_points, stdout, seedname, timer, error, &
! comm)
! if (allocated(error)) return
! endif

if (w90_calculation%fermi_surface_plot) then
call plot_fermi_surface(fermi_energy_list, recip_lattice, fermi_surface_plot, num_wann, &
Expand Down Expand Up @@ -282,6 +286,16 @@ subroutine plot_main(atom_data, band_plot, dis_manifold, fermi_energy_list, ferm
endif
end if !on_root

if (w90_calculation%bands_plot) then
call plot_interpolate_bands(mp_grid, real_lattice, band_plot, kpoint_path, &
real_space_ham, ws_region, print_output, recip_lattice, &
num_wann, wannier_data, ham_r, irvec, ndegen, nrpts, &
wannier_centres_translated, ws_distance, &
bands_num_spec_points, stdout, seedname, timer, error, &
comm)
if (allocated(error)) return
endif

if (w90_calculation%wannier_plot) then
call plot_wannier(wannier_plot, wvfn_read, wannier_data, print_output, u_matrix_opt, &
dis_manifold, real_lattice, atom_data, kpt_latt, u_matrix, num_kpts, &
Expand Down Expand Up @@ -396,16 +410,33 @@ subroutine plot_interpolate_bands(mp_grid, real_lattice, band_plot, kpoint_path,
character(len=10), allocatable :: xlabel(:)
character(len=10), allocatable :: ctemp(:)

! mpi variables
integer :: my_node_id, num_nodes, size_rdist, size_ndeg
logical :: on_root
integer, allocatable :: counts(:)
integer, allocatable :: displs(:)

num_nodes = mpisize(comm)
my_node_id = mpirank(comm)

on_root = .false.
if (my_node_id == 0) on_root = .true.

allocate (counts(0:num_nodes - 1))
allocate (displs(0:num_nodes - 1))
!
if (print_output%timing_level > 1) then
call io_stopwatch_start('plot: interpolate_bands', timer)
endif
!
time0 = io_time()
if (on_root) then
if (print_output%timing_level > 1) then
call io_stopwatch_start('plot: interpolate_bands', timer)
endif
!
time0 = io_time()

write (stdout, *)
write (stdout, '(1x,a)') 'Calculating interpolated band-structure'
write (stdout, *)
endif ! on_root
call utility_metric(recip_lattice, recip_metric)
write (stdout, *)
write (stdout, '(1x,a)') 'Calculating interpolated band-structure'
write (stdout, *)
!
allocate (ham_pack((num_wann*(num_wann + 1))/2), stat=ierr)
if (ierr /= 0) then
Expand Down Expand Up @@ -582,26 +613,28 @@ subroutine plot_interpolate_bands(mp_grid, real_lattice, band_plot, kpoint_path,
!
! Write out the kpoints in the path
!
open (newunit=bndunit, file=trim(seedname)//'_band.kpt', form='formatted')
write (bndunit, *) total_pts
do loop_spts = 1, total_pts
write (bndunit, '(3f12.6,3x,a)') (plot_kpoint(loop_i, loop_spts), loop_i=1, 3), "1.0"
end do
close (bndunit)
!
! Write out information on high-symmetry points in the path
!
open (newunit=bndunit, file=trim(seedname)//'_band.labelinfo.dat', form='formatted')
do loop_spts = 1, bands_num_spec_points
if ((MOD(loop_spts, 2) .eq. 1) .and. &
(kpath_print_first_point((loop_spts + 1)/2) .eqv. .false.)) cycle
write (bndunit, '(a,3x,I10,3x,4f18.10)') &
kpoint_path%labels(loop_spts), &
idx_special_points(loop_spts), &
xval_special_points(loop_spts), &
(plot_kpoint(loop_i, idx_special_points(loop_spts)), loop_i=1, 3)
end do
close (bndunit)
if (on_root) then
open (newunit=bndunit, file=trim(seedname)//'_band.kpt', form='formatted')
write (bndunit, *) total_pts
do loop_spts = 1, total_pts
write (bndunit, '(3f12.6,3x,a)') (plot_kpoint(loop_i, loop_spts), loop_i=1, 3), "1.0"
end do
close (bndunit)
!
! Write out information on high-symmetry points in the path
!
open (newunit=bndunit, file=trim(seedname)//'_band.labelinfo.dat', form='formatted')
do loop_spts = 1, bands_num_spec_points
if ((MOD(loop_spts, 2) .eq. 1) .and. &
(kpath_print_first_point((loop_spts + 1)/2) .eqv. .false.)) cycle
write (bndunit, '(a,3x,I10,3x,4f18.10)') &
kpoint_path%labels(loop_spts), &
idx_special_points(loop_spts), &
xval_special_points(loop_spts), &
(plot_kpoint(loop_i, idx_special_points(loop_spts)), loop_i=1, 3)
end do
close (bndunit)
endif ! on_root
!
! Cut H matrix in real-space
!
Expand Down Expand Up @@ -631,15 +664,20 @@ subroutine plot_interpolate_bands(mp_grid, real_lattice, band_plot, kpoint_path,
return
endif
endif

if (on_root .and. print_output%timing_level > 2) then
call io_stopwatch_start('plot: interpolate_bands: loop_kpoints', timer)
endif
! [lp] the s-k and cut codes are very similar when use_ws_distance is used, a complete
! merge after this point is not impossible
do loop_kpt = 1, total_pts
call comms_array_split(total_pts, counts, displs, comm)
! Don't worry about serial run, it is OK to call comms_array_split!
do loop_kpt = displs(my_node_id) + 1, displs(my_node_id) + counts(my_node_id)
! do loop_kpt = 1, total_pts
ham_kprm = cmplx_0
!
if (index(band_plot%mode, 's-k') .ne. 0) then
do irpt = 1, nrpts
! [lp] Shift the WF to have the minimum distance IJ, see also ws_distance.F90
! [lp] Shift the WF to have the minimum distance IJ, see also ws_distance.F90
if (ws_region%use_ws_distance) then
do j = 1, num_wann
do i = 1, num_wann
Expand All @@ -653,7 +691,7 @@ subroutine plot_interpolate_bands(mp_grid, real_lattice, band_plot, kpoint_path,
enddo
enddo
else
! [lp] Original code, without IJ-dependent shift:
! [lp] Original code, without IJ-dependent shift:
rdotk = twopi*dot_product(plot_kpoint(:, loop_kpt), irvec(:, irpt))
fac = cmplx(cos(rdotk), sin(rdotk), dp)/real(ndegen(irpt), dp)
ham_kprm = ham_kprm + fac*ham_r(:, :, irpt)
Expand All @@ -662,7 +700,7 @@ subroutine plot_interpolate_bands(mp_grid, real_lattice, band_plot, kpoint_path,
! end of s-k mode
elseif (index(band_plot%mode, 'cut') .ne. 0) then
do irpt = 1, nrpts_cut
! [lp] Shift the WF to have the minimum distance IJ, see also ws_distance.F90
! [lp] Shift the WF to have the minimum distance IJ, see also ws_distance.F90
if (ws_region%use_ws_distance) then
do j = 1, num_wann
do i = 1, num_wann
Expand All @@ -674,10 +712,10 @@ subroutine plot_interpolate_bands(mp_grid, real_lattice, band_plot, kpoint_path,
enddo
enddo
enddo
! [lp] Original code, without IJ-dependent shift:
! [lp] Original code, without IJ-dependent shift:
else
rdotk = twopi*dot_product(plot_kpoint(:, loop_kpt), irvec_cut(:, irpt))
!~[aam] check divide by ndegen?
!~[aam] check divide by ndegen?
fac = cmplx(cos(rdotk), sin(rdotk), dp)
ham_kprm = ham_kprm + fac*ham_r_cut(:, :, irpt)
endif ! end of use_ws_distance
Expand Down Expand Up @@ -715,22 +753,31 @@ subroutine plot_interpolate_bands(mp_grid, real_lattice, band_plot, kpoint_path,
end if
!
end do

call comms_reduce(eig_int(1, 1), num_wann*total_pts, 'SUM', error, comm)
call comms_reduce(bands_proj(1, 1), num_wann*total_pts, 'SUM', error, comm)

if (on_root .and. print_output%timing_level > 2) then
call io_stopwatch_stop('plot: interpolate_bands: loop_kpoints', timer)
endif
!
! Interpolation Finished!
! Now we write plotting files
!
emin = minval(eig_int) - 1.0_dp
emax = maxval(eig_int) + 1.0_dp
if (on_root) then
emin = minval(eig_int) - 1.0_dp
emax = maxval(eig_int) + 1.0_dp

if (index(band_plot%format, 'gnu') > 0) then
call plot_interpolate_gnuplot(band_plot, kpoint_path, bands_num_spec_points, num_wann)
endif
if (index(band_plot%format, 'xmgr') > 0) then
call plot_interpolate_xmgrace(kpoint_path, bands_num_spec_points, num_wann)
if (index(band_plot%format, 'gnu') > 0) then
call plot_interpolate_gnuplot(band_plot, kpoint_path, bands_num_spec_points, num_wann)
endif
if (index(band_plot%format, 'xmgr') > 0) then
call plot_interpolate_xmgrace(kpoint_path, bands_num_spec_points, num_wann)
endif
write (stdout, '(1x,a,f11.3,a)') &
'Time to calculate interpolated band structure ', io_time() - time0, ' (sec)'
write (stdout, *)
endif
write (stdout, '(1x,a,f11.3,a)') &
'Time to calculate interpolated band structure ', io_time() - time0, ' (sec)'
write (stdout, *)

if (allocated(ham_r_cut)) then
deallocate (ham_r_cut, stat=ierr)
Expand Down
6 changes: 6 additions & 0 deletions test-suite/tests/jobconfig
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ program = WANNIER90_LABELINFO
inputs_args = ('silicon.win', '')
output = silicon_band.labelinfo.dat

# Silicon, 4 valence bands + 4 conduction bands; interpolated bandstructure only, checking the band.dat file
[testw90_example03_bands_plot]
program = WANNIER90_BANDS_PLOT
inputs_args = ('silicon.win', '')
output = silicon_band.dat

# Copper, states around the Fermi level; Fermi surface
[testw90_example04]
program = WANNIER90_WOUT_OK
Expand Down
5 changes: 5 additions & 0 deletions test-suite/tests/testw90_example03_bands_plot/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
silicon_band.dat
silicon_band.gnu
silicon_band.kpt
silicon_band.labelinfo.dat
silicon.wout
14 changes: 14 additions & 0 deletions test-suite/tests/testw90_example03_bands_plot/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
all: $(patsubst %.chk.fmt.bz2, %.chk, $(wildcard *.chk.fmt.bz2)) $(patsubst %.mmn.bz2, %.mmn, $(wildcard *.mmn.bz2))

%.chk: %.chk.fmt.bz2
$(eval SEEDNAME:=$(patsubst %.chk.fmt.bz2, %, $<))
echo $(SEEDNAME)
cat $< | bunzip2 - > $(SEEDNAME).chk.fmt && ../../../w90chk2chk.x -f2u $(SEEDNAME) && rm $(SEEDNAME).chk.fmt

%.mmn: %.mmn.bz2
$(eval SEEDNAME:=$(patsubst %.mmn.bz2, %, $<))
echo $(SEEDNAME)
cat $< | bunzip2 - > $(SEEDNAME).mmn


.PHONY: all
Loading

0 comments on commit e3ee42e

Please sign in to comment.