diff --git a/examples/kernel_function.py b/examples/kernel_function.py index 0ac0528..413edb6 100755 --- a/examples/kernel_function.py +++ b/examples/kernel_function.py @@ -216,6 +216,21 @@ def forward(self, x): x = self.fc3(x) # Hidden layer 2 to output layer return x + def set_weights_biases(self, weights, biases): + weights = [torch.Tensor(entry) for entry in weights] + biases = [torch.Tensor(entry) for entry in biases] + with torch.no_grad(): + model.fc1.weight = nn.Parameter(weights[0]) + model.fc1.bias = nn.Parameter(biases[0]) + + model.fc2.weight = nn.Parameter(weights[1]) + model.fc2.bias = nn.Parameter(biases[1]) + + model.fc3.weight = nn.Parameter(weights[2]) + model.fc3.bias = nn.Parameter(biases[2]) + + + # Initialize the network #model = WarpNet()