SGD in LMS Adaptive Filtering vs SGD for Deep Learning

Robert 2023-04-02 (16:29)

Howard, Jeremy and Gugger, Sylvain. Deep Learning for Coders with fastai and PyTorch. First edition, rel. 2021-11-05. Boston: O'Reilly Media, 2020.

Whew! Chapter 4 of the Deep Learning book was a bear. Titled “Under the Hood: Training a Digit Classifier,” the chapter walks through the steps of implementing and training a “linear optimizer”–basically a single layer of a neural network–via stochastic gradient descent (SGD).

I’m writing this article to solidify my thoughts on SGD. As it turns out, I’m already familiar with SGD; in old school signal processing, SGD is basically how adaptive filters are trained via least mean square (LMS) adaptation. I’m going to explore how the two algorithms are related, then look at what role PyTorch plays.

LMS Adaptive Filtering

Acoustic echo cancellation is a typical LMS adaptive filter scenario:

                      +-----------+ x[n]                 x[n]
                      |           |
           |----------|  loudspkr |<--------|------------------
           |          |           |         |
           |          +-----------+         |
           |                                |    ^
           v                                v     |
  +--------------+            +----------------------+
  | room response|            |Adaptive Echo Canceller
  |              |            |estimate w[n]         |
  |              |            +----------------------+\
  +--------------+                          |          \
           |                                | y^[n]     \ feedback
           |         +--------+             |            \
           |         |        |             v             \
           |         |        |         +-----+            \
           --------> |  mic   | ------->| diff| --------------------->
                     |        |  y[n]   |     |      e[n]
                     +--------+         +-----+

Given a current guess of the channel, \(\mathbf{w}[n]\), and the channel input \(\mathbf{x}[n]\), I compute the corresponding estimate of the channel output:

\[ \begin{equation} \hat{y}[n]= \mathbf{w[n]}^T \mathbf{x[n]} \end{equation} \] where

In other words, the adaptive echo canceller tries to model a linear channel response \(\mathbf{h}[n]\) with a FIR filter \(\mathbf{w}[n]\). It accomplishes this by guessing \(\mathbf{w}\), examining the difference between the actual and estimated filter output, then updating \(\mathbf{w}\) to nudge it closer to the actual value. The correction process–the so-called adaptation–is SGD.

In LMS, I first need to choose a cost function which specifies how far off my estimated signal \(\hat{y}[n]\) is from the actual \(y[n]\). The usual procedure is to pick cost function \(J[n]\) which is the squared difference between \(y[n]\) and \(\hat{y}[n]\), e.g. \[ \begin{equation} J[n]= |\quad e[n]\quad|^2 \end{equation} \] where

\[ \begin{equation} e[n] = y[n] - \hat{y}[n] \end{equation} \] Finally, I can nudge \(\mathbf{w}[n]\) towards the optimum answer by walking the vector in the opposite direction of the gradient of \(J[n]\):

\[ \begin{equation} \mathbf{w[n+1]} = \mathbf{w[n]} - \mu(\mathbf{\nabla J[n]}) \end{equation} \] where

In most explanations, \(J[n]\) is usually expressed as the average error over time and then they assert that they’re looking at an instantaneous result and the average gets tossed. I’ll just dispense with the averaging right at the start.

The tricky bit is computing the gradient of \(J[n]\):

\[ \begin{equation} \nabla \mathbf{J}[n]= \dfrac{\partial J[n]}{\partial \mathbf{w}} \end{equation} \]

\(J[n]\) is a “composition” of functions: \(J[n]\) is function of \(e[n]\), and \(e[n]\) is a function of \(\mathbf{w}[n]\). Thus I need to use the chain rule \[ \frac{ df(g(x))}{dx} = f'(g(x)) g'(x)) \]

