forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_autograd.py
113 lines (93 loc) · 4.03 KB
/
gen_autograd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
To run this file by hand from the root of the PyTorch
repository, run:
python -m tools.autograd.gen_autograd \
build/aten/src/ATen/Declarations.yaml \
aten/src/ATen/native/native_functions.yaml \
$OUTPUT_DIR \
tools/autograd
Where $OUTPUT_DIR is where you would like the files to be
generated. In the full build system, OUTPUT_DIR is
torch/csrc/autograd/generated/
"""
# gen_autograd.py generates C++ autograd functions and Python bindings.
#
# It delegates to the following scripts:
#
# gen_autograd_functions.py: generates subclasses of torch::autograd::Node
# gen_variable_type.py: generates VariableType.h which contains all tensor methods
# gen_python_functions.py: generates Python bindings to THPVariable
#
import argparse
import os
from tools.codegen.api import cpp
from tools.codegen.api.autograd import (
match_differentiability_info, NativeFunctionWithDifferentiabilityInfo,
)
from tools.codegen.gen import parse_native_yaml
from tools.codegen.selective_build.selector import SelectiveBuilder
from typing import List
from . import gen_python_functions
from .gen_autograd_functions import gen_autograd_functions_lib, gen_autograd_functions_python
from .gen_trace_type import gen_trace_type
from .gen_variable_type import gen_variable_type
from .gen_inplace_or_view_type import gen_inplace_or_view_type
from .gen_variable_factories import gen_variable_factories
from .load_derivatives import load_derivatives
def gen_autograd(
native_functions_path: str,
out: str,
autograd_dir: str,
operator_selector: SelectiveBuilder,
disable_autograd: bool = False,
) -> None:
# Parse and load derivatives.yaml
differentiability_infos = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
template_path = os.path.join(autograd_dir, 'templates')
native_funcs = parse_native_yaml(native_functions_path).native_functions
fns = list(sorted(filter(
operator_selector.is_native_function_selected_for_training,
native_funcs), key=lambda f: cpp.name(f.func)))
fns_with_diff_infos: List[NativeFunctionWithDifferentiabilityInfo] = match_differentiability_info(fns, differentiability_infos)
# Generate VariableType.h/cpp
if not disable_autograd:
gen_variable_type(out, native_functions_path, fns_with_diff_infos, template_path)
gen_inplace_or_view_type(out, native_functions_path, fns_with_diff_infos, template_path)
# operator filter not applied as tracing sources are excluded in selective build
gen_trace_type(out, native_funcs, template_path)
# Generate Functions.h/cpp
gen_autograd_functions_lib(
out, differentiability_infos, template_path)
# Generate variable_factories.h
gen_variable_factories(out, native_functions_path, template_path)
def gen_autograd_python(
native_functions_path: str,
out: str,
autograd_dir: str,
) -> None:
differentiability_infos = load_derivatives(
os.path.join(autograd_dir, 'derivatives.yaml'), native_functions_path)
template_path = os.path.join(autograd_dir, 'templates')
# Generate Functions.h/cpp
gen_autograd_functions_python(
out, differentiability_infos, template_path)
# Generate Python bindings
deprecated_path = os.path.join(autograd_dir, 'deprecated.yaml')
gen_python_functions.gen(
out, native_functions_path, deprecated_path, template_path)
def main() -> None:
parser = argparse.ArgumentParser(
description='Generate autograd C++ files script')
parser.add_argument('native_functions', metavar='NATIVE',
help='path to native_functions.yaml')
parser.add_argument('out', metavar='OUT',
help='path to output directory')
parser.add_argument('autograd', metavar='AUTOGRAD',
help='path to autograd directory')
args = parser.parse_args()
gen_autograd(args.native_functions,
args.out, args.autograd,
SelectiveBuilder.get_nop_selector())
if __name__ == '__main__':
main()