Neural Proximal Operators

← Back to Knowledge Share

This is the exact conceptual leap that birthed the Plug-and-Play (PnP) and Regularization by Denoising (RED) frameworks, revolutionizing how we solve general noisy inverse problems.

To understand why a generative model can seamlessly replace a line of math in an optimization algorithm, we have to look at the \(z\)-update through a Bayesian lens. It all hinges on the mathematical definition of a Proximal Operator.

1. The \(z\)-update is a Proximal Operator

Let's look at the standard ADMM \(z\)-update step again. To make it cleaner, let's define our incoming, frozen data as \(v = x^{k+1} + u^k\).

Assuming \(B = -I\) (meaning our constraint is \(x = z\)), the update is:

\[z^{k+1} = \underset{z}{\text{argmin}} \left( g(z) + \frac{\rho}{2}\|z - v\|_2^2 \right)\]

In optimization, this specific structure—minimizing a function plus a squared distance to a point \(v\)—is called the Proximal Operator of the function \(g\), scaled by \(\frac{1}{\rho}\).

\[z^{k+1} = \text{prox}_{g/\rho}(v)\]

2. The Bayesian Translation: MAP Denoising

Now, let's look at that exact same equation from the perspective of Bayesian probability and Maximum A Posteriori (MAP) estimation.

Imagine I hand you a noisy image \(v\), and I tell you it was created by taking a clean image \(z\) and adding Additive White Gaussian Noise (AWGN) with a variance of \(\sigma^2\). I ask you to find the most likely clean image \(z\).

Using Bayes' theorem, you want to maximize the posterior probability \(p(z|v)\). Taking the negative logarithm to turn it into a minimization problem, you get two terms:

  1. The Likelihood (\(-\log p(v|z)\)): For Gaussian noise, this is exactly proportional to \(\frac{1}{2\sigma^2}\|z - v\|_2^2\).
  2. The Prior (\(-\log p(z)\)): This is the probability distribution of clean images.

So, the MAP denoising problem is mathematically written as:

\[z_{clean} = \underset{z}{\text{argmin}} \left( -\log p(z) + \frac{1}{2\sigma^2}\|z - v\|_2^2 \right)\]

3. The Leap: Math Equals Denoising

Look closely at the ADMM \(z\)-update and the MAP Denoising equation. They are completely identical.

  • \(g(z)\) is mathematically equivalent to \(-\log p(z)\) (the negative log-likelihood of your data manifold).
  • \(\rho\) is mathematically equivalent to \(\frac{1}{\sigma^2}\) (the inverse of the noise variance).

This means that evaluating the proximal operator mathematically is literally the exact same thing as running a Gaussian denoiser on the image \(v\).

For decades, researchers had to manually invent analytical formulas for \(g(z)\) so they could calculate the derivative—like Total Variation (which assumes images are piecewise flat) or L1-wavelet sparsity. But these hand-crafted math formulas are terrible at describing the complex, highly non-linear manifold of natural images.

4. Enter Generative Models

If the \(z\)-update is just a "denoise this image" step, we don't need an analytical math formula for \(g(z)\) at all. We can implicitly define the prior by using whatever state-of-the-art tool is best at mapping a noisy vector back to the clean data manifold.

This is where modern generative models fit perfectly:

  • Diffusion Models & Score Matching: By Tweedie's Formula, a denoiser is fundamentally estimating the score function \(\nabla_z \log p(z)\). Diffusion models are explicitly trained to estimate this exact score vector at varying noise levels. Placed inside the \(z\)-update, a single reverse-diffusion step acts as a vastly superior proximal operator.
  • Flow Matching: Similarly, flow matching models learn continuous normalizing flows (vector fields) that transport probability mass from a simple base distribution (like Gaussian noise) to the complex data distribution. This learned transport acts as the ultimate data-driven regularizer.
  • VAEs: A VAE projects the noisy input \(v\) into a learned, constrained latent space and decodes it, effectively forcing the output \(z\) to lie entirely on the manifold of "realistic" data, acting as a powerful implicit projection operator.

In short: You don't plug a denoiser in because it's a heuristic hack. You plug it in because the ADMM algorithm mathematically *asks* for a denoiser, and a generative neural network is simply the best denoiser humanity has built so far.


Proof: Why \(-\log p(v|z) \propto \frac{1}{2\sigma^2} \|z - v\|_2^2\)

This relationship is one of the most fundamental connections in machine learning and optimization. It proves that minimizing the Mean Squared Error (or L2 norm) is mathematically identical to assuming your noise is Gaussian.

Here is the step-by-step derivation of exactly how the multivariate Gaussian probability density function (PDF) turns into that simple quadratic penalty.

1. The Forward Model

