-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make functions consistently return Pytorch tensors and require tensors as input #365
Comments
This was referenced Oct 9, 2024
facebook-github-bot
pushed a commit
that referenced
this issue
Oct 15, 2024
…mark (#399) Summary: ### PR Description This PR addresses the first step in making AEPsych's functions consistently return PyTorch tensors and expect tensors as input, improving compatibility with GPUs and reducing redundant conversions between NumPy arrays and PyTorch tensors(partially solving #365). #### Key changes include: 1. **Conversion of `np.arrays` to tensors** in the following files: - **`aepsych/models/base.py`**: - Refactored the `p_below_threshold` method to operate fully with PyTorch tensors. - Replaced `norm.cdf()` with `torch.distributions.Normal(0, 1).cdf()` for better GPU compatibility. - **`aepsych/benchmark/problem.py`**: - Significant changes made to ensure consistent use of tensors across the pipeline. - The result of `f_threshold()` now directly returns a PyTorch tensor, ensuring consistency. - Additionally, used `detach().cpu().numpy()` in places where the `super().evaluate()` method returns float values, ensuring compatibility. 2. **Updates in `aepsych/tests/test_benchmark.py`**: - Migrated all operations from NumPy to PyTorch. - This includes calculations for Brier score and misclassification error, now utilizing `torch.mean()`, `torch.square()`, `torch.isclose()`, and `torch.all()` to fully align with tensor operations. #### Stability: All test cases have passed successfully in the workflow. Pull Request resolved: #399 Reviewed By: crasanders Differential Revision: D64245698 Pulled By: JasonKChow fbshipit-source-id: 3ed3d7b627f488ec61da5b9013a46cafc8b83556
facebook-github-bot
pushed a commit
that referenced
this issue
Oct 17, 2024
…rmalization (#403) Summary: This PR addresses the second part of issue #365, focusing on the `Strategy` class and how data is added and normalized, transitioning the process to use tensors instead of NumPy operations. The changes were made specifically within the `normalize_inputs` method of the `Strategy` class. Previously, this method had mismatched docstrings indicating `np.array` usage. Now, it consistently accepts and returns tensors, performing all operations within tensors. The `normalize_inputs` method is called in `add_data()` (where the confusion arises), as the data passed can vary (either tensors or `np.array`). To resolve this, the method now acts as the first step, accepting both formats and then converting everything to tensors for consistent operations (model fitting later on). It’s also crucial to ensure the data type is `float64`, as `gpytorch` does not support other data types. Additionally, a detailed docstring was added to clarify the method's expectations and ensure its proper use going forward. Pull Request resolved: #403 Reviewed By: crasanders Differential Revision: D64343236 Pulled By: JasonKChow fbshipit-source-id: 413077605f4fa46b82405897c713cbc62b58a3f3
facebook-github-bot
pushed a commit
that referenced
this issue
Oct 18, 2024
…Torch Tensors (#406) Summary: This PR partially solves #365, with changes as follows: - **Grid Generation and Meshgrid Handling (`dim_grid` and `get_lse_interval`)**: Transitioned from `np.mgrid` to `torch.meshgrid` and `torch.linspace`, simplifying setup and ensuring full compatibility with PyTorch, reducing conversion steps. - **Interpolation (`interpolate_monotonic`)**: Switched from `np.searchsorted` to `torch.searchsorted` and used `torch.where` for interpolation, enabling efficient, single-pass processing and maintaining overall consistency. - **Probability and Quantile Calculations (`get_lse_interval`)**: Updated to use `torch.distributions.Normal`, `torch.median`, and `torch.quantile`. - **Generalized Vectorization (`get_jnd_1d` and `get_jnd_multid`)**: Functions are now fully vectorized using PyTorch’s capabilities, avoiding element-wise iteration. Pull Request resolved: #406 Reviewed By: crasanders Differential Revision: D64563850 Pulled By: JasonKChow fbshipit-source-id: 867b86b6822eb3380a2f1e0535849b3ea44a5a05
facebook-github-bot
pushed a commit
that referenced
this issue
Oct 22, 2024
…nd update related tests (#411) Summary: Partially Solving #365 : - Converted all functions in `test_functions.py` to use PyTorch tensors for both inputs and outputs, ensuring compatibility with PyTorch workflows. - Retained NumPy for specific cases like interpolation in functions such as `make_songetal_threshfun` and `make_songetal_testfun`, as there is no direct PyTorch implementation for `CubicSpline` and `interp1d`, so SciPy is still used with NumPy arrays internally. - Added type hints to improve overall code robustness and support static linting. - Made minor changes to `test_pairwise_probit.py` and `test_mi.py` to adapt to the new tensor structures in `test_functions.py` and ensure smooth integration. Pull Request resolved: #411 Reviewed By: crasanders Differential Revision: D64721427 Pulled By: JasonKChow fbshipit-source-id: b3b6a78af7686397f4132df139c5290a478fee6c
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
AEPsych is primarily built on top of GPytorch and Botorch, whose methods all depend on Pytorch tensors. However AEPsych also uses some routines from scipy, which depends on numpy arrays. This has led inconsistent return types throughout the code base, and a lot of conversion between arrays and tensors. To make things consistent, all public methods and functions should return tensors and expect tensors as input. This would also allow us to enable GPU compatibility; we should use tensor.cpu().numpy() whenever we need to pass a tensor into a numpy function. We should only convert back and forth between arrays where absolutely necessary. Note that there may be certain functions that have to accept and return arrays (possibly the ask/tell message handlers?), but these exceptions should be tested and verified.
The text was updated successfully, but these errors were encountered: