-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from andfoy/add_global_registries
Define per-process dtype and device registries
- Loading branch information
Showing
27 changed files
with
1,133 additions
and
175 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
defmodule ExTorch.Application do | ||
@moduledoc false | ||
use Application | ||
|
||
def start(_type, _args) do | ||
children = [ | ||
{Registry, [name: ExTorch.Registry.DType, keys: :duplicate]}, | ||
{Registry, [name: ExTorch.Registry.Device, keys: :duplicate]}, | ||
] | ||
Supervisor.start_link(children, strategy: :one_for_one) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
defprotocol ExTorch.Protocol.DefaultStruct do | ||
@doc """ | ||
Given a struct filled with default values, return a struct with valid values. | ||
""" | ||
def replace_defaults(struct_defaults) | ||
end | ||
|
||
defimpl ExTorch.Protocol.DefaultStruct, for: Any do | ||
def replace_defaults(in_struct) do | ||
in_struct | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
defmodule ExTorch.Registry.Device do | ||
@moduledoc false | ||
require ExTorch.Device | ||
|
||
@after_compile __MODULE__ | ||
def __after_compile__(_env, bytecode) do | ||
__MODULE__ | ||
|> :code.which() | ||
|> to_string() | ||
|> File.write(bytecode) | ||
end | ||
|
||
@doc """ | ||
Get the default device on which `ExTorch.Tensor` structs are allocated. | ||
## Notes | ||
By default, `ExTorch` will set `:cpu` as the default device. | ||
""" | ||
@doc kind: :process_values | ||
@spec get_default_device() :: ExTorch.Device.device() | ||
def get_default_device() do | ||
case Registry.values(__MODULE__, :device, self()) do | ||
[] -> | ||
Registry.register(__MODULE__, :device, :cpu) | ||
:cpu | ||
[device] -> device | ||
end | ||
end | ||
|
||
@doc """ | ||
Sets the default device on which `ExTorch.Tensor` structs are allocated. | ||
This function does not affect factory function calls which are called with an | ||
explicit `device` argument. Factory calls will be performed as if they were | ||
passed `device` as an argument. | ||
The default device is initially `:cpu`. If you set the default tensor device to | ||
another device (e.g., `:cuda`) without a device index, tensors will be | ||
allocated on whatever the current device for the device type. | ||
## Examples | ||
# By default, the device will be :cpu | ||
iex> a = ExTorch.tensor([1.2, 3]) | ||
iex> a.device | ||
:cpu | ||
# Change the default device to :cuda | ||
iex> ExTorch.set_default_device(:cuda) | ||
# Check that tensors are now being created on gpu | ||
iex> a = ExTorch.tensor([1.2, 3]) | ||
iex> a.device | ||
{:cuda, 0} | ||
""" | ||
@doc kind: :process_values | ||
@spec set_default_device(ExTorch.Device.device()) :: ExTorch.Device.device() | ||
def set_default_device(device) when ExTorch.Device.is_device(device) do | ||
case Registry.values(__MODULE__, :device, self()) do | ||
[] -> | ||
Registry.register(__MODULE__, :device, device) | ||
_ -> | ||
Registry.unregister(__MODULE__, :device) | ||
Registry.register(__MODULE__, :device, device) | ||
end | ||
|
||
device | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
defmodule ExTorch.Registry.DType do | ||
@moduledoc false | ||
|
||
@after_compile __MODULE__ | ||
def __after_compile__(_env, bytecode) do | ||
__MODULE__ | ||
|> :code.which() | ||
|> to_string() | ||
|> File.write(bytecode) | ||
end | ||
|
||
require ExTorch.DType | ||
|
||
@doc """ | ||
Get the current default floating point dtype for the current process. | ||
## Notes | ||
By default, `ExTorch` will set the default dtype to `:float32`. | ||
""" | ||
@doc kind: :process_values | ||
@spec get_default_dtype() :: ExTorch.DType.dtype() | ||
def get_default_dtype() do | ||
case Registry.select(__MODULE__, [{{:dtype, self(), :"$1"}, [], [:"$1"]}]) do | ||
[] -> | ||
Registry.register(__MODULE__, :dtype, :float32) | ||
:float32 | ||
[value] -> value | ||
end | ||
end | ||
|
||
@doc """ | ||
Sets the default floating point dtype of the current process to `dtype`. | ||
Supports `:float32` and `:float64` as inputs. Other dtypes may be accepted | ||
without complaint but are not supported and are unlikely to work as expected. | ||
When PyTorch is initialized its default floating point dtype is `:float32`, | ||
and the intent of `set_default_dtype(:float64)` is to facilitate NumPy-like | ||
type inference. The default floating point dtype is used to: | ||
1. Implicitly determine the default complex dtype. When the default floating | ||
point type is `:float32` the default complex dtype is `:complex64`, and | ||
when the default floating point type is `:float64` the default complex | ||
type is `:complex128`. | ||
2. Infer the dtype for tensors constructed using Elixir floats or | ||
`ExTorch.Complex` numbers. See examples below. | ||
3. Determine the result of type promotion between bool and integer tensors and | ||
Elixir floats and `ExTorch.Complex` numbers. | ||
## Examples | ||
# Initial default for floating point is :float32 | ||
iex> a = ExTorch.tensor([1.2, 3]) | ||
iex> a.dtype | ||
:float | ||
# Initial default for floating point complex numbers is :complex64 | ||
iex> b = ExTorch.tensor([1.2, ExTorch.Complex.complex(0, 3)]) | ||
iex> b.dtype | ||
:complex_float | ||
# Changing the default dtype to :float64 | ||
iex> ExTorch.set_default_dtype(:float64) | ||
# Floats are now interpreted as float64 | ||
iex> a = ExTorch.tensor([1.2, 3]) | ||
iex> a.dtype | ||
:double | ||
# Complex numbers are now interpreted as :complex128 | ||
iex> b = ExTorch.tensor([1.2, ExTorch.Complex.complex(0, 3)]) | ||
iex> b.dtype | ||
:complex_double | ||
""" | ||
@doc kind: :process_values | ||
@spec set_default_dtype(ExTorch.DType.dtype()) :: ExTorch.DType.dtype() | ||
def set_default_dtype(dtype) when ExTorch.DType.is_dtype(dtype) do | ||
case Registry.values(__MODULE__, :dtype, self()) do | ||
[] -> | ||
Registry.register(__MODULE__, :dtype, dtype) | ||
_ -> | ||
Registry.unregister(__MODULE__, :dtype) | ||
Registry.register(__MODULE__, :dtype, dtype) | ||
end | ||
|
||
dtype | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.