Skip to content

Commit

Permalink
Fix Minibatch alignment in Bayesian Neural Network example + Pre-comm…
Browse files Browse the repository at this point in the history
…it hooks (pymc-devs#719)

* Fix Minibatch alignment in Bayesian Neural Network example

* Run: pre-commit run all-files

---------

Co-authored-by: Deepak CH <[email protected]>
  • Loading branch information
2 people authored and fonnesbeck committed Dec 20, 2024
1 parent 92fc463 commit f53b581
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,12 @@
" }\n",
"\n",
" with pm.Model(coords=coords) as neural_network:\n",
" ann_input = pm.Data(\"ann_input\", X_train, mutable=True)\n",
" ann_output = pm.Data(\"ann_output\", Y_train, mutable=True)\n",
" # Define minibatch variables\n",
" minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)\n",
"\n",
" # Define data variables using minibatches\n",
" ann_input = pm.Data(\"ann_input\", minibatch_x, mutable=True, dims=(\"obs_id\", \"train_cols\"))\n",
" ann_output = pm.Data(\"ann_output\", minibatch_y, mutable=True, dims=\"obs_id\")\n",
"\n",
" # Weights from input to hidden layer\n",
" weights_in_1 = pm.Normal(\n",
Expand All @@ -212,7 +216,8 @@
" \"out\",\n",
" act_out,\n",
" observed=ann_output,\n",
" total_size=Y_train.shape[0], # IMPORTANT for minibatches\n",
" total_size=X_train.shape[0], # IMPORTANT for minibatches\n",
" dims=\"obs_id\",\n",
" )\n",
" return neural_network\n",
"\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,12 @@ def construct_nn():
}
with pm.Model(coords=coords) as neural_network:
ann_input = pm.Data("ann_input", X_train, mutable=True)
ann_output = pm.Data("ann_output", Y_train, mutable=True)
# Define minibatch variables
minibatch_x, minibatch_y = pm.Minibatch(X_train, Y_train, batch_size=50)
# Define data variables using minibatches
ann_input = pm.Data("ann_input", minibatch_x, mutable=True, dims=("obs_id", "train_cols"))
ann_output = pm.Data("ann_output", minibatch_y, mutable=True, dims="obs_id")
# Weights from input to hidden layer
weights_in_1 = pm.Normal(
Expand All @@ -157,7 +161,8 @@ def construct_nn():
"out",
act_out,
observed=ann_output,
total_size=Y_train.shape[0], # IMPORTANT for minibatches
total_size=X_train.shape[0], # IMPORTANT for minibatches
dims="obs_id",
)
return neural_network
Expand Down

0 comments on commit f53b581

Please sign in to comment.