Applying the chain rule to \(J[n]\) (and dropping the time index n for now:): \[ \begin{align} \nabla \mathbf{J} &= \dfrac{\partial J}{\partial \mathbf{w}} \\ & = \dfrac{\partial |e|^2}{\partial \mathbf{w}} \\ & = \dfrac{(2 e)} {\partial \mathbf{w}} ( \dfrac{ e}{\partial \mathbf{w}} ) \end{align} \]

Plugging in the definition of \(e[n]\) yields: \[ \begin{align} \dfrac{\partial J}{\partial \mathbf{w}} &= 2 e (\dfrac{y - \hat{y}}{\partial \mathbf{w}})\\ &= 2e (\dfrac{y - \mathbf{w}^Tx}{\partial \mathbf{w}}) \\ &= 2e (0 - x) \\ &= 2(y - \mathbf{w}^Tx )(-x) \\ &= -2(y - \mathbf{w}^Tx) x \\ &= -2(y - \hat{y}) x\\ &= -2ex \\ \end{align} \]

Plugging this back into my iterative function, I finally get my familiar friend, the LMS update equation: \[ \boxed{ \begin{align} \mathbf{w}[n+1] &= \mathbf{w}[n] - \mu(\mathbf{\nabla J[n]}) \\ &= \mathbf{w}[n] + 2 \mu e[n]\mathbf{x}[n] \\ \end{align} } \]

SGD with a Second Order Polynomial

In chapter 4, Jeremy and Sylvain walk through the example of using SGD to estimate the coefficients of a 2nd order polynomial (representing the velocity over time of a roller coaster over a hump): \[ v[t] = v_0t^2 + v_1t + v_0 \] where

I’m going to rewrite \(v[t]\) in matrix notation: \[ v[t] = v^T x \] where

The goal is to iteratively estimate a vector \(\hat{v}\) such that \[ \hat{v}[t] \approx v[t] \]

As in the LMS algorithm, I choose a cost function which tells me how far away the estimate is from the actual (\(\hat{v}[t] - v[t]\))–except now, because we’re in Deep Learning Land, I call it a “loss” function \(L[t]\): \[ L[t] = E[(v[t] - \hat{v}[t])^2] \] The \(E[\quad]\) denotes that we’re looking at the mean difference over a range of t values. This is important because I’ll apply a single update only after running N points of data (a “mini batch”) through the model; the \(E[ \quad ]\) operator boils multiple results down to a single result.

Just like before, I will iterate the estimate toward the correct value by stepping it in the opposite direction of the gradient of L: \[ \hat{v}[t+1] = \hat{v}[t] + \mu \nabla L[t] \] where \(\mu\) is the scalar learning rate.

Sharpening my pencil and getting out the really big eraser, I again apply the chain rule and work out \(\nabla L[t]\) – dropping the time index t for now: \[ \begin{align} \nabla L &= \dfrac{\partial ( E[(v - \hat{v})^2])}{\partial \hat{v}} && \text{E[] is a linear op, so move it outside}\\ &= E[(\dfrac { (v - \hat{v})^2} {\partial \hat{v}} ) \dfrac{ (v-\hat{v}^Tx)} {\partial \hat{v}} ] && \text{chain rule} \\ &= E[2(v - \hat{v}^Tx)] (0-x)] \\ &= -2 * E[(v - \hat{v}^Tx) x ] && \text{note that term in parens is scalar} \end{align} \] Now I expand the \(E[\quad]\) operator by assuming a sample size of N and reintroduce the time indices \(t_0...t_{N-1}\): $$ \[\begin{align} \nabla L &= -2 * E[(v - \hat{v}^Tx) x ] \\ &= \dfrac{-2}{N} [(v(t_0) - \hat{v}^Tx_0)x_0 + (v(t_1) - \hat{v}^Tx_1)x_1 + \ldots ] \\ \end{align}\] $$

Let \(\epsilon(t_i) \equiv v(t_i) - \hat{v}^Tx_i\), which is a scalar value. Then I get:

