← Go Back

Section 4.5.3 Non-convex Bilevel Optimization

In this section, we discuss the application of the compositional gradient estimation technique to non-convex bilevel optimization defined by

\[\begin{equation}\label{eqn:bi} \begin{aligned} \min_{\mathbf{w}\in\mathbb{R}^d}& f(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\\ \mathbf{u}^*(\mathbf{w})&=\arg\min_{\mathbf{u}\in\mathbb{R}^{d'}} g(\mathbf{w}, \mathbf{u}), \end{aligned} \end{equation}\]

where \(g\) is twice differentiable and \(\mu_g\)-strongly convex in terms of \(\mathbf{u}\). Let \(F(\mathbf{w}) =f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})).\) The following lemma states the gradient of the objective \(F(\mathbf{w})\).

Lemma 4.24 We have \[\begin{align*} \nabla F(\mathbf{w}) & = \nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) - \nabla_{21} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))^{\top}(\nabla_{22} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})))^{-1}\nabla_2 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})). \end{align*}\]

By the optimality condition of \(\mathbf{u}^*(\mathbf{w})\), we have \[\begin{align*} \nabla_2 g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))=0. \end{align*}\] By taking derivative on both sides, using the chain rule, and the implicit function theorem, we obtain \[\begin{align*} \nabla_{21} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) + \nabla_{22} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\frac{\partial \mathbf{u}^*(\mathbf{w})}{\partial \mathbf{w}}=0. \end{align*}\] Hence \[\begin{align*} \frac{\partial \mathbf{u}^*(\mathbf{w})}{\partial \mathbf{w}}=-(\nabla_{22} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})))^{-1}\nabla_{21} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})). \end{align*}\] Thus, \[\begin{align*} \nabla F(\mathbf{w}) & = \nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) + \frac{\partial \mathbf{u}^*(\mathbf{w})}{\partial \mathbf{w}}^{\top}\nabla_2 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\\ & = \nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) - \nabla_{21} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))^{\top}(\nabla_{22} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})))^{-1}\nabla_2 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})). \end{align*}\]

Let us define \[\begin{align*} &\mathcal{M}(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))=\nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) - \nabla_{21} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))^{\top}(\nabla_{22} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})))^{-1}\nabla_2 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})). \end{align*}\]

If we can establish the Lipschitz continuity of \(\mathcal{M}(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\) in terms of the second argument and the Lipschitz continuity of \(\mathbf{u}^*(\mathbf{w})\), then the similar technique can be leveraged. Let \(\mathbf{u}^*(\mathbf{w}_t)\) be tracked by \(\mathbf{u}_t\). It can be updated by SGD: \[\begin{align} \mathbf{u}_{t+1} = \mathbf{u}_t - \gamma_t \nabla_2 g(\mathbf{w}_t, \mathbf{u}_t; \zeta_t). \end{align}\]

With \(\mathbf{u}_t\), the gradient at \(\mathbf{w}_t\) can be estimated by \[\begin{align} \mathcal{M}(\mathbf{w}_t, \mathbf{u}_t)=\nabla_1 f(\mathbf{w}_t, \mathbf{u}_t) +\nabla_{21} g(\mathbf{w}_t, \mathbf{u}_t)^{\top}(\nabla_{22} g(\mathbf{w}_t, \mathbf{u}_t))^{-1}\nabla_2 f(\mathbf{w}_t, \mathbf{u}_t). \end{align}\] However, another challenge is to handle the Hessian inverse \((\nabla_{22} g(\mathbf{w}_t, \mathbf{u}_t)^{-1}\), which itself is a compositional structure. We will discuss three different ways to tackle this challenge. If we have a stochastic estimator of \(\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t)\) denoted by \(\mathbf{v}_{t}\), then we update the model parameter by: \[\begin{align} \mathbf{w}_{t+1} = \mathbf{w}_t - \eta_t \mathbf{v}_{t}. \end{align}\]

4.5.3.1 Approach 1: The MA Estimator

If the lower level problem is low-dimensional such that the inverse of the Hessian matrix can be efficiently computed, we can estimate \(\nabla_{22} g(\mathbf{w}_t, \mathbf{u}_t)\) by a MA estimator: \[\begin{align*} H_{22,t} = S_{\mu_g}[(1-\beta) H_{22,t-1} + \beta \nabla_{22} g(\mathbf{w}_t, \mathbf{u}_t; \zeta_{2,t})]. \end{align*}\] where \(S_{\mu_g}[\cdot]\) is a projection operator that projects a matrix into a matrix whose minimum eigen-value is lower bounded by \(\mu_g\), where \(\mu_g\) is the lower bound of eigen-values of \(\nabla_{22} g(\mathbf{w}, \mathbf{u})\). The projection ensures that \([H_{22,t}]^{-1}\) is Lipschitz continuous with respect to \(H_{22,t}\).

The a vanilla stochastic gradient estimator of \(\mathbf{w}_t\) and its MA estimator are computed by

\[\begin{equation}\label{eqn:bio-hma} \begin{aligned} &\mathbf{z}_{t} = \nabla_1 f(\mathbf{w}_t, \mathbf{u}_t; \xi_{t}) +\nabla_{21} g(\mathbf{w}_t, \mathbf{u}_t; \zeta'_{2,t})^{\top}(H_{22,t})^{-1}\nabla_2 f(\mathbf{w}_t, \mathbf{u}_t; \xi_{t})\\ &\mathbf{v}_{t} =(1-\beta)\mathbf{v}_{t-1} + \beta \mathbf{z}_t. \end{aligned} \end{equation}\]

Convergence Analysis

The proof is largely similar to that of Theorem 4.3. We provide a sketch of proof below. Recall that \[\begin{align*} &\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t)=\nabla_1 f(\mathbf{w}_t, \mathbf{u}_t) +\nabla_{21} g(\mathbf{w}_t, \mathbf{u}_t)^{\top}(\nabla_{22} g(\mathbf{w}_t, \mathbf{u}_t))^{-1}\nabla_2 f(\mathbf{w}_t, \mathbf{u}_t). \end{align*}\] Define: \[\begin{align*} \hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t)=\nabla_1 f(\mathbf{w}_t, \mathbf{u}_t) +\nabla_{21} g(\mathbf{w}_t, \mathbf{u}_t)^{\top} H_{22,t}^{-1}\nabla_2 f(\mathbf{w}_t, \mathbf{u}_t). \end{align*}\]

First, similar to Lemma 4.9, we have the following: \[\begin{align} F(\mathbf{w}_{t+1})& \leq F(\mathbf{w}_t) + \frac{\eta_t}{2} \|\mathbf{v}_t - \nabla F(\mathbf{w}_t)\|_2^2-\frac{\eta_t}{2}\|\nabla F(\mathbf{w}_t)\|_2^2- \frac{1}{4\eta_t} \|\mathbf{w}_{t+1} - \mathbf{w}_t\|_2^2. \end{align}\]

We establish a recursion of the error \(\|\mathbf{v}_t - \nabla F(\mathbf{w}_t)\|_2^2\) similar to Lemma 4.7 by noting that \(\mathbb E_{\xi_t, \zeta'_{2,t}}[\mathbf{z}_t]=\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t)\) and there exists \(\sigma>0\) such that \(\mathbb E_{\xi_t, \zeta'_{2,t}}[\|\mathbf{z}_t - \hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t)\|_2^2]\leq \sigma^2\). Thus, Lemma 4.7 implies that

