diff --git a/caput/mpiutil.py b/caput/mpiutil.py index e7ab679b..3f7cd0dd 100644 --- a/caput/mpiutil.py +++ b/caput/mpiutil.py @@ -56,6 +56,36 @@ warnings.warn("Warning: mpi4py not installed.", ImportWarning) +def cpu_count(comm: MPI.Intracomm = world, scomm: MPI.Intracomm = None): + """Get the number of CPUs available to each process. + + Parameters + ---------- + comm + MPI communicator + scomm + MPI shared memory communicator + + Returns + ------- + cpu_count + Number of cpus available to each process + """ + if scomm is None: + if comm is world: + scomm = world_scomm + else: + scomm = comm.Split_type(MPI.COMM_TYPE_SHARED) + + try: + nproc_per_node = comm.size // scomm.size + except AttributeError: + # This would happend if the default comm is None + nproc_per_node = 1 + + return int(os.cpu_count() // nproc_per_node) + + def enable_mpi_exception_handler(): """Install an MPI aware exception handler.