PnP with Diffusion

← Back to Knowledge Share

1. The Original Optimization Problem

The core idea of plug-and-play Image Restoration (IR) methods is to take a complex inverse problem and separate it into a data term and a prior term 1. This is expressed as the following optimization problem:

\[\hat{x} = \text{argmin}_x \frac{1}{2\sigma_n^2} \|y - H(x)\|^2 + \lambda P(x)\]

  • \(y\): The observed measurement.
  • \(H\): A known degradation operator (e.g., blur, downsampling).
  • \(\sigma_n\): The standard deviation of the Gaussian noise.
  • \(\lambda P(\cdot)\): The prior term with a regularization parameter \(\lambda\) which ensures the solution adheres to the desired data distribution 1.

2. Decoupling the Problem with HQS

To solve this efficiently, the authors use the Half-Quadratic-Splitting (HQS) algorithm to decouple the data and prior terms. By introducing an auxiliary variable \(z\) and a constraint coefficient \(\mu\), the single equation is split into two iteratively solved subproblems:

Prior Subproblem:

\[z_k = \text{argmin}_z \frac{1}{2(\sqrt{\lambda/\mu})^2} \|z - x_k\|^2 + P(z)\]

*Data Subproblem: *

\[x_{k-1} = \text{argmin}_x \|y - H(x)\|^2 + \mu \sigma_n^2 \|x - z_k\|^2\]

The prior subproblem is effectively a Gaussian denoising problem, while the data subproblem represents a proximal operator that ensures data consistency.

3. Solving the Prior Term with Diffusion Models

Instead of using traditional denoisers, DiffPIR utilizes the generative capabilities of score-based diffusion models. In diffusion models, the score function \(s_\theta(x_t, t)\) approximates the gradient of the log-density of perturbed data.

By defining the relative noise level \(\bar{\sigma}_t = \frac{\sqrt{1-\bar{\alpha}_t}}{\bar{\alpha}_t}\) and substituting \(\sqrt{\lambda/\mu} = \bar{\sigma}_t\), the prior subproblem can be rewritten as a proximal operator that uses the diffusion model's score function:

\[z_k \approx x_k + \frac{1 - \bar{\alpha}_t}{\bar{\alpha}t} s_\theta(x_k)\]

In this context, \(z_k\) represents the predicted clean image, \(\hat{x}_0^{(t)}\)

4. Solving the Data Term

With the clean image \(\hat{x}_0^{(t)}\) estimated by the diffusion model, it is fed into the data subproblem to ensure the image matches the initial measurements:

\[\hat{x}_0^{(t)} = \text{argmin}_x \|y - H(x)\|^2 + \rho_t \|x - \hat{x}_0^{(t)}\|^2\]

Here, \(\rho_t = \lambda(\sigma_n / \bar{\sigma}_t)^2\) . For many tasks like image deblurring or super-resolution, this problem has a fast analytical closed-form solution. If an analytical solution is unavailable, it can be approximated using a first-order gradient descent step:

\[\hat{x}_0^{(t)} \approx \hat{x}_0^{(t)} - \frac{\bar{\sigma}t^2}{2\lambda \sigma_n^2} \nabla{\hat{x}_0^{(t)}} \|y - H(\hat{x}_0^{(t)})\|^2\]

5. DiffPIR Final Sampling Step

Once the image $\hat{x}_0^{(t)}* is updated to be data-consistent, the algorithm uses an estimation-correction strategy similar to DDIM to move to the next reverse diffusion step *\(x_{t-1}\).

First, the effective predicted noise \(\hat{\epsilon}\) is calculated based on the updated image:

\[\hat{\epsilon}(x_t,y) = \frac{1}{\sqrt{1 - \bar{\alpha}_t}} (x_t - \sqrt{\bar{\alpha}_t}\hat{x}_0^{(t)})\]

Finally, a noise hyperparameter \(\zeta\) is introduced to balance the corrected predicted noise \(\hat{\epsilon}\) and standard Gaussian noise \(\epsilon_t \sim \mathcal{N}(0, I)\). The final reverse sampling step is calculated as:

\[x_{t-1} = \sqrt{\bar{\alpha}_{t-1}}\hat{x}_0^{(t)} + \sqrt{1 - \bar{\alpha}_{t-1}}(\sqrt{1 - \zeta}\hat{\epsilon} + \sqrt{\zeta}\epsilon_t)\]

This process loops backwards from timestep \(T\) down to 1, progressively removing noise while constraining the output to match the target measurements.