Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom built optimizer for set_optimizer #154

Open
hmishfaq opened this issue Jul 9, 2023 · 1 comment
Open

custom built optimizer for set_optimizer #154

hmishfaq opened this issue Jul 9, 2023 · 1 comment

Comments

@hmishfaq
Copy link

hmishfaq commented Jul 9, 2023

I am trying to run BaggingRegressor with a custom built optimizer. But it seems it's not currently supported. So, I am doing the following

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as sp
from tqdm import tqdm
from copy import deepcopy
from langevin import LangevinMC
from torch.utils.data import Dataset, DataLoader
from torchensemble.bagging import BaggingRegressor


# Generate data
X = np.random.uniform(0.0, 0.5, 100).reshape(-1, 1)
noise = sp.norm(0.00, 0.02)
target_toy = lambda x: (
    x
    + 0.3 * np.sin(2 * np.pi * (x + noise.rvs(1)[0]))
    + 0.3 * np.sin(4 * np.pi * (x + noise.rvs(1)[0]))
    + noise.rvs(1)[0]
    - 0.5
)
y = np.array([target_toy(e) for e in X])

x_grid = np.linspace(-5, 5, 1000).reshape(-1, 1)


# Define trainable network
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(1, 16)
        self.fc2 = nn.Linear(16, 16)
        self.fc3 = nn.Linear(16, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


# convert into PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32).reshape(-1, 1)

# create DataLoader, then take one batch
train_loader = DataLoader(list(zip(X,y)), shuffle=True, batch_size=100)


class CustomBaggingRegressor(BaggingRegressor):
    
    def set_optimizer(self, optimizer_name, **kwargs):
        optimizer = LangevinMC(self.parameters(), lr = 1e-1,
                       beta_inv=1e-1, weight_decay=0)

        return optimizer
        

# Define the ensemble
model = CustomBaggingRegressor(
    estimator=Model,
    n_estimators=9,
    cuda=False,
)

# Set the criterion
criterion = nn.MSELoss()
model.set_criterion(criterion)

# Set the optimizer
model.set_optimizer(optimizer_name = 'LangevinMC', lr=1e-1, weight_decay=0)

# Train and Evaluate
model.fit(
    train_loader,
    epochs=10000,
)

However, when I try to run this code, I am getting the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
 in 
     22 
     23 # Set the optimizer
---> 24 model.set_optimizer(optimizer_name = 'LangevinMC', lr=1e-1, weight_decay=0)
     25 
     26 # Train and Evaluate

 in set_optimizer(self, optimizer_name, **kwargs)
      4 
      5     def set_optimizer(self, optimizer_name, **kwargs):
----> 6         optimizer = LangevinMC(self.parameters(), lr = 1e-1,
      7                        beta_inv=1e-1, weight_decay=0)
      8 

[~/Documents/Code/randomized_prior_function/langevin.py](https://file+.vscode-resource.vscode-cdn.net/Users/hmishfaq/Documents/Code/randomized_prior_function/~/Documents/Code/randomized_prior_function/langevin.py) in __init__(self, params, lr, beta_inv, sigma, weight_decay, device)
     42         self.curr_step = 0
     43         defaults = dict(weight_decay=weight_decay)
---> 44         super(LangevinMC, self).__init__(params, defaults)
     45 
     46     def init_map(self):

[/usr/local/lib/python3.8/site-packages/torch/optim/optimizer.py](https://file+.vscode-resource.vscode-cdn.net/usr/local/lib/python3.8/site-packages/torch/optim/optimizer.py) in __init__(self, params, defaults)
     45         param_groups = list(params)
...
---> 47             raise ValueError("optimizer got an empty parameter list")
     48         if not isinstance(param_groups[0], dict):
     49             param_groups = [{'params': param_groups}]

ValueError: optimizer got an empty parameter list

I am not sure how I can go around this. Any help would be appreciated. Also here is the custom optimizer:

import math
import torch
from torch.optim import Optimizer

from torch import Tensor
from typing import List


def lmc(params: List[Tensor],
        d_p_list: List[Tensor],
        weight_decay: float,
        lr: float):
    r"""Functional API that performs Langevine MC algorithm computation.
    """

    for i, param in enumerate(params):
        d_p = d_p_list[i]
        if weight_decay != 0:
            d_p = d_p.add_(param, alpha=weight_decay)

        param.add_(d_p, alpha=-lr)


class LangevinMC(Optimizer):
    def __init__(self,
                 params,              # parameters of the model
                 lr=0.01,             # learning rate
                 beta_inv=0.01,       # inverse temperature parameter
                 sigma=1.0,           # variance of the Gaussian noise
                 weight_decay=1.0,
                 device=None):   # l2 penalty
        if lr < 0:
            raise ValueError('lr must be positive')
        if device:
            self.device = device
        else:
            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.beta_inv = beta_inv
        self.lr = lr
        self.sigma = sigma
        self.temp = - math.sqrt(2 * beta_inv / lr) * sigma
        self.curr_step = 0
        defaults = dict(weight_decay=weight_decay)
        super(LangevinMC, self).__init__(params, defaults)

    def init_map(self):
        self.mapping = dict()
        index = 0
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is not None:
                    num_param = p.numel()
                    self.mapping[p] = [index, num_param]
                    index += num_param
        self.total_size = index

    @torch.no_grad()
    def step(self):
        self.curr_step += 1
        if self.curr_step == 1:
            self.init_map()

        lr = self.lr
        temp = self.temp
        noise = temp * torch.randn(self.total_size, device=self.device)

        for group in self.param_groups:
            weight_decay = group['weight_decay']

            params_with_grad = []
            d_p_list = []
            for p in group['params']:
                if p.grad is not None:
                    params_with_grad.append(p)

                    start, length = self.mapping[p]
                    add_noise = noise[start: start + length].reshape(p.shape)
                    delta_p = p.grad
                    delta_p = delta_p.add_(add_noise)
                    d_p_list.append(delta_p)
                    # p.add_(delta_p)
            lmc(params_with_grad, d_p_list, weight_decay, lr)
@xuyxu
Copy link
Member

xuyxu commented Sep 9, 2023

Hi @hmishfaq, sorry for the late response. See if adding the line super().set_optimizer(optimizer_name, **kwargs) as the first line to the function set_optimizer in the class CustomBaggingRegressor solves the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants