You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to run BaggingRegressor with a custom built optimizer. But it seems it's not currently supported. So, I am doing the following
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as sp
from tqdm import tqdm
from copy import deepcopy
from langevin import LangevinMC
from torch.utils.data import Dataset, DataLoader
from torchensemble.bagging import BaggingRegressor
# Generate data
X = np.random.uniform(0.0, 0.5, 100).reshape(-1, 1)
noise = sp.norm(0.00, 0.02)
target_toy = lambda x: (
x
+ 0.3 * np.sin(2 * np.pi * (x + noise.rvs(1)[0]))
+ 0.3 * np.sin(4 * np.pi * (x + noise.rvs(1)[0]))
+ noise.rvs(1)[0]
- 0.5
)
y = np.array([target_toy(e) for e in X])
x_grid = np.linspace(-5, 5, 1000).reshape(-1, 1)
# Define trainable network
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(1, 16)
self.fc2 = nn.Linear(16, 16)
self.fc3 = nn.Linear(16, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# convert into PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1)
# create DataLoader, then take one batch
train_loader = DataLoader(list(zip(X,y)), shuffle=True, batch_size=100)
class CustomBaggingRegressor(BaggingRegressor):
def set_optimizer(self, optimizer_name, **kwargs):
optimizer = LangevinMC(self.parameters(), lr = 1e-1,
beta_inv=1e-1, weight_decay=0)
return optimizer
# Define the ensemble
model = CustomBaggingRegressor(
estimator=Model,
n_estimators=9,
cuda=False,
)
# Set the criterion
criterion = nn.MSELoss()
model.set_criterion(criterion)
# Set the optimizer
model.set_optimizer(optimizer_name = 'LangevinMC', lr=1e-1, weight_decay=0)
# Train and Evaluate
model.fit(
train_loader,
epochs=10000,
)
However, when I try to run this code, I am getting the following error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
in
22
23 # Set the optimizer
---> 24 model.set_optimizer(optimizer_name = 'LangevinMC', lr=1e-1, weight_decay=0)
25
26 # Train and Evaluate
in set_optimizer(self, optimizer_name, **kwargs)
4
5 def set_optimizer(self, optimizer_name, **kwargs):
----> 6 optimizer = LangevinMC(self.parameters(), lr = 1e-1,
7 beta_inv=1e-1, weight_decay=0)
8
[~/Documents/Code/randomized_prior_function/langevin.py](https://file+.vscode-resource.vscode-cdn.net/Users/hmishfaq/Documents/Code/randomized_prior_function/~/Documents/Code/randomized_prior_function/langevin.py) in __init__(self, params, lr, beta_inv, sigma, weight_decay, device)
42 self.curr_step = 0
43 defaults = dict(weight_decay=weight_decay)
---> 44 super(LangevinMC, self).__init__(params, defaults)
45
46 def init_map(self):
[/usr/local/lib/python3.8/site-packages/torch/optim/optimizer.py](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.8/site-packages/torch/optim/optimizer.py) in __init__(self, params, defaults)
45 param_groups = list(params)
...
---> 47 raise ValueError("optimizer got an empty parameter list")
48 if not isinstance(param_groups[0], dict):
49 param_groups = [{'params': param_groups}]
ValueError: optimizer got an empty parameter list
I am not sure how I can go around this. Any help would be appreciated. Also here is the custom optimizer:
import math
import torch
from torch.optim import Optimizer
from torch import Tensor
from typing import List
def lmc(params: List[Tensor],
d_p_list: List[Tensor],
weight_decay: float,
lr: float):
r"""Functional API that performs Langevine MC algorithm computation.
"""
for i, param in enumerate(params):
d_p = d_p_list[i]
if weight_decay != 0:
d_p = d_p.add_(param, alpha=weight_decay)
param.add_(d_p, alpha=-lr)
class LangevinMC(Optimizer):
def __init__(self,
params, # parameters of the model
lr=0.01, # learning rate
beta_inv=0.01, # inverse temperature parameter
sigma=1.0, # variance of the Gaussian noise
weight_decay=1.0,
device=None): # l2 penalty
if lr < 0:
raise ValueError('lr must be positive')
if device:
self.device = device
else:
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.beta_inv = beta_inv
self.lr = lr
self.sigma = sigma
self.temp = - math.sqrt(2 * beta_inv / lr) * sigma
self.curr_step = 0
defaults = dict(weight_decay=weight_decay)
super(LangevinMC, self).__init__(params, defaults)
def init_map(self):
self.mapping = dict()
index = 0
for group in self.param_groups:
for p in group['params']:
if p.grad is not None:
num_param = p.numel()
self.mapping[p] = [index, num_param]
index += num_param
self.total_size = index
@torch.no_grad()
def step(self):
self.curr_step += 1
if self.curr_step == 1:
self.init_map()
lr = self.lr
temp = self.temp
noise = temp * torch.randn(self.total_size, device=self.device)
for group in self.param_groups:
weight_decay = group['weight_decay']
params_with_grad = []
d_p_list = []
for p in group['params']:
if p.grad is not None:
params_with_grad.append(p)
start, length = self.mapping[p]
add_noise = noise[start: start + length].reshape(p.shape)
delta_p = p.grad
delta_p = delta_p.add_(add_noise)
d_p_list.append(delta_p)
# p.add_(delta_p)
lmc(params_with_grad, d_p_list, weight_decay, lr)
The text was updated successfully, but these errors were encountered:
Hi @hmishfaq, sorry for the late response. See if adding the line super().set_optimizer(optimizer_name, **kwargs) as the first line to the function set_optimizer in the class CustomBaggingRegressor solves the problem.
I am trying to run
BaggingRegressor
with a custom built optimizer. But it seems it's not currently supported. So, I am doing the followingHowever, when I try to run this code, I am getting the following error:
I am not sure how I can go around this. Any help would be appreciated. Also here is the custom optimizer:
The text was updated successfully, but these errors were encountered: