diff --git a/src/roodmus/simulation/run_parakeet.py b/src/roodmus/simulation/run_parakeet.py index f04f820..75df5bf 100644 --- a/src/roodmus/simulation/run_parakeet.py +++ b/src/roodmus/simulation/run_parakeet.py @@ -707,11 +707,17 @@ def add_arguments( ) options_sample_motion = options_sample.add_argument_group("motion") options_sample_motion.add_argument( - "--global_drift", - help="global drift vector. Must specify x- and y-components", + "--global_drift_magnitude", + help="magnitude of global drift in Angstrom", + type=float, + default=0, + ) + options_sample_motion.add_argument( + "--global_drift_std", + help="std of the direction of global" + + " drift for each micrograph in radians", + default=0, type=float, - default=[0, 0], - nargs=2, ) options_sample_motion.add_argument( "--interaction_range", @@ -1011,6 +1017,23 @@ def sample_defocus(c_10: float, c_10_stddev: float) -> float: return np.random.normal(c_10, c_10_stddev) +def sample_drift( + global_drift_magnitude: float, + global_drift_std: float, + global_drift_direction: float, +): + """From the base direction sample a new vector with magnitude equal + to global_drift_magnitude and direction given by global_drift_direction + +- a random value with std equal to global_drift_std + """ + + angle = np.random.normal(global_drift_direction, global_drift_std) + global_drift_vec = global_drift_magnitude * np.array( + [np.cos(angle), np.sin(angle)] + ) + return global_drift_vec + + def get_pdb_files(pdb_dir: str) -> List[str]: """Grab a list of molecule/structure definition files (such as PDBs) to add to micrographs @@ -1232,6 +1255,9 @@ def main(args): ) frames = get_pdb_files(args.pdb_dir) + # sample the global drift base direction + global_drift_direction = np.random.uniform(0, 2 * np.pi) + # loop over the number of images to generate config files progressbar = tqdm( range(args.n_images), @@ -1247,6 +1273,13 @@ def main(args): args.mrc_dir, f"{n_image}".zfill(args.leading_zeros) + ".yaml" ) + # determine the glbal drift vector + args.global_drift = sample_drift( + args.global_drift_magnitude, + args.global_drift_std, + global_drift_direction, + ) + # initialise the configuration # the .h5 files will be saved to _sample.h5 in the mrc_dir config = Configuration(