Skip to content

Commit

Permalink
Merge pull request #18 from andfoy/add_global_registries
Browse files Browse the repository at this point in the history
Define per-process dtype and device registries
  • Loading branch information
andfoy authored Sep 7, 2023
2 parents 9e59ece + 2587017 commit 167caa4
Show file tree
Hide file tree
Showing 27 changed files with 1,133 additions and 175 deletions.
2 changes: 2 additions & 0 deletions lib/extorch.ex
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ defmodule ExTorch do
extends(ExTorch.Native.Tensor.Ops.Indexing)
extends(ExTorch.Native.Tensor.Ops.PointWise)
extends(ExTorch.Native.Tensor.Ops.Other)
extends(ExTorch.Registry.DType)
extends(ExTorch.Registry.Device)

# extends(ExTorch.Native.Tensor.Info)
end
12 changes: 12 additions & 0 deletions lib/extorch/application.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
defmodule ExTorch.Application do
@moduledoc false
use Application

def start(_type, _args) do
children = [
{Registry, [name: ExTorch.Registry.DType, keys: :duplicate]},
{Registry, [name: ExTorch.Registry.Device, keys: :duplicate]},
]
Supervisor.start_link(children, strategy: :one_for_one)
end
end
3 changes: 3 additions & 0 deletions lib/extorch/native/macros.ex
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,9 @@ defmodule ExTorch.Native.Macros do
quote do
unquote(Macro.var(kwarg, nil)) =
struct(unquote(struct_name), unquote(Macro.var(:kwargs, nil)))

unquote(Macro.var(kwarg, nil)) =
ExTorch.Protocol.DefaultStruct.replace_defaults(unquote(Macro.var(kwarg, nil)))
end
end)

Expand Down
505 changes: 394 additions & 111 deletions lib/extorch/native/tensor/creation.ex

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions lib/extorch/native/tensor/info.ex
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ defmodule ExTorch.Native.Tensor.Info do
@spec requires_grad(ExTorch.Tensor.t()) :: boolean()
defbinding(requires_grad(tensor))

@doc """
Get the `memory_format` of a tensor.
## Arguments
- tensor (`ExTorch.Tensor`): Input tensor
"""
@spec memory_format(ExTorch.Tensor.t()) :: ExTorch.MemoryFormat.memory_format()
defbinding(memory_format(tensor))

@doc """
Get the `layout` of a tensor.
## Arguments
- tensor (`ExTorch.Tensor`): Input tensor
"""
@spec layout(ExTorch.Tensor.t()) :: ExTorch.Layout.layout()
defbinding(layout(tensor))