\[\begin{align}\label{eqn:bio-v-err} &\mathbb E_{\xi_t, \zeta'_{2,t}}\left[\|\mathbf{v}_{t} - \nabla F(\mathbf{w}_t)\|_2^2\right] \leq (1-\beta_t)\|\mathbf{v}_{t-1} - \nabla F(\mathbf{w}_{t-1})\|_2^2 \\ & \quad + \frac{2L_F^2}{\beta_t}\|\mathbf{w}_{t-1}-\mathbf{w}_t\|_2^2 + 4\beta_t\|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t) - \nabla F(\mathbf{w}_t)\|_2^2 + \beta_t^2\sigma^2.\notag \end{align}\]

Then, we bound \(\|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t)-\nabla F(\mathbf{w}_t)\|_2^2\) by \[\begin{align*} \|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t)-\nabla F(\mathbf{w}_t)\|_2^2&\leq 2\|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t) -\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t) \|^2_2\\ & +2\|\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t) - \nabla F(\mathbf{w}_t)\|^2_2\\ & \leq O(\|H_{22,t} - \nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\|_2^2) + O(\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|^2_2). \end{align*}\]

As a result, we have \[\begin{align*} \mathbb E[\|\mathbf{v}_t - &\nabla F(\mathbf{w}_t)\|_2^2] \leq (1-\beta_t)\|\mathbf{v}_{t-1}- \nabla F(\mathbf{w}_{t-1})\|^2 + \frac{2L_F^2}{\beta_t}\|\mathbf{w}_t - \mathbf{w}_{t-1}\|_2^2\\ &+ \beta_t (O(\|H_{22,t} - \nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\|_2^2) + O(\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|^2_2)) + \beta_t^2O(\sigma^2). \end{align*}\] This result is similar to that in Lemma 4.8.

We can further build the error recursion of \(\|H_{22,t} - \nabla_{22}g(\mathbf{w}_t,\mathbf{u}_t)\|_2^2\) similar to Lemma 4.1, and the error recursion of \(\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|^2_2\) similar to Lemma 4.20. Combining these results, we can establish a complexity of \(O(1/\epsilon^4)\) for finding an \(\epsilon\)-stationary solution of \(F(\cdot)\) in expectation.

4.5.3.2 Approach 2: The Neumann Series (Matrix Taylor Approximation)

If the lower level problem is high-dimensional such that it is prohibited to compute the Hessian, one approach is to leverage the Neuman series: \[\begin{align} A^{-1}= \sum_{i=0}^\infty (I -A)^i, \quad \text{ if }\quad \|A\|\leq 1. \end{align}\] Hence, if \(\|\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\|\leq L_{22}\), we estimate the inverse of \(\frac{1}{L_{22}}\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\), yielding \[\begin{align} \left(\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\right)^{-1}\approx \frac{1}{L_{22}}\sum_{i=0}^{K-1} \left(I - \frac{1}{L_{22}}\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\right)^i. \end{align}\] This can be further estimated by a stochastic route, by sampling \(k\) from \(\{0,\ldots,K- 1\}\) randomly, then estimate the Hessian inverse by \[\begin{align} Q_{22,t} = \left\{\begin{array}{ll}\frac{K}{L_{22}}\prod_{i=1}^{k} \left(I - \frac{1}{L_{22}}\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t; \zeta_{i})\right)& \text{ if } k\geq 1\\ \frac{K}{L_{22}}I & \text{ if } k=0\end{array}\right.. \end{align}\] This is can be justified by \[\begin{align*} \mathbb{E}[Q_{22,t}] &= \frac{1}{K}\frac{K}{L_{22}}I + \frac{K-1}{K}\mathbb{E}_{k\sim\{1,\ldots, K-1\}}\left[\frac{K}{L_{22}}\prod_{i=1}^{k} \left(I - \frac{1}{L_{22}}\mathbb{E}[\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t; \zeta_{i})]\right)\right]\\ & = \mathbb{E}_{k}\frac{K}{L_{22}}\left(I - \frac{1}{L_{2}}\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\right)^k = \sum_{k=0}^{K-1}\frac{1}{L_{22}}\left(I -\frac{1}{L_{22}}\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\right)^k. \end{align*}\]

Then the vanilla gradient estimator of \(\mathbf{w}_t\) and its MA estimator are computed by