\[ \nabla L = \dfrac{-2}{N} [(\epsilon(t_0)x_0 + \epsilon(t_1)x_1 + \ldots] \]

Plugging this definition into the iterative function, I get: \[ \boxed{ \begin{align} \hat{v}[t+1] &= \hat{v}[t] + \mu \nabla L[t] \\ &= \hat{v}[t] - \mu \dfrac{2}{N} [(\epsilon(t_0)x_0 + \epsilon(t_1)x_1 + \ldots] \end{align} } \] In other words, the update term is the vector \(x_i \equiv [t_i^2 \quad t_i \quad 1]\) scaled by the difference \(\epsilon(t_i)\), then averaged over the N points in the mini-batch.

This result looks a great deal like my previous result for the LMS gradient, \(\dfrac{\partial J}{\partial \mathbf{w}} = -2xe\), only averaged for N points.

The Role of PyTorch in Gradient Computation

The tricky bit in the above example was calculating the gradient of the loss function, \(\nabla L\). I can imagine that for different applications, the loss function will differ. I don’t want to work out the gradient equation by hand for each case!

Fortunately, PyTorch lets me define tensors which automatically calculate their own gradients using Tensor.requires_grad_(). This “autograd” feature saves me lots of effort! The strange part, at least for me, are the mechanics of the gradient-computation: if tensor y contains the result of a string of calculations based on tensor x, e.g. y=myfunc(x), then I get the gradient by calling y.backward() and the results appear in x.grad. It’s weird: I call a method on one object, the results are magically glommed onto another object. I see the value, however!

URL to the buttersquid diagram:

https://buttersquid.ink/?spore=bNobwRALmBcYDYEsB2BTMAaMAPAjDALAJyZYBMMAbAAy
YCee0xYt50pNYKAhjDpigCMYHOAHssBTGNoxSmCHBjg4UWADMUKACYCuAYwDWGeADdeAX3Ppwqs
ACcUeqCVl1efMAHdeAVkwALSTA1PRgwAAJjNTVhKXEYAFoKOJloBI8FJXhbO1FRAFtwhwBnAAdR
JGKUAB0kYzgzaBxLa0gwxFRjXBgAZj9sVgo5ZgZk5kH+Hib+IWgReMYUxOHM6GVbLGAkAF16xua
rGzCHJy7ZYdT8TG8m4cDoK+DQ2EjMaNj4BYSepbn5RTW2XaogArloygY7HsLIc2rAOmgSAwcB4y
LIPPRKHQJpwphxBB8xBImv1pB9VuswtUQVRaXgpI0qC0jvDkIjsMjUawUXRRtjePi8TNeHFiThS
aJUhwKUDYNTaVR6aZhMy4fA2V1OSRuRjkWMWLJBR8CX9PmKJVL-lkVFSaXToXNVbYEZr3NreExM
WwOAbvZNjbMPETeAAOX4ZAGU2C0La7Bkw1o5RzObDotxNADs1xgTHu-RCYVewRipuDaWGZNNMpt
sHyCFC8aaTvaGqRvWGaOgPU9DHY-L9uIDIrNvl+0sjsrA8vtjaZsOdrY5BEenfwjy9+u5RtNJqD
C3FY6tgJrU7tiodc8Tx2TZy7PrdXlkPzA92fBZeURL82J6T3qXSR5RmAACCWhcKUEAICYKDhAAo
no-iiOEADCXBIHoKBwHAKB2LUKDFJB+RcBAMGeLGDoHFesAnCmxL4PeTRMDcYyvm8zwRJ+hJfH+
iQRtathaAg0S1BRzasp0bYPCurBrryo7jKG-o7oGoquJ8lqQBOJ60AAeuRjaUSy6oSUuDxhgMlA
9oa-Z9oOynDmWCQWsO1a2Cg+nKk285hBAAjxC4Xb9Kkz4QFgGweY0ahcHAVTmNsQA