PnP with Diffusion
1. The Original Optimization Problem
The core idea of plug-and-play Image Restoration (IR) methods is to take a complex inverse problem and separate it into a data term and a prior term 1. This is expressed as the following optimization problem:
\[\hat{x} = \text{argmin}_x \frac{1}{2\sigma_n^2} \|y - H(x)\|^2 + \lambda P(x)\]
- \(y\): The observed measurement.
- \(H\): A known degradation operator (e.g., blur, downsampling).
- \(\sigma_n\): The standard deviation of the Gaussian noise.
- \(\lambda P(\cdot)\): The prior term with a regularization parameter \(\lambda\) which ensures the solution adheres to the desired data distribution 1.
2. Decoupling the Problem with HQS
To solve this efficiently, the authors use the Half-Quadratic-Splitting (HQS) algorithm to decouple the data and prior terms. By introducing an auxiliary variable \(z\) and a constraint coefficient \(\mu\), the single equation is split into two iteratively solved subproblems:
Prior Subproblem:
\[z_k = \text{argmin}_z \frac{1}{2(\sqrt{\lambda/\mu})^2} \|z - x_k\|^2 + P(z)\]
*Data Subproblem: *
\[x_{k-1} = \text{argmin}_x \|y - H(x)\|^2 + \mu \sigma_n^2 \|x - z_k\|^2\]
The prior subproblem is effectively a Gaussian denoising problem, while the data subproblem represents a proximal operator that ensures data consistency.
3. Solving the Prior Term with Diffusion Models
Instead of using traditional denoisers, DiffPIR utilizes the generative capabilities of score-based diffusion models. In diffusion models, the score function \(s_\theta(x_t, t)\) approximates the gradient of the log-density of perturbed data.
By defining the relative noise level \(\bar{\sigma}_t = \frac{\sqrt{1-\bar{\alpha}_t}}{\bar{\alpha}_t}\) and substituting \(\sqrt{\lambda/\mu} = \bar{\sigma}_t\), the prior subproblem can be rewritten as a proximal operator that uses the diffusion model's score function:
\[z_k \approx x_k + \frac{1 - \bar{\alpha}_t}{\bar{\alpha}t} s_\theta(x_k)\]
In this context, \(z_k\) represents the predicted clean image, \(\hat{x}_0^{(t)}\)
4. Solving the Data Term
With the clean image \(\hat{x}_0^{(t)}\) estimated by the diffusion model, it is fed into the data subproblem to ensure the image matches the initial measurements:
\[\hat{x}_0^{(t)} = \text{argmin}_x \|y - H(x)\|^2 + \rho_t \|x - \hat{x}_0^{(t)}\|^2\]
Here, \(\rho_t = \lambda(\sigma_n / \bar{\sigma}_t)^2\) . For many tasks like image deblurring or super-resolution, this problem has a fast analytical closed-form solution. If an analytical solution is unavailable, it can be approximated using a first-order gradient descent step:
\[\hat{x}_0^{(t)} \approx \hat{x}_0^{(t)} - \frac{\bar{\sigma}t^2}{2\lambda \sigma_n^2} \nabla{\hat{x}_0^{(t)}} \|y - H(\hat{x}_0^{(t)})\|^2\]
5. DiffPIR Final Sampling Step
Once the image $\hat{x}_0^{(t)}* is updated to be data-consistent, the algorithm uses an estimation-correction strategy similar to DDIM to move to the next reverse diffusion step *\(x_{t-1}\).
First, the effective predicted noise \(\hat{\epsilon}\) is calculated based on the updated image:
\[\hat{\epsilon}(x_t,y) = \frac{1}{\sqrt{1 - \bar{\alpha}_t}} (x_t - \sqrt{\bar{\alpha}_t}\hat{x}_0^{(t)})\]
Finally, a noise hyperparameter \(\zeta\) is introduced to balance the corrected predicted noise \(\hat{\epsilon}\) and standard Gaussian noise \(\epsilon_t \sim \mathcal{N}(0, I)\). The final reverse sampling step is calculated as:
\[x_{t-1} = \sqrt{\bar{\alpha}_{t-1}}\hat{x}_0^{(t)} + \sqrt{1 - \bar{\alpha}_{t-1}}(\sqrt{1 - \zeta}\hat{\epsilon} + \sqrt{\zeta}\epsilon_t)\]
This process loops backwards from timestep \(T\) down to 1, progressively removing noise while constraining the output to match the target measurements.