\[\begin{equation}\label{eqn:bio-neu} \begin{aligned} &\mathbf{z}_{t} = \nabla_1 f(\mathbf{w}_t, \mathbf{u}_t; \zeta_{1,t}) +\nabla_{21} g(\mathbf{w}_t, \mathbf{u}_t; \zeta'_{2,t})^{\top}Q_{22,t}\nabla_2 f(\mathbf{w}_t, \mathbf{u}_t; \zeta_{1,t})\\ &\mathbf{v}_{t} =(1-\beta)\mathbf{v}_{t-1} + \beta \mathbf{z}_t. \end{aligned} \end{equation}\]

Convergence Analysis

We provide a proof sketch below. We can understand that \(\mathbf{z}_t\) is an unbiased stochastic estimator of \[\begin{align*} \hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t)=\nabla_1 f(\mathbf{w}_t, \mathbf{u}_t) +\nabla_{21} g(\mathbf{w}_t, \mathbf{u}_t)^{\top}\mathbb{E}[Q_{22,t}]\nabla_2 f(\mathbf{w}_t, \mathbf{u}_t). \end{align*}\] We decompose the estimation error of \(\mathbf{v}_t\) similarly as in (\(\ref{eqn:bio-v-err}\)) and bound \(\|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t) - \nabla F(\mathbf{w}_t)\|_2^2\) by \[\begin{align*} \|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t) - \nabla F(\mathbf{w}_t)\|_2^2 &\leq 2\|\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t) -\nabla F(\mathbf{w}_t)\|_2^2 \\ & + 2\|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t) - \mathcal{M}(\mathbf{w}_t,\mathbf{u}_t)\|_2^2 . \end{align*}\] The error recursion of the first term on the right hand side can be similarly bounded as before. To bound the last error, since \[\begin{equation*} \left[\nabla_{22}^2 g(\mathbf{w}, \mathbf{u})\right]^{-1} = \mathbb{E}[Q_{22}] + \frac{1}{L_{22}} \sum_{i=K}^{\infty} \left[I - \frac{1}{L_{22}} \nabla_{22}^2 g(\mathbf{w}, \mathbf{u})\right]^i, \end{equation*}\] we have \[\begin{align*} &\|\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t) - \hat{\mathcal{M}}(\mathbf{w}_t,\mathbf{u}_t)\|_2^2\leq O(\|\left[\nabla_{22}^2 g(\mathbf{w}, \mathbf{u})\right]^{-1} - \mathbb{E}[Q_{22}]\|_2^2),\\ &\left\| \left[\nabla_{22}^2 g(\mathbf{w}, \mathbf{u})\right]^{-1} - \mathbb{E}[Q_{22}] \right\|_2 \le \frac{1}{L_{22}} \sum_{i=K}^{\infty} \left\|I - \frac{1}{L_{22}} \nabla_{22}^2 g(\mathbf{w}, \mathbf{u})\right\|_2^i \le \frac{1}{\mu_g} \left(1 - \frac{\mu_g}{L_{22}}\right)^K. \end{align*}\] As a result, if \(K=O(\frac{L_{22}}{\mu_g}\log(1/(\mu_g\beta_t\sigma^2)))\), then \(\|\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t) - \mathcal{M}'(\mathbf{w}_t,\mathbf{u}_t)\|_2^2\leq O(\beta_t\sigma^2)\). Then similar to the analysis of Approach 1, we can establish a complexity of \(O(1/\epsilon^4)\) for finding an \(\epsilon\)-stationary solution of \(F(\cdot)\) in expectation.

4.5.3.3 Approach 3: The penalty method

An alternative approach to avoid computing the Hessian inverse and Jacobian matrices is to reformulate the problem as a constrained optimization problem: \[\begin{align*} \min_{\mathbf{w},\mathbf{u}} &\quad f(\mathbf{w}, \mathbf{u}) \\ \text{s.t.} &\quad g(\mathbf{w}, \mathbf{u}) \leq \min_{\mathbf{u}} g(\mathbf{w}, \mathbf{u}). \end{align*}\] This constrained problem can be addressed using a penalty method (see Section 6.7): \[\begin{align*} \min_{\mathbf{w},\mathbf{u}} \; f(\mathbf{w}, \mathbf{u}) + \lambda \big( g(\mathbf{w}, \mathbf{u}) - \min_{\mathbf{y}} g(\mathbf{w}, \mathbf{y}) \big)_+, \end{align*}\] where \(\lambda > 0\) is a penalty parameter and \((\cdot)_+\) denotes the positive part. Since \(g(\mathbf{w}, \mathbf{u}) \geq \min_{\mathbf{y}} g(\mathbf{w}, \mathbf{y})\), the formulation simplifies to:

\[\begin{align}\label{eqn:bio-minmax} &\min_{\mathbf{w},\mathbf{u}} \; f(\mathbf{w}, \mathbf{u}) + \lambda \left( g(\mathbf{w}, \mathbf{u}) - \min_{\mathbf{y}} g(\mathbf{w}, \mathbf{y}) \right)\\ &= \min_{\mathbf{w},\mathbf{u}} \max_{\mathbf{y}} f(\mathbf{w}, \mathbf{u}) + \lambda \left( g(\mathbf{w}, \mathbf{u}) - g(\mathbf{w}, \mathbf{y}) \right). \end{align}\]

If both \(f\) and \(g\) are smooth and \(g\) is strongly convex in its second argument, the resulting formulation becomes a non-convex strongly-concave min-max problem, which can be effectively addressed using the SMDA algorithm with the following update for \(t\geq 1\):

\[\begin{equation}\label{eqn:bi-smda} \begin{aligned} &\mathbf{y}_{t+1} = \mathbf{y}_{t} + \gamma_t\lambda\nabla_2 g(\mathbf{w}_t, \mathbf{y}_{t}; \xi_t),\\ &\mathbf{z}_t = \nabla f(\mathbf{w}_t, \mathbf{u}_{t}; \zeta_t) + \lambda \left(\nabla g(\mathbf{w}_t, \mathbf{u}_{t}; \xi_t) - \left[\nabla_1 g(\mathbf{w}_t, \mathbf{y}_{t}; \xi_t)\atop 0\right]\right),\\ &\mathbf{v}_{t} = (1-\beta_t)\mathbf{v}_{t-1} + \beta_t \mathbf{z}_t, \\ &\left[\mathbf{w}_{t+1}\atop \mathbf{u}_{t+1}\right] = \left[\mathbf{w}_t\atop \mathbf{u}_{t}\right] - \eta_t \mathbf{v}_{t}. \end{aligned} \end{equation}\]

Convergence Analysis

