-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathclogs-and-leaks.html
258 lines (245 loc) · 14.6 KB
/
clogs-and-leaks.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
<!DOCTYPE html>
<html>
<head>
<link rel="canonical" href="https://hardmath123.github.io/clogs-and-leaks.html"/>
<link rel="stylesheet" type="text/css" href="/static/base.css"/>
<title>Clogs and Leaks: Why My Tensors Won't Flow - Comfortably Numbered</title>
<meta http-equiv="Content-Type" content="text/html; charset=UTF-8"/>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0, maximum-scale=1.0, user-scalable=no" />
<link rel="alternate" type="application/rss+xml" title="Comfortably Numbered" href="/feed.xml" />
<!--
<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script>
MathJax.Hub.Config({
tex2jax: {inlineMath: [['$','$']]}
});
</script>
-->
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css" integrity="sha384-Um5gpz1odJg5Z4HAmzPtgZKdTBHZdw8S29IecapCSB31ligYPhHQZMIlWLYQGVoc" crossorigin="anonymous">
<script defer src="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.js" integrity="sha384-YNHdsYkH6gMx9y3mRkmcJ2mFUjTd0qNQQvY9VYZgQd7DcN7env35GzlmFaZ23JGp" crossorigin="anonymous"></script>
<script defer src="https://cdn.jsdelivr.net/npm/[email protected]/dist/contrib/auto-render.min.js" integrity="sha384-vZTG03m+2yp6N6BNi5iM4rW4oIwk5DfcNdFfxkk9ZWpDriOkXX8voJBFrAO7MpVl" crossorigin="anonymous"></script>
<script>
document.addEventListener("DOMContentLoaded", function() {
renderMathInElement(document.body, {
// customised options
// • auto-render specific keys, e.g.:
delimiters: [
{left: '$$', right: '$$', display: true},
{left: '$', right: '$', display: false},
{left: '\\begin{align}', right: '\\end{align}', display: true},
{left: '\\(', right: '\\)', display: false},
{left: '\\[', right: '\\]', display: true}
],
// • rendering keys, e.g.:
throwOnError : false
});
});
</script>
</head>
<body>
<header id="header">
<script src="static/main.js"></script>
<div>
<a href="/"><span class="left-word">Comfortably</span> <span class="right-word">Numbered</span></a>
</div>
</header>
<article id="postcontent" class="centered">
<section>
<h1>Clogs and Leaks: Why My Tensors Won't Flow</h1>
<center><em><p>Defining a class of tricky bugs in PyTorch programs</p>
</em></center>
<h4>Monday, November 23, 2020 · 5 min read</h4>
<blockquote>
<p><em>Note:</em> Jekyll and Hyde — is how this blog goes. There are weeks of
rational thought and weeks of irrational ramblings. This fall has been much
of the latter, for good reason, but here is a break(!) in the clouds.</p>
<p>Okay, okay. The truth is that I wrote this essay for my CS class this quarter
(leave it to Pat to assign an essay for a CS class!). But then, I think I
have reached the age where every written assignment in college is to be
treated as an opportunity to say something I otherwise would not have a
chance to say… and so the essay, lightly edited, finds its way to this
blog.</p>
</blockquote>
<hr>
<p>In this essay I want to describe two kinds of tricky bugs that might creep into
your PyTorch programs. I call the bugs “clogs” and “leaks.” In my mind “clogs”
and “leaks” reveal an exciting possible research direction for anyone
interested in designing better APIs for automatic differentiation in machine
learning.</p>
<p><em>Note: The examples presented, though in principle timeless, were tested using
Python 3.8.5 running PyTorch 1.6.0.</em></p>
<h2 id="preliminaries-pytorch-s-pipes">Preliminaries: PyTorch’s Pipes</h2>
<p>If you are familiar with PyTorch internals, you can skip this section. If not,
a brief review of a wonderful topic: how does PyTorch differentiate your code
for gradient descent? The technical term for PyTorch’s approach is <em>tape-based
reverse-mode automatic differentiation</em>. As you perform arithmetic computations
on your variables, PyTorch tracks the intermediate values in a computation
graph. When you want to differentiate a value, you call <code>.backward()</code> on that
value. PyTorch then walks <em>backwards</em> along this computation graph, computing
the derivative at each step and accumulating them according to the chain rule.
Eventually, the leaf nodes of the graph contain the derivatives you asked for.</p>
<p>Let me give a small example. Suppose we wanted to compute $d2x^2/dx|_{x=3}$.
We might write a program that looks like this:</p>
<pre><code class="lang-python">x = torch.tensor(3., requires_grad=True)
y = 2 * x**2
y.backward()
print(x.grad)
</code></pre>
<p>This program generates the following computation graph.</p>
<p><img src="static/clogs-and-leaks/forward.png" alt="Forward"></p>
<p>When you call <code>.backward()</code>, the PyTorch automatic differentiation walks
backwards along the graph, computing derivatives at <em>each</em> step. By the chain
rule, the product of these gives the overall derivative we sought.</p>
<p><img src="static/clogs-and-leaks/backward.png" alt="Forward"></p>
<h2 id="clogs">Clogs</h2>
<p>Now, consider this simple PyTorch program to compute $d(\sqrt{x} +
x)/dx|_{x=4}$. What do you expect to be printed?</p>
<pre><code class="lang-python">x = torch.tensor(4., requires_grad=True)
y = sqrt(x) + x
y.backward()
print(x.grad)
</code></pre>
<p>A casual user or AP calculus student would <em>expect</em> to see 1.25 printed, of
course. But what <em>actually</em> gets printed is 1. Why?</p>
<p>Ah! I didn’t show you the full program: I hid the imports. It turns out that
the first line of this program is <code>from math import sqrt</code>, <em>not</em> <code>from torch
import sqrt</code>. Now, the Python standard library’s <code>math.sqrt</code> is not a
PyTorch-differentiable function, and so PyTorch is unable to track the flow of
derivatives through <code>sqrt(x)</code>.</p>
<p>As a result of this bug, backpropagation gets “stuck” on the way back, and only
the derivative of <code>x</code>, i.e. 1, is deposited. This is a clog — the gradients
can’t flow! In the computation graph below, the dotted arrow represents the
clog.</p>
<p><img src="static/clogs-and-leaks/clog.png" alt="Clog graph"></p>
<p>The reason calling <code>math.sqrt()</code> on a PyTorch tensor is not a runtime error is
that PyTorch tensors implicitly convert to “raw” floating-point numbers as
needed. Most of the time this is a useful and indispensable feature. But I
believe this situation should <em>at the very least</em> raise an error or a warning.
While the example I presented was reasonably straightforward, there are <em>many</em>
different ways to “clog” backpropagation, with varying degrees of insidiousness
(for example, what happens when you mutate a variable in place?). It can be a
nightmare to debug such situations when something goes wrong — that is, if you
notice the bug in the first place!</p>
<p><em>By the way:</em> the celebrated “reparametrization trick” that powers variational
autoencoders is really just a workaround for a gradient clog problem. To train
a variational autoencoder, you need to compute the derivative of a sample of a
probability distribution with respect to the distribution’s parameters (e.g.
the mean $\mu$ and variance $\sigma^2$ of a Gaussian distribution).
Unfortunately, naïvely sampling from a parametrized distribution abruptly
truncates the computation graph with respect to the parameters, because the
random number generator is not differentiable all the way through — who <em>knows</em>
what <em>it’s</em> doing! The solution, is to sample from a standard unit normal
distribution (where $\mu=0$ and $\sigma=1$, and then re-scale the sample
by multiplying by $\sigma$ and adding $\mu$. Of course, multiplication and
addition <em>are</em> easily differentiable, and so the gradients can now flow.
Problem solved!</p>
<h2 id="leaks">Leaks</h2>
<p>Now, consider this slightly more complicated PyTorch program. We are going to
implement a silly reinforcement learning algorithm. Here is the situation:
There is a truck driving on the road with constant velocity, and your goal is
to catch up to it and drive right alongside the truck. At each timestep you are
allowed to choose your velocity, and then you’re told how far you are from the
truck.</p>
<p>The setup:</p>
<pre><code class="lang-python">truck_velocity = torch.tensor(3.142)
truck_position = torch.tensor(2.718)
def get_measurement(car_position):
global truck_position
truck_position = truck_position + truck_velocity
return torch.abs(truck_position - car_position)
</code></pre>
<p>And a simple online gradient-based learning algorithm:</p>
<pre><code class="lang-python">my_velocity = torch.tensor(0.01)
my_position = torch.tensor(0.)
for i in range(500):
my_velocity.requires_grad_()
my_position = my_position + my_velocity
loss = get_measurement(my_position)
loss.backward()
my_velocity =\
my_velocity.detach() - my_velocity.grad * 0.01
</code></pre>
<p>Unlike last time, there’s nothing up my sleeve here — this is all reasonable
PyTorch code. This code actually works just fine.</p>
<p>But, if you run it for long enough (say, 1000 iterations), you’ll notice
something odd: each step starts taking longer and longer. The algorithm is
<a href="https://accidentallyquadratic.tumblr.com">accidentally quadratic</a>! You can see
this behavior quite clearly in this graph, which shows a linear growth in
iteration time from step to step (the spikes are garbage collection pauses).</p>
<p><img src="static/clogs-and-leaks/graph.png" alt="Time graph, essentially linear and
increasing"></p>
<p>How can this be? Isn’t each loop doing the same calculation?</p>
<p>Here is one hypothesis: if you’ve read <a href="https://arxiv.org/abs/1909.13371">this
paper</a> you might look to see if we’re
<code>.detach()</code>-ing <code>my_velocity</code>. The <code>.detach()</code> function snips off all incoming
edges to a node in the computation graph; essentially, creating an artificial
clog. If we forget to do that, the gradients would “leak” back in time across
multiple steps in the graph, all the way back to the first step, and each
iteration would therefore take longer and longer — just as we’re observing.</p>
<p>But, alas, this is not the source of the bug: as you can see, we <em>are</em>
detaching <code>my_velocity</code> when we update it. So, what’s really going on here?</p>
<p>It’s tricky! The leak is in <code>my_position</code>, which subtly depends on <em>all</em>
previous values of <code>my_velocity</code> and therefore makes backpropagation compute
gradients for <em>all</em> previous timesteps. The dataflow diagram below hopefully
clarifies this point. Notice how each <code>velocity</code> has its parent nodes detached
(thanks to the call to <code>.detach()</code>!), but <code>loss</code> still has an indirect
dependence on the chain of <code>positions</code>.</p>
<p><img src="static/clogs-and-leaks/leak.png" alt="Leak graph"></p>
<p>Finding the correct place to insert the line <code>my_position =
my_position.detach()</code> is left as a not-quite-trivial exercise to the reader.
Beware! Putting it in the <em>wrong</em> place will either have no effect <em>or</em> cause
<code>my_velocity</code> to always have gradient 0.</p>
<p>Just like memory leaks, gradient leaks can be extremely sneaky. They pop up
whenever your inference is “stateful” — think of applications like physics
controllers, reinforcement learning, animated graphics, RNNs, and so on. I
would not be surprised if many popular implementations of such algorithms <em>do</em>
have “gradient leak” bugs. However, the bugs usually only manifest themselves
visibly when the inference passes through enough timesteps for the leak to
compound. Just like a dripping tap, you might not notice your losses until you
get the bill at the end of the month… and then, you need to figure out how to
track down the source of the leak and figure out the right way to fix it.</p>
<h2 id="plungers-and-patches-an-appeal-for-plumbing-">Plungers and patches? An appeal for PLumbing…</h2>
<p>In the long term, how can we protect ourselves from this class of bugs? One
potential solution is to embed the API inside a language whose type system
tracks the creation of the computation graph. You might be able to use
well-understood techniques like <em>taint analysis</em> or <em>linear types</em> (pun not
intended) which traditionally track the flow of <em>information</em>, to now track the
flow of <em>differentiability</em> through the program.</p>
<p>Let me be slightly more concrete about this suggestion. In our “clog” example,
a good type system might detect that <code>sqrt</code> cuts off the computation graph,
and, knowing that $y$ does not directly depend on $x$ in the expected way
anymore, complain at compile-time when we try to request $dy/dx$. In our
“leak” example, a good type system might notice that the “old” <code>my_position</code>
effectively goes out of scope when it is re-assigned, and therefore it might
complain that an unreachable reference to it actually does persist through the
new <code>my_position</code>. Such checks seem very reasonable to demand from a modern
type system.</p>
</section>
<div id="comment-breaker">◊ ◊ ◊</div>
</article>
<footer id="footer">
<div>
<ul>
<li><a href="https://github.com/kach">
Github</a></li>
<li><a href="feed.xml">
Subscribe (RSS feed)</a></li>
<li><a href="https://twitter.com/hardmath123">
Twitter</a></li>
<li><a href="https://creativecommons.org/licenses/by-nc/3.0/deed.en_US">
CC BY-NC 3.0</a></li>
</ul>
</div>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','//www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-46120535-1', 'hardmath123.github.io');
ga('require', 'displayfeatures');
ga('send', 'pageview');
</script>
</footer>
</body>
</html>