-
Notifications
You must be signed in to change notification settings - Fork 12
/
simutils.py
30 lines (24 loc) · 1.02 KB
/
simutils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
class LinearRegressionModel(nn.Module):
def __init__(self, p, weights = None, bias = None):
super(LinearRegressionModel, self).__init__()
self.linear = nn.Linear(p, 1)
if weights is not None:
self.linear.weight = Parameter(torch.Tensor([weights]))
if bias is not None:
self.linear.bias = Parameter(torch.Tensor([bias]))
def forward(self, x):
return self.linear(x)
class LogisticRegressionModel(nn.Module):
def __init__(self, p, weights = None, bias = None):
super(LogisticRegressionModel, self).__init__()
self.linear = nn.Linear(p, 1)
if weights is not None:
self.linear.weight = Parameter(torch.Tensor([weights]))
if bias is not None:
self.linear.bias = Parameter(torch.Tensor([bias]))
def forward(self, x):
return torch.sigmoid(self.linear(x))
# model_modules["Logistic"](3, (1,1,1), 0).forward(torch.zeros([1,3]))