The convergence analysis of (\(\ref{eqn:bi-smda}\)) for the min–max problem (\(\ref{eqn:bio-minmax}\)) follows a similar approach to that of Theorem 4.5 for SMDA. However, a remaining challenge lies in converting the convergence result for the min–max formulation into that of the original problem. To address this, we provide the detailed convergence analysis below. We begin by stating the following assumption.

Assumption 4.9

Regarding the problem (\(\ref{eqn:bi}\)), the following conditions hold:

\[\begin{align} \|\nabla f(\mathbf{w}_1, \mathbf{u}_1) - \nabla f(\mathbf{w}_2, \mathbf{u}_2)\|_2 \leq L_f\left\|\left(\mathbf{w}_1\atop \mathbf{u}_1\right) - \left(\mathbf{w}_2\atop \mathbf{u}_2\right)\right\|_2. \end{align}\]

\[\begin{align} \|\nabla g(\mathbf{w}_1, \mathbf{u}_1) - \nabla g(\mathbf{w}_2, \mathbf{u}_2)\|_2 \leq L_g\left\|\left(\mathbf{w}_1\atop \mathbf{u}_1\right) - \left(\mathbf{w}_2\atop \mathbf{u}_2\right)\right\|_2. \end{align}\]

\[\begin{align} &\mathbb{E}[\|\nabla f(\mathbf{w}, \mathbf{u}; \zeta)- \nabla f(\mathbf{w}, \mathbf{u})\|_2^2]\leq \sigma_f^2,\\ &\mathbb{E}[\|\nabla g(\mathbf{w}, \mathbf{u}; \xi)- \nabla g(\mathbf{w}, \mathbf{u})\|_2^2]\leq \sigma_g^2. \end{align}\]

Let us define \(\bar{\mathbf{w}} : = (\mathbf{w}, \mathbf{u})\) and

\[\begin{align} &\bar f(\bar{\mathbf{w}}, \mathbf{y}): = f(\mathbf{w}, \mathbf{u}) + \lambda \left( g(\mathbf{w}, \mathbf{u}) - g(\mathbf{w}, \mathbf{y}) \right),\\ &\bar F(\bar{\mathbf{w}}) : = \max_{\mathbf{y}}\bar f(\bar{\mathbf{w}}, \mathbf{y}). \end{align}\]

Then \[\begin{align*} &\nabla_1\bar f(\bar{\mathbf{w}}, \mathbf{y})=\nabla f(\mathbf{w}, \mathbf{u}) + \lambda \left(\nabla g(\mathbf{w}, \mathbf{u}) - \left[\nabla_1 g(\mathbf{w}, \mathbf{y})\atop 0\right]\right),\\ &\nabla_2\bar f(\bar{\mathbf{w}}, \mathbf{y})= -\lambda \nabla_2 g(\mathbf{w}, \mathbf{y}),\\ &\nabla_1\bar f(\bar{\mathbf{w}}, \mathbf{y};\varepsilon)=\nabla f(\mathbf{w}, \mathbf{u}; \zeta) + \lambda \left(\nabla g(\mathbf{w}, \mathbf{u};\xi) - \left[\nabla_1 g(\mathbf{w}, \mathbf{y}; \xi)\atop 0\right]\right),\\ &\nabla_2\bar f(\bar{\mathbf{w}}, \mathbf{y};\xi)= -\lambda \nabla_2 g(\mathbf{w}, \mathbf{y};\xi). \end{align*}\] where \(\varepsilon=(\zeta, \xi)\). We first show \(\bar f(\bar{\mathbf{w}}, \mathbf{y})\) satisfies the conditions in Assumption 4.8.

Lemma 4.25

Under Assumption 4.9, we have

Proof.
(i) is obvious. The Lipschitz continuity of \(\nabla_1 \bar f(\bar{\mathbf{w}}, \mathbf{y})\) follows that of \(\nabla f(\mathbf{w}, \mathbf{u})\) and \(\nabla g(\mathbf{w}, \mathbf{u})\). For (iii), we have \[\begin{align*} &\|\nabla_2\bar f(\bar{\mathbf{w}}_1, \mathbf{y}_1)- \nabla_2\bar f(\bar{\mathbf{w}}_2, \mathbf{y}_2)\|_2= \lambda \|\nabla_2 g(\mathbf{w}_1, \mathbf{u}_1)-\nabla_2 g(\mathbf{w}_2, \mathbf{u}_2)\|_2\\ &\leq \lambda \|\nabla g(\mathbf{w}_1, \mathbf{u}_1)-\nabla g(\mathbf{w}_2, \mathbf{u}_2)\|_2 \leq \lambda L_g\left\|\left(\mathbf{w}_1\atop \mathbf{u}_1\right) - \left(\mathbf{w}_2\atop \mathbf{u}_2\right)\right\|_2\\ &\leq \lambda L_g(\|\mathbf{w}_1 - \mathbf{w}_2\|_2 + \|\mathbf{u}_1 -\mathbf{u}_2\|_2)\leq \lambda L_g(\|\bar{\mathbf{w}}_1 - \bar{\mathbf{w}}_2\|_2 + \|\mathbf{u}_1 -\mathbf{u}_2\|_2). \end{align*}\] It is trivial to prove (iv). The last result follows that \(\max_{\mathbf{y}}\bar f(\bar{\mathbf{w}}, \mathbf{y})\geq f(\mathbf{w}, \mathbf{u})\geq \infty\).

Theorem 4.6

Suppose Assumption 4.9 hold. By setting \[\begin{align*} &\beta_t=\beta= \frac{\epsilon^2}{9\sigma_f^2 + 18\lambda^2\sigma_g^2}, \\ &\gamma_t=\gamma=\frac{\mu_g\epsilon^2}{96(L_f+2L_g\lambda)^2\lambda\sigma_g^2},\\ &\eta_t= \\ & \quad\min\left\{\frac{\beta}{\sqrt{8}(L_f+2L_g\lambda)(1+L_g)},\frac{\gamma\mu_g\lambda}{16\sqrt{3}(L_f + 2L_g\lambda)L_g},\frac{1}{2(L_f+2L_g\lambda)(1+L_g)}\right\} \end{align*}\] in (\(\ref{eqn:bi-smda}\)), then the following holds

