forked from zxjzxj9/PyTorchIntroduction
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
35 lines (27 loc) · 844 Bytes
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# 静态加载
import torch
import gelu
# 同样可以通过 gelu = GELU.apply使用这个激活函数
class GELU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.input = input
return gelu.forward(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.input
return gelu.backward(grad_output, input)
# 动态加载
import torch
from torch.utils.cpp_extension import load
# PyTorch会进行自动编译,生成对应的模块
gelu = load(name="gelu", sources=["gelu/gelu.cc"])
class GELU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.input = input
return gelu.forward(input)
@staticmethod
def backward(ctx, grad_output):
input = ctx.input
return gelu.backward(grad_output, input)