Assume we have a clean signal \(z\) (e.g., a vector of \(N\) pixels). We measure a noisy version of it, \(v\). We assume the noise added to the signal is Additive White Gaussian Noise (AWGN).

We can write this as:

\[v = z + \eta\]

Where the noise vector \(\eta\) is drawn from a normal distribution with a mean of \(0\) and a variance of \(\sigma^2\) for every pixel. In multidimensional terms, its covariance matrix is \(\Sigma = \sigma^2 I\) (where \(I\) is the identity matrix).

Because \(v\) is just \(z\) shifted by this noise, the probability distribution of \(v\) given a specific \(z\) is also a Gaussian distribution:

\[p(v|z) \sim \mathcal{N}(z, \sigma^2 I)\]

2. The Multivariate Gaussian PDF

The standard formula for the Probability Density Function of an \(N\)-dimensional multivariate Gaussian is:

\[p(v|z) = \frac{1}{(2\pi)^{N/2} |\Sigma|^{1/2}} \exp\left( -\frac{1}{2} (v - z)^T \Sigma^{-1} (v - z) \right)\]

Let's plug in our specific covariance matrix, \(\Sigma = \sigma^2 I\):

  • The determinant \(|\Sigma|\) becomes \((\sigma^2)^N\).
  • The inverse \(\Sigma^{-1}\) becomes \(\frac{1}{\sigma^2} I\).

Substituting these into the PDF:

\[p(v|z) = \frac{1}{(2\pi \sigma^2)^{N/2}} \exp\left( -\frac{1}{2\sigma^2} (v - z)^T (v - z) \right)\]

In linear algebra, the dot product of a vector with itself, \((v - z)^T (v - z)\), is the exact definition of the squared L2 norm, \(\|v - z\|_2^2\). (And since we are squaring it, \(\|v - z\|_2^2\) is identical to \(\|z - v\|_2^2\)).

So, the PDF simplifies to:

\[p(v|z) = \frac{1}{(2\pi \sigma^2)^{N/2}} \exp\left( -\frac{1}{2\sigma^2} \|z - v\|_2^2 \right)\]

3. Taking the Negative Logarithm

In Maximum A Posteriori (MAP) estimation, we want to maximize this probability. However, computers hate multiplying tiny probabilities together (it causes mathematical underflow), and dealing with exponentials is difficult.

Because the logarithm is a strictly increasing function, finding the maximum of \(\log(f(x))\) yields the exact same \(x\) as finding the maximum of \(f(x)\). To turn it into a *minimization* problem (which algorithms like ADMM require), we take the negative natural logarithm (\(-\log\)).

Let's apply \(-\log\) to our PDF:

\[-\log p(v|z) = -\log \left( \frac{1}{(2\pi \sigma^2)^{N/2}} \exp\left( -\frac{1}{2\sigma^2} \|z - v\|_2^2 \right) \right)\]

Using the logarithm rule \(\log(a \cdot b) = \log(a) + \log(b)\), we can split the equation into two parts: the normalization constant and the exponential term.

\[-\log p(v|z) = -\log \left( \frac{1}{(2\pi \sigma^2)^{N/2}} \right) - \log \left( \exp\left( -\frac{1}{2\sigma^2} \|z - v\|_2^2 \right) \right)\]

Because \(\log(\exp(x)) = x\), the exponential function perfectly cancels out:

\[-\log p(v|z) = \frac{N}{2}\log(2\pi \sigma^2) + \frac{1}{2\sigma^2} \|z - v\|_2^2\]

4. Dropping the Constants

Remember the goal: we are going to use this equation inside an \(\text{argmin}\) function to solve for the best \(z\).

\[z^{k+1} = \underset{z}{\text{argmin}} \left( -\log p(v|z) \right)\]

Look at the first term of our result: \(\frac{N}{2}\log(2\pi \sigma^2)\). There is no \(z\) in this term. It is entirely composed of constants (\(N\), \(\pi\), \(\sigma\)).

When you are trying to find the minimum of a function with respect to \(z\), adding or subtracting a constant shifts the whole curve up or down, but it *does not change the horizontal coordinate where the bottom of the bowl is located*. Therefore, the optimization algorithm can completely ignore it.

When we drop the constant term, we are left with the exact proportionality:

\[-\log p(v|z) \propto \frac{1}{2\sigma^2} \|z - v\|_2^2\]

The Takeaway

This is why the \(\frac{\rho}{2}\|Ax + Bz - c\|_2^2\) penalty in the Augmented Lagrangian is so elegant. When you set up ADMM, you are mathematically enforcing the assumption that the "disagreement" (the residual) between your variables acts like Gaussian noise.


Would you like to see how this equation completely changes if we assume our sensor data is corrupted by Poisson noise (shot noise) instead of Gaussian noise, and how that alters the optimization step?