\[\begin{align}\label{eq:smda-c} \mathbb{E}\left[\frac{1}{T}\sum_{t=0}^{T-1}\left\{\frac{1}{4}\|\mathbf{v}_{t}\|_2^2 + \|\nabla \bar F(\bar{\mathbf{w}}_t)\|_2^2\right\}\right]\leq \epsilon^2, \end{align}\]

with an iteration complexity of

\[\begin{align}\label{eqn:bi-smda-t} T&=O\left(\max\left\{\frac{C_\Upsilon \lambda}{\epsilon^2}, \frac{C_\Upsilon(\lambda\sigma_f^2+\lambda^3\sigma_g^2)}{\epsilon^4}, \frac{C_\Upsilon \lambda^3\sigma_g^2}{\epsilon^4\mu_g^2}\right\}\right), \end{align}\] where \(C_\Upsilon= 2(\bar F(\bar{\mathbf{w}}_0) - \min_{\bar{\mathbf{w}}}\bar F(\bar{\mathbf{w}})) + \frac{1}{\sqrt{8}\bar L_F}\|\mathbf{v}_0 - \nabla \bar F(\bar{\mathbf{w}}_0)\|_2^2 + \frac{L_1}{\sqrt{3}\kappa}\|\mathbf{y}_0 - \mathbf{y}^*(\mathbf{w}_0)\|_2^2\) and \(\kappa = L_g/\mu_g\) .

Proof.
We map the problem into the setting in Theorem 4.5 with \(L_1=L_f + 2L_g\lambda\), \(L_{21}=L_g\lambda\), \(L_{22}=L_g\lambda\), \(\mu=\mu_g\lambda\), \(\kappa = L_{21}/(\mu_g\lambda)=L_g/\mu_g\), \(L_F = L_1(1+\kappa)=(L_f+2L_g\lambda)(1+L_g)\), \(\sigma_1^2 = 3\sigma_f^2 + 6\lambda^2\sigma_g^2\), \(\sigma_2^2 = \lambda^2\sigma_g^2\). Then, substituting these values into the result of Theorem 4.5, we obtain the desired conclusion, retaining only the dependence on \(\lambda\), as it will be set to a sufficiently large value.

Convergence of the original function

Next, we derive the convergence of the original function in terms of \(\|\nabla F(\mathbf{w})\|_2\). We need the following additional assumption.

Assumption 4.10

  1. \(g\) is twice differentiable and \(\nabla_{21}g(\mathbf{w}, \mathbf{u})\) and \(\nabla g_{22}(\mathbf{w}, \mathbf{u})\) are \(L_{gg}\)-Lipschitz continuous; and (ii) \(\|\nabla_2 f(\mathbf{w}, \mathbf{u})\|_2\leq G_f\).

Lemma 4.26

Let \(\mathbf{u}^*_{\lambda}(\mathbf{w})=\arg\min_{\mathbf{u}}\bar F(\mathbf{w}, \mathbf{u})\), \(\mathbf{u}^*(\mathbf{w})=\arg\min_{\mathbf{u}}g(\mathbf{w}, \mathbf{u})\). Under Assumption 4.10(i), we have \[\begin{align*} \|\nabla F(\mathbf{w}) - \nabla_1\bar F(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w}))\|_2&\leq L_f(1+\frac{L_g}{\mu_g})\|\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|_2 \\ &+ L_{gg}\lambda(1+\frac{L_g}{\mu_g}) \|\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|^2_2. \end{align*}\]

Proof.
Let \(\mathbf{u}^* = \mathbf{u}^*(\mathbf{w})\). Then, \[\begin{align*} \nabla_{1} \bar{F}(\mathbf{w}, \mathbf{u}) &= \nabla_{1} f(\mathbf{w}, \mathbf{u}) + \lambda (\nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*)) \\ \nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u}) &= \nabla_{2} f(\mathbf{w}, \mathbf{u}) + \lambda \nabla_{2} g(\mathbf{w}, \mathbf{u}). \end{align*}\] Due to Lemma 4.24, we have

\[\begin{equation} \label{eq:31_new} \begin{split} &\nabla F(\mathbf{w}) - \nabla_{1} \bar{F}(\mathbf{w}, \mathbf{u}) = \nabla_{1} f(\mathbf{w}, \mathbf{u}^*) - \nabla_{1} f(\mathbf{w}, \mathbf{u}) \\ & -\nabla_{12} g(\mathbf{w}, \mathbf{u}^*) \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} \nabla_{2} f(\mathbf{w}, \mathbf{u}^*) - \lambda (\nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*)). \end{split} \end{equation}\]

We can rearrange terms for \((\nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*))\) as the following:

\[\begin{equation} \label{eq:32_new} \begin{split} \nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*) &= \nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} (\mathbf{u} - \mathbf{u}^*) \\ & + \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} (\mathbf{u} - \mathbf{u}^*). \end{split} \end{equation}\]

To continue, we leverage the following equality: \[\begin{equation*} \begin{split} \mathbf{u} - \mathbf{u}^* =& -\nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} g(\mathbf{w}, \mathbf{u}) - \nabla_{2} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{22} g(\mathbf{w}, \mathbf{u}^*) (\mathbf{u} - \mathbf{u}^*)) \\ & +\nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} g(\mathbf{w}, \mathbf{u}) - \nabla_{2} g(\mathbf{w}, \mathbf{u}^*)). \end{split} \end{equation*}\] By the optimality condition for \(\mathbf{u}^*\), \(\nabla_{2} g(\mathbf{w}, \mathbf{u}^*) = 0\), and \(\nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u})=\nabla_{2} f(\mathbf{w}, \mathbf{u}) + \lambda \nabla_{2} g(\mathbf{w}, \mathbf{u})\), we can express \(\mathbf{u} - \mathbf{u}^*\) as

\[\begin{equation}\label{eq:33_new} \begin{split} \mathbf{u} - \mathbf{u}^* &= -\nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} g(\mathbf{w}, \mathbf{u}) - \nabla_{2} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{22} g(\mathbf{w}, \mathbf{u}^*) (\mathbf{u} - \mathbf{u}^*)) \\ & + \frac{1}{\lambda} \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u}) - \nabla_{2} f(\mathbf{w}, \mathbf{u})). \end{split} \end{equation}\]