@doc """
Get a human readable representation of a tensor.
Expand Down
12 changes: 12 additions & 0 deletions lib/extorch/protocol/default_struct.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
defprotocol ExTorch.Protocol.DefaultStruct do
@doc """
Given a struct filled with default values, return a struct with valid values.
"""
def replace_defaults(struct_defaults)
end

defimpl ExTorch.Protocol.DefaultStruct, for: Any do
def replace_defaults(in_struct) do
in_struct
end
end
69 changes: 69 additions & 0 deletions lib/extorch/registry/device.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
defmodule ExTorch.Registry.Device do
@moduledoc false
require ExTorch.Device

@after_compile __MODULE__
def __after_compile__(_env, bytecode) do
__MODULE__
|> :code.which()
|> to_string()
|> File.write(bytecode)
end

@doc """
Get the default device on which `ExTorch.Tensor` structs are allocated.
## Notes
By default, `ExTorch` will set `:cpu` as the default device.
"""
@doc kind: :process_values
@spec get_default_device() :: ExTorch.Device.device()
def get_default_device() do
case Registry.values(__MODULE__, :device, self()) do
[] ->
Registry.register(__MODULE__, :device, :cpu)
:cpu
[device] -> device
end
end

@doc """
Sets the default device on which `ExTorch.Tensor` structs are allocated.
This function does not affect factory function calls which are called with an
explicit `device` argument. Factory calls will be performed as if they were
passed `device` as an argument.
The default device is initially `:cpu`. If you set the default tensor device to
another device (e.g., `:cuda`) without a device index, tensors will be
allocated on whatever the current device for the device type.
## Examples
# By default, the device will be :cpu
iex> a = ExTorch.tensor([1.2, 3])
iex> a.device
:cpu
# Change the default device to :cuda
iex> ExTorch.set_default_device(:cuda)
# Check that tensors are now being created on gpu
iex> a = ExTorch.tensor([1.2, 3])
iex> a.device
{:cuda, 0}
"""
@doc kind: :process_values
@spec set_default_device(ExTorch.Device.device()) :: ExTorch.Device.device()
def set_default_device(device) when ExTorch.Device.is_device(device) do
case Registry.values(__MODULE__, :device, self()) do
[] ->
Registry.register(__MODULE__, :device, device)
_ ->
Registry.unregister(__MODULE__, :device)
Registry.register(__MODULE__, :device, device)
end

device
end
end
90 changes: 90 additions & 0 deletions lib/extorch/registry/dtype.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
defmodule ExTorch.Registry.DType do
@moduledoc false

@after_compile __MODULE__
def __after_compile__(_env, bytecode) do
__MODULE__
|> :code.which()
|> to_string()
|> File.write(bytecode)
end

require ExTorch.DType

@doc """
Get the current default floating point dtype for the current process.
## Notes
By default, `ExTorch` will set the default dtype to `:float32`.
"""
@doc kind: :process_values
@spec get_default_dtype() :: ExTorch.DType.dtype()
def get_default_dtype() do
case Registry.select(__MODULE__, [{{:dtype, self(), :"$1"}, [], [:"$1"]}]) do
[] ->
Registry.register(__MODULE__, :dtype, :float32)
:float32
[value] -> value
end
end

@doc """
Sets the default floating point dtype of the current process to `dtype`.
Supports `:float32` and `:float64` as inputs. Other dtypes may be accepted
without complaint but are not supported and are unlikely to work as expected.
When PyTorch is initialized its default floating point dtype is `:float32`,
and the intent of `set_default_dtype(:float64)` is to facilitate NumPy-like
type inference. The default floating point dtype is used to:
1. Implicitly determine the default complex dtype. When the default floating
point type is `:float32` the default complex dtype is `:complex64`, and
when the default floating point type is `:float64` the default complex
type is `:complex128`.
2. Infer the dtype for tensors constructed using Elixir floats or
`ExTorch.Complex` numbers. See examples below.
3. Determine the result of type promotion between bool and integer tensors and
Elixir floats and `ExTorch.Complex` numbers.
## Examples
# Initial default for floating point is :float32
iex> a = ExTorch.tensor([1.2, 3])
iex> a.dtype
:float
# Initial default for floating point complex numbers is :complex64
iex> b = ExTorch.tensor([1.2, ExTorch.Complex.complex(0, 3)])
iex> b.dtype
:complex_float
# Changing the default dtype to :float64
iex> ExTorch.set_default_dtype(:float64)
# Floats are now interpreted as float64
iex> a = ExTorch.tensor([1.2, 3])
iex> a.dtype
:double
# Complex numbers are now interpreted as :complex128
iex> b = ExTorch.tensor([1.2, ExTorch.Complex.complex(0, 3)])
iex> b.dtype
:complex_double
"""
@doc kind: :process_values
@spec set_default_dtype(ExTorch.DType.dtype()) :: ExTorch.DType.dtype()
def set_default_dtype(dtype) when ExTorch.DType.is_dtype(dtype) do
case Registry.values(__MODULE__, :dtype, self()) do
[] ->
Registry.register(__MODULE__, :dtype, dtype)
_ ->
Registry.unregister(__MODULE__, :dtype)
Registry.register(__MODULE__, :dtype, dtype)
end

dtype
end
end
53 changes: 51 additions & 2 deletions lib/extorch/tensor/options.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ defmodule ExTorch.Tensor.Options do

defstruct [
# Type of the tensor reference
dtype: :float64,
dtype: nil,

# The layout format of the tensor
layout: :strided,

# Device where the tensor lives on
device: :cpu,
device: nil,

# true if the tensor will accumulate gradients, false otherwise
requires_grad: false,
Expand All @@ -42,4 +42,53 @@ defmodule ExTorch.Tensor.Options do
# The memory format which the tensor should have
memory_format: :contiguous
]

defimpl ExTorch.Protocol.DefaultStruct, for: __MODULE__ do
@spec replace_defaults(ExTorch.Tensor.Options.t()) :: ExTorch.Tensor.Options.t()
def replace_defaults(%ExTorch.Tensor.Options{dtype: dtype, device: device} = in_struct) do
dtype =
case dtype do
nil -> ExTorch.get_default_dtype()
_ -> dtype
end

device =
case device do
nil -> ExTorch.get_default_device()
_ -> device
end

struct(in_struct, dtype: dtype, device: device)
end
end

@doc false
@spec merge_input(ExTorch.Tensor.t(), ExTorch.Tensor.Options.t()) :: ExTorch.Tensor.Options.t()
def merge_input(%ExTorch.Tensor{} = input, %ExTorch.Tensor.Options{} = options) do
dtype =
case options.dtype do
:auto -> ExTorch.Tensor.dtype(input)
dtype -> dtype
end

device =
case options.device do
:auto -> ExTorch.Tensor.device(input)
device -> device
end

layout =
case options.layout do
nil -> ExTorch.Tensor.layout(input)
layout -> layout
end

mem_fmt =
case options.memory_format do
:preserve -> ExTorch.Tensor.memory_format(input)
mem_fmt -> mem_fmt
end

struct(options, dtype: dtype, device: device, layout: layout, memory_format: mem_fmt)
end
end
1 change: 1 addition & 0 deletions lib/extorch/utils/print_options.ex
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ defmodule ExTorch.Utils.PrintOptions do
sci_mode: boolean() | nil
}

@derive [ExTorch.Protocol.DefaultStruct]
defstruct precision: 4,
threshold: 1000.0,
edgeitems: 3,
Expand Down
25 changes: 22 additions & 3 deletions lib/extorch/utils/types.ex
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,42 @@ defmodule ExTorch.Utils.Types do
end

def collect_types(float, acc) when is_float(float) and float <= 3.4_028_235e38 do
MapSet.put(acc, :float32)
dtype =
case ExTorch.get_default_dtype() do
:float32 -> :float32
:float64 -> :float64
_ -> :float32
end

MapSet.put(acc, dtype)
end

def collect_types(float, acc) when is_float(float) do
MapSet.put(acc, :float64)
end

def collect_types(%ExTorch.Complex{real: re, imaginary: im}, acc) do
default_dtype =
case ExTorch.get_default_dtype() do
:float32 -> :complex64
:float64 -> :complex128
_ -> :complex64
end

choose_complex_dtype = fn
:complex64, :complex64 -> :complex64
_, :complex128 -> :complex128
end

re_dtype =
case re <= 3.4_028_235e38 do
true -> :complex64
true -> choose_complex_dtype.(:complex64, default_dtype)
false -> :complex128
end

im_dtype =
case im <= 3.4_028_235e38 do
true -> :complex64
true -> choose_complex_dtype.(:complex64, default_dtype)
false -> :complex128
end

Expand Down
2 changes: 2 additions & 0 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ defmodule ExTorch.MixProject do
# Run "mix help compile.app" to learn about applications.
def application do
[
mod: {ExTorch.Application, []},
extra_applications: [:logger, :ssl, :inets]
]
end
Expand Down Expand Up @@ -83,6 +84,7 @@ defmodule ExTorch.MixProject do
# This is useful for custom JS/CSS files you want to include.
before_closing_body_tag: &before_closing_body_tag/1,
groups_for_functions: [
{:"Per-process settings", &(&1[:kind] == :process_values)},
{:"Tensor information", &(&1[:kind] == :tensor_info)},
{:"Tensor creation", &(&1[:kind] == :tensor_creation)},
{:"Tensor manipulation", &(&1[:kind] == :tensor_manipulation)},
Expand Down
Loading

0 comments on commit 167caa4

Please sign in to comment.