Section 4.5 Structured Optimization with Compositional Gradient
In this section, we extend the compositional optimization technique to address other structured optimization problems, including min-max optimization, min-min optimization, and bilevel optimization. These problems share a common structure in the form of a compositional gradient, denoted by \(\mathcal{M}(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\), where \(\mathcal{M}\) is a mapping that is Lipschitz continuous with respect to its second argument, and \(\mathbf{u}^*(\mathbf{w})\) is defined as the solution to a strongly convex optimization problem:
\[\begin{align} \label{eqn:auxu} \mathbf{u}^*(\mathbf{w}) = \arg\min_{\mathbf{u} \in \mathcal{U}} h(\mathbf{w}, \mathbf{u}). \end{align}\]
This structure generalizes the gradient of a compositional function \(f(g(\mathbf{w}))\), whose gradient takes the form \(\mathcal{M}(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) = \nabla g(\mathbf{w})\nabla f(\mathbf{u}^*(\mathbf{w}))\) with
\[\mathbf{u}^*(\mathbf{w}) = \arg\min_{\mathbf{u}} \|\mathbf{u} - g(\mathbf{w})\|_2^2.\]
The high-level idea underlying the algorithms and analysis presented below is summarized as follows. To estimate \(\mathcal{M}(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\) at \(\mathbf{w}_t\), we use an auxiliary variable \(\mathbf{u}_t\) to track the optimal solution \(\mathbf{u}^*(\mathbf{w}_t)\), which is defined by solving (\(\ref{eqn:auxu}\)) with one step update at \(\mathbf{w}_t\). A key aspect of the analysis is that the error in the approximation of \(\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t)\) is controlled by the estimation error \(\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2\), due to the Lipschitz continuity of \(\mathcal{M}\):
\[\begin{align} \|\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t) - \mathcal{M}(\mathbf{w}_t, \mathbf{u}^*(\mathbf{w}_t))\|_2^2 \leq O(\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2). \end{align}\]
Moreover, since \(\mathbf{u}^*(\mathbf{w})\) is the solution to a strongly convex problem and hence is Lipschitz continuous with respect to \(\mathbf{w}\), we can construct a recursion for \(\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2\) to effectively bound the cumulative error over iterations.
In cases where \(\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t)\) cannot be computed exactly and is instead approximated by a stochastic estimator \(\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t; \zeta_t)\), where \(\zeta_t\) is a random variable, we employ a moving average (MA) estimator:
\[\begin{align*} \mathbf{v}_{t} = (1 - \beta_t) \mathbf{v}_{t-1} + \beta_t\, \mathcal{M}(\mathbf{w}_t, \mathbf{u}_t; \zeta_t). \end{align*}\]
The model update is then performed using:
\[\begin{align*} \mathbf{w}_{t+1} = \mathbf{w}_t - \eta_t \mathbf{v}_{t}. \end{align*}\]
Alternatively, if \(\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t)\) is directly computable, the update simplifies to:
\[\begin{align*} \mathbf{w}_{t+1} = \mathbf{w}_t - \eta_t\, \mathcal{M}(\mathbf{w}_t, \mathbf{u}_t). \end{align*}\]