Combining (\(\ref{eq:33_new}\)) and (\(\ref{eq:32_new}\)) and then plugging the result back to (\(\ref{eq:31_new}\)), we have \[\begin{equation*} \begin{split} &\nabla F(\mathbf{w}) - \nabla_{1} \bar{F}(\mathbf{w}, \mathbf{u}) = \nabla_{1} f(\mathbf{w}, \mathbf{u}^*) - \nabla_{1} f(\mathbf{w}, \mathbf{u}) \\ & -\nabla_{12} g(\mathbf{w}, \mathbf{u}^*) \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} \nabla_{2} f(\mathbf{w}, \mathbf{u}^*) \\ &- \lambda (\nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} (\mathbf{u} - \mathbf{u}^*))\\ & +\lambda \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} g(\mathbf{w}, \mathbf{u}) - \nabla_{2} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{22} g(\mathbf{w}, \mathbf{u}^*) (\mathbf{u} - \mathbf{u}^*)) \\ & - \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u}) - \nabla_{2} f(\mathbf{w}, \mathbf{u})). \end{split} \end{equation*}\]

As a result, we have \[\begin{equation*} \begin{split} &\nabla F(\mathbf{w}) - \nabla_{1} \bar{F}(\mathbf{w}, \mathbf{u})+ \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} \nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u}) \\ & = \nabla_{1} f(\mathbf{w}, \mathbf{u}^*) - \nabla_{1} f(\mathbf{w}, \mathbf{u}) \\ & -\nabla_{12} g(\mathbf{w}, \mathbf{u}^*) \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} f(\mathbf{w}, \mathbf{u}^*) - \nabla_{2} f(\mathbf{w}, \mathbf{u})) \\ &- \lambda (\nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} (\mathbf{u} - \mathbf{u}^*))\\ & +\lambda \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} g(\mathbf{w}, \mathbf{u}) - \nabla_{2} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{22} g(\mathbf{w}, \mathbf{u}^*) (\mathbf{u} - \mathbf{u}^*)) . \end{split} \end{equation*}\]

By Assumption 4.10 we have \[\begin{align*} &\|\nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} (\mathbf{u} - \mathbf{u}^*)\|_2\leq L_{gg}\|\mathbf{u} - \mathbf{u}^*\|^2_2,\\ &\|\nabla_{2} g(\mathbf{w}, \mathbf{u}) - \nabla_{2} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{22} g(\mathbf{w}, \mathbf{u}^*) (\mathbf{u} - \mathbf{u}^*)\|_2 \le L_{gg} \|\mathbf{u} - \mathbf{u}^*\|^2_2. \end{align*}\] By Assumption 4.9 we have \[\begin{align*} &\|\nabla_{1} f(\mathbf{w}, \mathbf{u}^*) - \nabla_{1} f(\mathbf{w}, \mathbf{u})\|_2\leq L_f\|\mathbf{u}^*-\mathbf{u}\|_2,\\ &\|\nabla_{2} f(\mathbf{w}, \mathbf{u}^*) - \nabla_{2} f(\mathbf{w}, \mathbf{u})\|_2\leq L_f\|\mathbf{u}^*-\mathbf{u}\|_2,\\ &\|\nabla_{12} g(\mathbf{w}, \mathbf{u}^*) \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1}\|_2\leq \frac{L_g}{\mu_g}. \end{align*}\] Thus, we have \[\begin{equation*} \begin{split} &\|\nabla F(\mathbf{w}) - \nabla_{1} \bar{F}(\mathbf{w}, \mathbf{u})+ \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} \nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u})\|_2 \\ & \leq L_f(1+\frac{L_g}{\mu_g})\|\mathbf{u} - \mathbf{u}^*\|_2 + L_{gg}\lambda(1+\frac{L_g}{\mu_g}) \|\mathbf{u} - \mathbf{u}^*\|^2_2. \end{split} \end{equation*}\] Plugging \(\mathbf{u}=\mathbf{u}^*_\lambda(\mathbf{w}) = \min_{\mathbf{u}}\bar F(\mathbf{w}, \mathbf{u})\), then \(\nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w}) ) =0\) and then we have \[\begin{equation*} \begin{split} &\|\nabla F(\mathbf{w}) - \nabla_{1} \bar{F}(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w}))\|_2 \\ & \leq L_f(1+\frac{L_g}{\mu_g})\|\mathbf{u}^*_\lambda(\mathbf{w}) - \mathbf{u}^*\|_2 + L_{gg}\lambda(1+\frac{L_g}{\mu_g}) \|\mathbf{u}^*_\lambda(\mathbf{w}) - \mathbf{u}^*\|^2_2. \end{split} \end{equation*}\]

Next, we bound \(\|\mathbf{u}^*_\lambda(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|_2\).

Lemma 4.27

Under Assumption 4.10(ii), we have \(\|\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|_2 \le \frac{G_f}{\lambda \mu_g}\).

Proof.
By the definitions of \(\mathbf{u}^*_\lambda(\mathbf{w})\), \(\mathbf{u}^*(\mathbf{w})\), we have \[\begin{align*} &\mathbf{u}^*_\lambda(\mathbf{w}) = \arg\min_{\mathbf{u}} \frac{1}{\lambda}f(\mathbf{w}, \mathbf{u}) + g(\mathbf{w}, \mathbf{u}), \\ &\mathbf{u}^*(\mathbf{w}) = \arg\min_{\mathbf{u}} g(\mathbf{w}, \mathbf{u}). \end{align*}\] By the optimality condition, \[\begin{align*} &\frac{1}{\lambda}\nabla_2 f(\mathbf{w}, \mathbf{u}_\lambda^*(\mathbf{w})) + \nabla_2 g(\mathbf{w}, \mathbf{u}_\lambda^*(\mathbf{w}))=0,\\ &\nabla_2 g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))=0. \end{align*}\] Since \(g(\mathbf{w}, \mathbf{u})\) is \(\mu_g\)-strongly convex w.r.t \(\mathbf{u}\) for any \(\mathbf{w}\), then we have \[\begin{align*} g(\mathbf{w},\mathbf{u}^*_{\lambda}(\mathbf{w}))\geq & g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) + \nabla_2 g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))^{\top}(\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})) \\ & + \frac{\mu_g}{2}\|\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|_2^2,\\ g(\mathbf{w},\mathbf{u}^*(\mathbf{w}))\geq & g(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w})) + \nabla_2 g(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w}))^{\top}(\mathbf{u}^*(\mathbf{w}) - \mathbf{u}^*_{\lambda}(\mathbf{w})) \\ & + \frac{\mu_g}{2}\|\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|_2^2. \end{align*}\] Adding these two inequalities yields: \[\begin{align*} \mu_g\|\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|_2^2&\leq -\nabla_2 g(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w}))^{\top}(\mathbf{u}^*(\mathbf{w}) - \mathbf{u}^*_{\lambda}(\mathbf{w}))\\ & =\frac{1}{\lambda}\nabla_2 f(\mathbf{w}, \mathbf{u}_\lambda^*(\mathbf{w}))^{\top}(\mathbf{u}^*(\mathbf{w}) - \mathbf{u}^*_{\lambda}(\mathbf{w}))\\ &\leq \frac{1}{\lambda}\|\nabla_2 f(\mathbf{w}, \mathbf{u}_\lambda^*(\mathbf{w}))\|_2\|(\mathbf{u}^*(\mathbf{w}) - \mathbf{u}^*_{\lambda}(\mathbf{w}))\|_2. \end{align*}\] Dividing both sides by \(\|\mathbf{u}^*(\mathbf{w}) - \mathbf{u}^*_{\lambda}(\mathbf{w})\|_2\) and noting \(\|\nabla_2 f(\mathbf{w}, \mathbf{u}_\lambda^*(\mathbf{w}))\|_2\leq G_f\) concludes the proof.

Corollary 4.2

Under the same setting as in Theorem 4.6 with \(\lambda=O(\frac{1}{\epsilon})>2L_f/\mu_g\) and assume \(\|\mathbf{y}_0-\mathbf{y}^*(\mathbf{w}_0)\|^2_2\leq O(\epsilon)\), then the following holds

\[\begin{align} \mathbb{E}\left[\|\nabla F(\mathbf{w}_\tau)\|_2\right]\leq O(\epsilon), \end{align}\]

with an iteration complexity of

\[\begin{align}\label{eq:bi-smda-t} T&=O\left(\max\left\{\frac{1}{\epsilon^3}, \frac{\sigma_f^2}{\epsilon^5}, \frac{\sigma_g^2}{\epsilon^7}\right\}\right), \end{align}\] where \(\tau\in\{0,\ldots, T-1\}\) is randomly sampled.

Proof.
Combining Lemma 4.25 and Lemma 4.27, we have \[\begin{align*} \|\nabla F(\mathbf{w}_\tau)\|_2 & = \|\nabla F(\mathbf{w}_\tau) - \nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}^*_{\lambda}(\mathbf{w}_\tau))\|_2 + \|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}^*_{\lambda}(\mathbf{w}_\tau))\|_2\\ &\leq L_f(1+\frac{L_g}{\mu_g})\frac{G_f}{\mu_g\lambda} + L_{gg}\lambda(1+\frac{L_g}{\mu_g})\frac{G_f^2}{\mu_g^2\lambda^2}\\ & + \|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}^*_{\lambda}(\mathbf{w}_\tau)) - \nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2 + \|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2\\ &\leq \frac{2L_fL_gG_f}{\mu_g^2\lambda} + \frac{2L_{gg}L_gG_f^2}{\mu_g^3\lambda}\\ & + \|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}^*_{\lambda}(\mathbf{w}_\tau)) - \nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2 + \|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2. \end{align*}\] Since \(\bar F(\mathbf{w}, \mathbf{u})\) is \((\lambda\mu_g - L_f)\)-strongly convex w.r.t \(\mathbf{u}\), Lemma 1.6(c) implies that \[\begin{align*} (\lambda\mu_g - L_f)&\|\mathbf{u}^*_{\lambda}(\mathbf{w}_\tau) - \mathbf{u}_\tau\|_2^2\leq \frac{1}{(\lambda\mu_g - L_f)}\|\nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau) - \nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}^*_{\lambda}(\mathbf{w}_\tau))\|_2^2\\ & = \frac{1}{(\lambda\mu_g - L_f)}\|\nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau) \|^2_2. \end{align*}\] Due to \(\nabla_1 \bar F(\mathbf{w}, \mathbf{u}) = \nabla_1 f(\mathbf{w}, \mathbf{u}) + \lambda ( \nabla_1 g(\mathbf{w}, \mathbf{u}) - \nabla_1 g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})))\), we have \[\begin{align*} &\|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}^*_{\lambda}(\mathbf{w}_\tau)) - \nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2 \leq (L_f + \lambda L_g)\|\mathbf{u}^*_{\lambda}(\mathbf{w}_\tau) - \mathbf{u}_\tau\|_2\\ &\leq \frac{(L_f + \lambda L_g)}{(\lambda\mu_g - L_f)}\|\nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2\\ &\leq \frac{2(\lambda\mu_g/2+\lambda L_g)}{\lambda\mu_g}\|\nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2=\frac{\mu_g+2L_g}{\mu_g}\|\nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2 \end{align*}\] where the last inequality uses \(L_f\leq \lambda\mu_g/2\). Combining the above inequalities, we obtain \[\begin{align*} \|\nabla F(\mathbf{w}_\tau)\|_2 &\leq \frac{2L_fL_gG_f}{\mu_g^2\lambda} + \frac{2L_{gg}L_gG_f^2}{\mu_g^3\lambda} \\ &+ \frac{\mu_g+ 2L_g}{\mu_g}\|\nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2 + \|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2. \end{align*}\] From Theorem 4.6, we have \[\begin{align*} \mathbb{E}\!\left[\|\nabla_2 \bar{F}(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2^2 + \|\nabla_1 \bar{F}(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2^2\right] \le \epsilon^2. \end{align*}\] Hence, it follows that \(\mathbb{E}[\|\nabla_2 \bar{F}(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2] \le \epsilon\) and \(\mathbb{E}[\|\nabla_1 \bar{F}(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2] \le \epsilon\). If \(\lambda = O(1/\epsilon)\), then \(\mathbb{E}[\|\nabla F(\mathbf{w}_\tau)\|_2] \le O(\epsilon)\). The iteration complexity can be established by substituting \(\lambda = O(1/\epsilon)\) into Theorem 4.6 and noting that \(C_\Upsilon = O(1)\) when \(\|\mathbf{y}_0 - \mathbf{y}^*(\mathbf{w}_0)\|_2^2 \le O(\epsilon)\).

Critical

The complexity of \(O(1/\epsilon^7)\) is not the state-of-the-art sample complexity achievable under the same assumptions. Indeed, a double-loop large-batch method?similar to the one presented in Section 4.5.1.1 for solving the min-max problem \(\min_{\mathbf{w},\mathbf{u}} \max_{\mathbf{y}} f(\mathbf{w}, \mathbf{u}) + \lambda \left( g(\mathbf{w}, \mathbf{u}) - g(\mathbf{w}, \mathbf{y}) \right)\)–can yield a superior sample complexity of \(O(1/\epsilon^6)\) for achieving the stationarity condition \(\mathbb{E}[\|\nabla F(\mathbf{w})\|_2]\leq\epsilon^2\).

To see this, we apply a similar argument as in Section 4.5.1.1. Let \(F_\lambda(\mathbf{w}):=\min_{\mathbf{u}} \max_{\mathbf{y}} f(\mathbf{w}, \mathbf{u}) + \lambda \left( g(\mathbf{w}, \mathbf{u}) - g(\mathbf{w}, \mathbf{y}) \right)\). Then \[ \nabla F_\lambda(\mathbf{w})=\nabla_1 f(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w})) + \lambda \left(\nabla_1 g(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w})) - \nabla_1 g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\right). \]

Given \(\mathbf{w}_t\), let \(\mathbf{u}_t, \mathbf{u}_{\lambda, t}\) denote approximations of \(\mathbf{u}^*(\mathbf{w}_t), \mathbf{u}^*_\lambda(\mathbf{w}_t)\) and let \[\begin{align*} &\mathbf{v}_t = \frac{1}{B}\sum_{i=1}^B\nabla_1 f(\mathbf{w}_t, \mathbf{u}_{\lambda,t}; \zeta_i) + \lambda \frac{1}{B}\sum_{i=1}^B\left(\nabla_1 g(\mathbf{w}_t, \mathbf{u}_{\lambda,t}; \xi_i) - \nabla_1 g(\mathbf{w}_t, \mathbf{u}_t; \xi_i)\right),\\ &\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \mathbf{v}_t. \end{align*}\]

Then, we have \[ \|\mathbf{v}_t - \nabla F_\lambda(\mathbf{w}_t)\|_2^2 \leq \frac{O(\lambda^2)}{B} + O\left(\lambda^2\|\mathbf{u}_{\lambda,t}-\mathbf{u}^*_\lambda(\mathbf{w}_t)\|_2^2\right) + O\left(\lambda^2\|\mathbf{u}_t-\mathbf{u}^*(\mathbf{w}_t)\|_2^2\right). \]

Following Corollary 3.9, we apply SGD to solve \(\min_{\mathbf{u}} f(\mathbf{w}, \mathbf{u}) + \lambda g(\mathbf{w}, \mathbf{u})\) with \[ T_0=\mathcal{O}\left(\max\left\{\frac{L_f+\lambda L_g}{\lambda \mu_g\sqrt{\epsilon}}, \frac{\sigma_f^2 + \lambda^2 \sigma_g^2}{(\lambda\mu_g)^2 (\epsilon/\lambda)^2}\right\}\right)=O\left(\frac{\lambda^2}{\epsilon^2}\right), \] we have \(\mathbb{E}[\|\mathbf{u}_{\lambda,t} - \mathbf{u}^*_{\lambda}(\mathbf{w}_t)\|_2^2] \leq \frac{\epsilon^2}{\lambda^2}\).

Similarly, applying SGD to solve \(\min_{\mathbf{u}} g(\mathbf{w}, \mathbf{u})\) with \[ T_1=\mathcal{O}\left(\max\left\{\frac{L_g}{\mu_g\sqrt{\epsilon}}, \frac{\sigma_g^2}{\mu_g^2 (\epsilon/\lambda)^2}\right\}\right)=O\left(\frac{\lambda^2}{\epsilon^2}\right), \] we have \(\mathbb{E}[\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2] \leq \frac{\epsilon^2}{\lambda^2}\).

Hence, with \(B=O(\lambda^2/\epsilon^2)\), we have \(\|\mathbf{v}_t - \nabla F_\lambda(\mathbf{w}_t)\|_2^2 \leq O(\epsilon^2)\).

Then, with \(\eta=O(1/L_{F_\lambda})\) and \(T=O(L_{F_\lambda}/\epsilon^2)\) iterations, we can prove that \(\mathbb{E}[\|\nabla F_{\lambda}(\mathbf{w}_\tau)\|_2]\leq O(\epsilon)\) following Lemma 4.9. Since Lemma 4.26 and Lemma 4.27 indicate that \(\|\nabla F_{\lambda}(\mathbf{w}_\tau) - \nabla F(\mathbf{w}_\tau)\|_2 \leq O(1/\lambda)\), with \(\lambda=O(1/\epsilon)\), we have \(\mathbb{E}[\|\nabla F(\mathbf{w}_\tau)\|_2]\leq O(\epsilon)\).

The sample complexity is \[ BT + T_0T + T_1T = O\left(\frac{\lambda^2 L_{F_\lambda}}{\epsilon^4}\right) = O\left(\frac{L_{F_\lambda}}{\epsilon^6}\right). \]

We can establish \(L_{F_\lambda} = O(1)\) independent of \(\lambda\) (Chen et al. 2025, Lemma B.7). Hence the total sample complexity is \(O(1/\epsilon^6)\).

It remains an open problem to develop a single-loop stochastic algorithm that achieves \(O(1/\epsilon^6)\) complexity without requiring a large batch size or assuming mean-square smoothness (see next section for more discussion).

← Go Back