← Go Back

Section 4.5 Structured Optimization with Compositional Gradient

In this section, we extend the compositional optimization technique to address other structured optimization problems, including min-max optimization, min-min optimization, and bilevel optimization. These problems share a common structure in the form of a compositional gradient, denoted by \(\mathcal{M}(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\), where \(\mathcal{M}\) is a mapping that is Lipschitz continuous with respect to its second argument, and \(\mathbf{u}^*(\mathbf{w})\) is defined as the solution to a strongly convex optimization problem:

\[\begin{align} \label{eqn:auxu} \mathbf{u}^*(\mathbf{w}) = \arg\min_{\mathbf{u} \in \mathcal{U}} h(\mathbf{w}, \mathbf{u}). \end{align}\]

This structure generalizes the gradient of a compositional function \(f(g(\mathbf{w}))\), whose gradient takes the form \(\mathcal{M}(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) = \nabla g(\mathbf{w})\nabla f(\mathbf{u}^*(\mathbf{w}))\) with

\[\mathbf{u}^*(\mathbf{w}) = \arg\min_{\mathbf{u}} \|\mathbf{u} - g(\mathbf{w})\|_2^2.\]

The high-level idea underlying the algorithms and analysis presented below is summarized as follows. To estimate \(\mathcal{M}(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\) at \(\mathbf{w}_t\), we use an auxiliary variable \(\mathbf{u}_t\) to track the optimal solution \(\mathbf{u}^*(\mathbf{w}_t)\), which is defined by solving (\(\ref{eqn:auxu}\)) with one step update at \(\mathbf{w}_t\). A key aspect of the analysis is that the error in the approximation of \(\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t)\) is controlled by the estimation error \(\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2\), due to the Lipschitz continuity of \(\mathcal{M}\):

\[\begin{align} \|\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t) - \mathcal{M}(\mathbf{w}_t, \mathbf{u}^*(\mathbf{w}_t))\|_2^2 \leq O(\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2). \end{align}\]

Moreover, since \(\mathbf{u}^*(\mathbf{w})\) is the solution to a strongly convex problem and hence is Lipschitz continuous with respect to \(\mathbf{w}\), we can construct a recursion for \(\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2\) to effectively bound the cumulative error over iterations.

In cases where \(\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t)\) cannot be computed exactly and is instead approximated by a stochastic estimator \(\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t; \zeta_t)\), where \(\zeta_t\) is a random variable, we employ a moving average (MA) estimator:

\[\begin{align*} \mathbf{v}_{t} = (1 - \beta_t) \mathbf{v}_{t-1} + \beta_t\, \mathcal{M}(\mathbf{w}_t, \mathbf{u}_t; \zeta_t). \end{align*}\]

The model update is then performed using:

\[\begin{align*} \mathbf{w}_{t+1} = \mathbf{w}_t - \eta_t \mathbf{v}_{t}. \end{align*}\]

Alternatively, if \(\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t)\) is directly computable, the update simplifies to:

\[\begin{align*} \mathbf{w}_{t+1} = \mathbf{w}_t - \eta_t\, \mathcal{M}(\mathbf{w}_t, \mathbf{u}_t). \end{align*}\]

4.5.1 Non-convex Min-Max Optimization

We consider a non-convex min-max optimization problem:

\[\begin{align} \label{eqn:mm} \min_{\mathbf{w}\in\mathbb{R}^d}\max_{\mathbf{u}\in\mathcal U} f(\mathbf{w}, \mathbf{u}) := \mathbb{E}_{\xi}[f(\mathbf{w}, \mathbf{u}; \xi)], \end{align}\]

where \(f(\mathbf{w},\mathbf{u})\) is continuous and differentiable and \(\mathcal U\) is a closed convex set. Let \(F(\mathbf{w}) = \max_{\mathbf{u}\in\mathcal U} f(\mathbf{w}, \mathbf{u})\). Denote by \(\nabla_1f(\cdot,\cdot)\) and \(\nabla_2 f(\cdot, \cdot)\) the partial gradients of the first and second variable, respectively.

Assumption 4.8

Regarding the problem (\(\ref{eqn:mm}\)), the following conditions hold:

  1. \(f(\mathbf{w}, \mathbf{u})\) is \(\mu\)-strongly concave in terms of \(\mathbf{u}\), and \(\mathbf{u}^*(\mathbf{w})=\arg\max_{\mathbf{u}\in\mathcal U}f(\mathbf{w}, \mathbf{u})\) exists for any \(\mathbf{w}\).

  2. \(\nabla_1 f(\mathbf{w}, \mathbf{u})\) is \(L_1\)-Lipschitz continuous such that

\[\begin{align} \|\nabla_1 f(\mathbf{w}, \mathbf{u}) - \nabla_1 f(\mathbf{w}', \mathbf{u}')\|_2 \leq L_1(\|\mathbf{w} - \mathbf{w}'\|_2 + \|\mathbf{u} - \mathbf{u}'\|_2). \end{align}\]

  1. \(\nabla_2 f(\mathbf{w}, \mathbf{u})\) is \(L_{21}\)-Lipschitz continuous with respect to the first variable and is \(L_{22}\)-Lipschitz continuous with respect to the second variable

\[\begin{align} &\|\nabla_2 f(\mathbf{w}, \mathbf{u}) - \nabla_2 f(\mathbf{w}', \mathbf{u}')\|_2\leq L_{21} \|\mathbf{w} - \mathbf{w}'\|_2 +L_{22} \|\mathbf{u} - \mathbf{u}'\|_2. \end{align}\]

  1. there exist \(\sigma_1,\sigma_2\) such that

\[\begin{align} &\mathbb{E}[\|\nabla_1f(\mathbf{w}, \mathbf{u}; \xi)- \nabla_1 f(\mathbf{w}, \mathbf{u})\|_2^2]\leq \sigma_1^2,\\ &\mathbb{E}[\|\nabla_2f(\mathbf{w}, \mathbf{u}; \xi)- \nabla_2 f(\mathbf{w}, \mathbf{u})\|_2^2]\leq \sigma_2^2. \end{align}\]

  1. \(F_*=\min\limits_{\mathbf{w}} F(\mathbf{w})\geq -\infty\).

4.5.1.1 A Double-loop Large mini-batch method

Let us first consider a straightforward approach that updates \(\mathbf{w}_t\) using a large-batch gradient estimator

\[ \mathbf{v}_t = \frac{1}{B} \sum_{i=1}^B \nabla_1 f(\mathbf{w}_t, \mathbf{u}_t; \zeta_{i,t}), \]

and computes \(\mathbf{u}_t\) via an inner-loop SGD with \(K\) updates. It suffices to have \(K = O(L_1^2\sigma_2^2 / (\mu^2 \epsilon^2))\) (by Corollary 3.9) such that

\[\begin{align*} &\mathbb{E}[\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2] \leq \frac{\epsilon^2}{L_1^2}. \end{align*}\]

If \(B = O(\sigma_1^2/\epsilon^2)\), following Lemma 4.18 below we have

\[\begin{align*} &\mathbb{E}[\|{\mathbf{v}}_t - \nabla F(\mathbf{w}_t)\|^2_2]\leq \mathbb{E}\bigg[\bigg\|\frac{1}{B} \sum_{i=1}^B \nabla_1 f(\mathbf{w}_t, \mathbf{u}_t; \zeta_{i,t}) - \nabla_1 f(\mathbf{w}_t, \mathbf{u}^*(\mathbf{w}_t))\bigg\|_2^2\bigg] \\ &\leq O\left(\frac{\sigma_1^2}{B} + L_1^2\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2\right)\leq \epsilon^2. \end{align*}\]

Combining this with Lemma 4.9, we can set the step size \(\eta_t = O(1/L_F)\) and the number of iterations \(T = O(L_F/\epsilon^2)\), yielding an overall sample complexity of

\[BT + KT = O\left(\frac{L_F\sigma_1^2}{\epsilon^4} + \frac{L_FL_1^2\sigma_2^2}{\mu^2\epsilon^4}\right).\]

4.5.1.2 A Stochastic Momentum Method

We present a solution method in Algorithm 12, referred to as SMDA (Stochastic Momentum Descent-Ascent). The method begins by updating the dual variable using stochastic gradient ascent (Step 4), then computes the moving average gradient estimator \(\mathbf{v}_{t}\) for the primal variable (Step 6), and finally updates the primal variable using this estimator (Step 7). When \(\beta_t = 1\), the method reduces to Algorithm 7. However, setting \(\beta_t < 1\) is crucial for achieving improved complexity. Conceptually, the method shares similarities with SCMA.


Algorithm 12: SMDA

  1. Input: learning rate schedules \(\{\eta_t\}_{t=1}^{T}\), \(\{\gamma_t\}_{t=1}^{T}, \{\beta_t\}_{t=1}^T\); starting points \(\mathbf{w}_0\), \(\mathbf{u}_1, \mathbf{v}_0\)
  2. \(\mathbf{w}_1 = \mathbf{w}_0 - \eta_0\mathbf{v}_0\)
  3. For \(t=1,\dotsc,T\)
  4.  Sample \(\zeta_t\)
  5.  Update \(\mathbf{u}_{t+1} = \Pi_{\mathcal U}[\mathbf{u}_{t} + \gamma_t \nabla_2 f(\mathbf{w}_t, \mathbf{u}_t; \zeta_t)]\)
  6.  Compute the vanilla gradient estimator \(\mathbf{z}_t = \nabla_1 f(\mathbf{w}_t, \mathbf{u}_{t}; \zeta_t)\)
  7.  Update the MA gradient estimator \(\mathbf{v}_{t} = (1-\beta_t)\mathbf{v}_{t-1} + \beta_t \mathbf{z}_t\)
  8.  Update the model by \(\mathbf{w}_{t+1} = \mathbf{w}_t - \eta_t \mathbf{v}_{t}\)

Convergence Analysis

We will prove the convergence of the gradient norm of \(F(\mathbf{w})\). We first prove the following lemmas.

Lemma 4.17

Let \(\mathbf{u}^*(\mathbf{w})=\arg\max_{\mathbf{u}\in\mathcal U}f(\mathbf{w}, \mathbf{u})\). Under Assumption 4.8(i), (iii), \(\mathbf{u}^*(\cdot)\) is \(\kappa\)-Lipschitz continuous with \(\kappa=\frac{L_{21}}{\mu}\).

Proof

Let us consider \(\mathbf{w}_1, \mathbf{w}_2\). By the optimality condition of \(\mathbf{u}^*(\mathbf{w}_1)\) and \(\mathbf{u}^*(\mathbf{w}_2)\) for a concave function, we have \[\begin{align*} \nabla_2 f(\mathbf{w}_1, \mathbf{u}^*(\mathbf{w}_1))^{\top}(\mathbf{u} - \mathbf{u}^*(\mathbf{w}_1))\leq 0,\quad \forall \mathbf{u}\in\mathcal U\\ \nabla_2 f(\mathbf{w}_2, \mathbf{u}^*(\mathbf{w}_2))^{\top}(\mathbf{u} - \mathbf{u}^*(\mathbf{w}_2))\leq 0,\quad\forall \mathbf{u}\in\mathcal U. \end{align*}\] Let \(\mathbf{u}=\mathbf{u}^*(\mathbf{w}_2)\) in the first inequality and \(\mathbf{u}=\mathbf{u}^*(\mathbf{w}_1)\) in the second equality and add them together we have \[\begin{align*} (\nabla_2 f(\mathbf{w}_1, \mathbf{u}^*(\mathbf{w}_1))-\nabla_2 f(\mathbf{w}_2, \mathbf{u}^*(\mathbf{w}_2)))^{\top}(\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1))\leq 0. \end{align*}\] Since \(-f(\mathbf{w}_1,\cdot)\) is \(\mu\)-strongly convex, due to Lemma 1.6, we have \[\begin{align*} (\nabla_2 f(\mathbf{w}_1, \mathbf{u}^*(\mathbf{w}_1))&-\nabla_2 f(\mathbf{w}_1, \mathbf{u}^*(\mathbf{w}_2)))^{\top}(\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1))\\ &\geq \mu\|\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1)\|_2^2. \end{align*}\] Combining these two inequalities we have \[\begin{align*} \mu\|\mathbf{u}^*(\mathbf{w}_2) &- \mathbf{u}^*(\mathbf{w}_1)\|_2^2\leq (\nabla_2 f(\mathbf{w}_2, \mathbf{u}^*(\mathbf{w}_2))-\nabla_2 f(\mathbf{w}_1, \mathbf{u}^*(\mathbf{w}_2)))^{\top}(\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1))\\ &\leq \|\nabla_2 f(\mathbf{w}_2, \mathbf{u}^*(\mathbf{w}_2))-\nabla_2 f(\mathbf{w}_1, \mathbf{u}^*(\mathbf{w}_2))\|_2\|\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1)\|_2\\ &\leq L_{21}\|\mathbf{w}_2 - \mathbf{w}_1\|_2\|\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1)\|_2. \end{align*}\] Thus, \[\begin{align*} \|\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1)\|_2\leq \frac{L_{21}}{\mu}\|\mathbf{w}_2 - \mathbf{w}_1\|_2. \end{align*}\]

Lemma 4.18

Under Assumption 4.8(i) and (ii), \(\nabla F(\mathbf{w})=\nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\), and it is \(L_F\)-Lipschitz continuous with \(L_F=L_1(1+\kappa)\).

Proof

If \(\mathcal U\) is bounded, the Danskin’s theorem implies that \(\nabla F(\mathbf{w}) = \nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\). If \(\mathcal U\) is unbounded, we have \[\begin{align} \nabla F(\mathbf{w}) = \nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) + \frac{\partial \mathbf{u}^*(\mathbf{w})}{\partial \mathbf{w}}^{\top}\nabla_2 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) = \nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})), \end{align}\] where the last equality follows from \(\nabla_2 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))=0\). To establish the Lipschitz continuity of \(\nabla F(\mathbf{w})\), let us consider \(\mathbf{w}_1\) and \(\mathbf{w}_2\). We have \[\begin{align*} &\|\nabla F(\mathbf{w}_1)-\nabla F(\mathbf{w}_2)\|_2 =\|\nabla_1 f(\mathbf{w}_1, \mathbf{u}^*(\mathbf{w}_1)) -\nabla_1 f(\mathbf{w}_2, \mathbf{u}^*(\mathbf{w}_2))\|_2 \\ &\leq L_1(\|\mathbf{w}_1 - \mathbf{w}_2\|_2 + \|\mathbf{u}^*(\mathbf{w}_1) - \mathbf{u}^*(\mathbf{w}_2)\|_2)\leq L_1(1+\kappa)\|\mathbf{w}_1- \mathbf{w}_2\|_2. \end{align*}\]

Next, we prove two lemmas similar to Lemma 4.8 and Lemma 4.1, regarding the recursion of gradient estimation error and the estimation error of \(\mathbf{u}\), respectively. The descent lemma (Lemma 4.9) still holds.

Lemma 4.19

It holds that \[\begin{align*} &\mathbb{E}_{\xi_t}\left[\|\mathbf{v}_{t} - \nabla F(\mathbf{w}_t)\|^2_2\right] \leq (1-\beta_t)\|\mathbf{v}_{t-1} - \nabla F(\mathbf{w}_{t-1})\|_2^2 + \frac{2L_F^2}{\beta_t}\|\mathbf{w}_{t-1}-\mathbf{w}_t\|_2^2 \\ & + 4L_1^2\beta_t\|\mathbf{u}_{t} - \mathbf{u}^*(\mathbf{w}_t)\|_2^2 + \beta_t^2\sigma_1^2. \end{align*}\]

Proof

Let \(\mathbf{z}_t = \nabla_1 f(\mathbf{w}_t, \mathbf{u}_{t};\xi_t)\) and \(\mathcal{M}_t = \mathbb{E}_t[\mathbf{z}_t]= \nabla_1 f(\mathbf{w}_t, \mathbf{u}_{t})\). Then \(\mathbf{v}_{t} = (1-\beta_t)\mathbf{v}_{t-1} + \beta_t \mathbf{z}_t\). Noting that \(\mathbb{E}_t[\|\mathbf{z}_t - \mathcal{M}_t\|_2^2]\leq \sigma_1^2\) and \(\|\mathcal{M}_t - \nabla F(\mathbf{w}_t)\|_2^2 \leq L_1^2 \|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w_t})\|_2^2\). Plugging these into Lemma 4.7 finishes the proof.

Lemma 4.20

Suppose Assumption 4.8(i), (iii), (iv) hold. Consider the update \(\mathbf{u}_{t} = \Pi_{\mathcal U} [\mathbf{u}_t + \gamma_t \nabla_2 f(\mathbf{w}_t, \mathbf{u}_t; \zeta_t)]\). If \(\gamma_t<1/L_{22}<1/\mu\), we have \[\begin{equation*} \begin{split} \mathbb{E}_t[\|\mathbf{u}_{t+1} - \mathbf{u}^*(\mathbf{w}_{t+1})\|_2^2] & \leq (1-\frac{\gamma_t \mu}{2}) \|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2 + \frac{3\kappa^2}{\gamma_t \mu} \mathbb{E}_t[\|\mathbf{w}_t - \mathbf{w}_{t+1}\|_2^2]\\ & + 2\gamma_t^2 \sigma_2^2. \end{split} \end{equation*}\]

Proof

By See Lemma 3.8, if \(\gamma<1/L_{22}\) we have \[\begin{equation} \begin{split} \mathbb{E}_t[\|\mathbf{u}_{t+1} - \mathbf{u}^*(\mathbf{w}_t)\|_2^2]\leq (1-\gamma_t\mu)\|\mathbf{u}_{t} - \mathbf{u}^*(\mathbf{w}_t)\|_2^2 + \gamma_t^2\sigma_2^2. \end{split} \end{equation}\] Then, \[\begin{equation*} \begin{split} &\mathbb{E}_t[\|\mathbf{u}_{t+1} - \mathbf{u}^*(\mathbf{w}_{t+1})\|_2^2] \leq (1+\frac{\gamma_t \mu}{2})\mathbb{E}_t[\|\mathbf{u}_{t} - \mathbf{u}^*(\mathbf{w}_t)\|_2^2] \\ &+ (1+\frac{2}{\gamma_t\mu})\mathbb{E}_t[\|\mathbf{u}^*(\mathbf{w}_t) - \mathbf{u}^*(\mathbf{w}_{t+1})\|_2^2] \\ & \leq (1+\frac{\gamma_t \mu}{2}) (1-\gamma_t \mu) \|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2 + (1+\frac{\gamma_t \mu}{2}) \gamma_t^2 \sigma_2^2 \\ & + \frac{2+\gamma_t\mu}{\gamma_t\mu} \kappa^2 \mathbb{E}_t[\|\mathbf{w}_t - \mathbf{w}_{t+1}\|_2^2] \\ & \leq (1-\frac{\gamma_t \mu}{2}) \|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2 + 2\gamma_t^2 \sigma_2^2 + \frac{3\kappa^2}{\gamma_t \mu} \mathbb{E}_t[\|\mathbf{w}_t - \mathbf{w}_{t+1}\|_2^2], \end{split} \end{equation*}\] where the first inequality uses the Young’s inequality, and the last inequality uses \(\gamma\mu<1\).

Finally, we can prove the following theorem regarding the convergence of SMDA.

Theorem 4.5

Suppose Assumption 4.8 holds. By setting \(\beta_t=\beta= \epsilon^2/(3\sigma_1^2)\), \(\gamma_t=\gamma=\mu\epsilon^2/(96L_1^2\sigma_2^2)\) and \(\eta_t=\eta= \min(\frac{\beta}{\sqrt{8}L_F},\frac{\gamma\mu}{16\sqrt{3}L_1\kappa},\frac{1}{2L_F})\) in SMDA, then the following holds

\[\begin{align} \label{eqn:smda-c} \mathbb E\left[\frac{1}{T}\sum_{t=0}^{T-1}\left\{\frac{1}{4}\|\mathbf{v}_{t}\|_2^2 + \|\nabla F(\mathbf{w}_t)\|_2^2\right\}\right]\leq \epsilon^2, \end{align}\]

with an iteration complexity of

\[\begin{align} \label{eqn:smda-t} T&=O\left(\max\left\{\frac{C_\Upsilon L_F}{\epsilon^2}, \frac{C_\Upsilon\sigma_1^2L_F}{\epsilon^4}, \frac{C_\Upsilon L_1^3\kappa\sigma_2^2}{\epsilon^4\mu^2}\right\}\right), \end{align}\]

where \(C_\Upsilon= 2(F(\mathbf{w}_0) - F_*) + \frac{1}{\sqrt{8}L_F}\|\mathbf{v}_0 - \nabla F(\mathbf{w}_0)\|_2^2 + \frac{L_1}{\sqrt{3}\kappa}\|\mathbf{u}_0 - \mathbf{u}^*(\mathbf{w}_0)\|_2^2\).

💡 Why it matters
The MA gradient estimator in SMDA is critical to obtaining a complexity of \(O(1/\epsilon^4)\). If we simply update the primal variable by SGD, the algorithm becomes SGDA. The convergence analysis of SGDA for non-convex minimax problems will suffer from a large batch size issue or slow convergence. In particular, SGDA with a batch size of \(O(1/\epsilon^2)\) can find an \(\epsilon\)-stationary solution in \(O(1/\epsilon^2)\) iterations when the problem is smooth in terms of primal and dual variables and strongly-concave in terms of dual variable, yielding a sample complexity of \(O(1/\epsilon^4)\). If using a constant batch size \(O(1)\), SGDA may need \(O(1/\epsilon^8)\) iterations for finding an \(\epsilon\)-stationary solution (Lin et al. 2020).

Proof

The proof is similar to Theorem 4.3. Let us see the three inequalities in Lemma 4.9, Lemma 4.19, and Lemma 4.20 that we have proved so far: \[\begin{align*} (*)\;&F(\mathbf{w}_{t+1}) \leq F(\mathbf{w}_t) + \frac{\eta}{2} \|\nabla F(\mathbf{w}_t) - \mathbf{v}_{t}\|_2^2- \frac{\eta}{2}\|\nabla F(\mathbf{w}_t)\|_2^2-\frac{\eta}{4}\|\mathbf{v}_{t}\|_2^2,\\ (\sharp)\;&\mathbb{E}\left[\|\mathbf{v}_{t} - \nabla F(\mathbf{w}_t)\|_2^2\right] \leq \mathbb{E}\left[(1-\beta)\|\mathbf{v}_{t-1} - \nabla F(\mathbf{w}_{t-1})\|_2^2 + \frac{2L_F^2\eta^2}{\beta}\|\mathbf{v}_{t-1}\|_2^2\right] \\ & + \mathbb{E}\left[4L_1^2\beta\|\mathbf{u}_{t} - \mathbf{u}^*(\mathbf{w}_t)\|_2^2 + \beta^2\sigma_1^2\right],\\ (\diamond)\;&\mathbb{E}\|\mathbf{u}_{t} - \mathbf{u}^*(\mathbf{w}_{t})\|_2^2 \leq \mathbb{E}\left[(1-\frac{\gamma \mu}{2})\|\mathbf{u}_{t-1} - \mathbf{u}^*(\mathbf{w}_{t-1})\|_2^2 + 2\gamma^2 \sigma_2^2 + \frac{3\kappa^2\eta^2}{\gamma \mu} \|\mathbf{v}_{t-1}\|_2^2\right]. \end{align*}\] Let \(\bar\gamma=\gamma\mu/2\), the last inequality becomes: \[\begin{align*} (\diamond)\;&\mathbb{E}\|\mathbf{u}_{t} - \mathbf{u}^*(\mathbf{w}_{t})\|_2^2 \leq \mathbb{E}\left[(1-\bar\gamma) \|\mathbf{u}_{t-1} - \mathbf{u}^*(\mathbf{w}_{t-1})\|_2^2 + 8\bar\gamma^2 \frac{\sigma_2^2}{\mu^2} + \frac{3\kappa^2\eta^2}{2\bar\gamma} \|\mathbf{v}_{t-1}\|_2^2\right]. \end{align*}\] Let us define \(A_t= 2(F(\mathbf{w}_t) - F_*)\) and \(B_t = \|\nabla F(\mathbf{w}_t)\|_2^2\), \(\Gamma_t = \|\mathbf{v}_{t}\|_2^2/2\), \(\Delta_t= \|\nabla F(\mathbf{w}_t) - \mathbf{v}_{t}\|_2^2\), \(\delta_{t}=\|\mathbf{u}_{t} - \mathbf{u}^*(\mathbf{w}_{t})\|_2^2\). Then the three inequalities \((*),(\sharp),(\diamond)\) satisfy that in Lemma 4.10 with \(C_1=4L_1^2, C_2=4L_F^2, C_3=3\kappa^2, \sigma^2=\sigma_1^2, {\sigma'}^2=8\sigma_2^2/\mu^2\). If \(\eta, \beta, \bar\gamma\) satisfy \[\begin{align*} &\beta =\frac{\epsilon^2}{3\sigma^2} =\frac{\epsilon^2}{3\sigma_1^2},\quad \bar\gamma =\frac{\epsilon^2}{6C_1{\sigma'}^2}=\frac{\epsilon^2\mu^2}{192L_1^2\sigma_2^2}, \\ & \eta=\min(\frac{1}{2L_F},\frac{\beta}{\sqrt{4C_2}}, \frac{\bar\gamma}{\sqrt{8C_1C_3}}) = \min(\frac{1}{2L_F},\frac{\beta}{4L_F}, \frac{\bar\gamma}{\sqrt{96}L_1\kappa}), \end{align*}\] then (\(\ref{eqn:smda-c}\)) holds, and the iteration complexity becomes \[\begin{align*} T&=O\left(\max\left\{\frac{C_\Upsilon L_F}{\epsilon^2}, \frac{C_\Upsilon \sigma^2\sqrt{C_2}}{\epsilon^4}, \frac{C_\Upsilon\sqrt{C_1C_3}C_1{\sigma'}^2}{\epsilon^4}\right\}\right)\\ &=O\left(\max\left\{\frac{C_\Upsilon L_F}{\epsilon^2}, \frac{C_\Upsilon\sigma_1^2L_F}{\epsilon^4}, \frac{C_\Upsilon L_1^3\kappa\sigma_2^2}{\epsilon^4\mu^2}\right\}\right). \end{align*}\]

Critical

It is worth mentioning that an improved complexity of \(O(1/\epsilon^3)\) can be achieved by employing the STORM gradient estimator for both the primal and dual variables under the mean-square smooth condition of the objective.

4.5.2 Non-convex Min-Min Optimization

We can extend SMDA to solving a non-convex strongly-convex min-min problem: \[\begin{align}\label{eqn:minmin} \min_{\mathbf{w}\in\mathbb{R}^d}\min_{\mathbf{u}\in\mathcal U} f(\mathbf{w}, \mathbf{u}): = \mathbb{E}_{\xi}[f(\mathbf{w}, \mathbf{u}; \xi)], \end{align}\] where \(f(\mathbf{w},\mathbf{u})\) is smooth, non-convex in terms of \(\mathbf{w}\) and strongly convex in terms of \(\mathbf{u}\) and \(\mathcal U\) is a closed convex set. If the \(\mathbf{u}^*(\mathbf{w})=\arg\min_{\mathbf{u}\in\mathcal U} f(\mathbf{w}, \mathbf{u})\) exists and unique, then we have \(\nabla F(\mathbf{w}) = \nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\). Hence, its gradient also exhibits a compositional structure, where the inner function \(\mathbf{u}^*(\mathbf{w})\) is a solution to a strongly convex problem.

SMDA can be modified by replacing the \(\mathbf{u}\) update with \[\mathbf{u}_{t+1} = \Pi_{\mathcal U}[\mathbf{u}_{t} - \gamma_t \nabla_2 f(\mathbf{w}_t, \mathbf{u}_t; \zeta_t)].\] Then, the same convergence result in the last subsection can be established for min-min problem, which is omitted here.


Algorithm 13: A novel method for weakly convex minimization

  1. Input: learning rate schedules \(\{\eta_t\}_{t=1}^{T}\), \(\{\gamma_t\}_{t=1}^{T}\); starting points \(\mathbf{w}_1\), \(\mathbf{u}_1, \mathbf{v}_1\)
  2. For \(t=1,\dotsc, T\)
  3.  Sample \(\zeta_t\) and compute \(\mathcal{G}(\mathbf{u}_t; \zeta_t)=\partial g(\mathbf{u}_t; \zeta_t)\)
  4.  Update \(\mathbf{u}_{t+1} = \mathbf{u}_t - \gamma_t(\mathcal{G}(\mathbf{u}_t; \zeta_t) + \rho (\mathbf{u}_t - \mathbf{w}_t))\)
  5.  Update \(\mathbf{w}_{t+1} = (1-2\eta_t\rho)\mathbf{w}_t + 2\eta_t\rho \mathbf{u}_{t}\)

4.5.2.1 Application to weakly convex minimization

Next, we present an application to solving weakly convex minimization problems: \[\begin{align} \min_{\mathbf{w}} F(\mathbf{w}) := \mathbb{E}[g(\mathbf{w}; \zeta)], \end{align}\] where \(F(\mathbf w) > -\infty\) is \(\rho\)-weakly convex, as discussed in Chapter ch:2.

As argued in Section 3.1.4, an \(\epsilon\)-stationary solution of the Moreau envelope of \(F(\mathbf{w})\) corresponds to a nearly \(\epsilon\)-stationary solution of the original problem. Hence, we consider optimizing the Moreau envelope directly: \[\begin{align}\label{eqn:minmin-weak} \min_{\mathbf{w}} F_{\rho}(\mathbf{w}) := \min_{\mathbf{u}} \mathbb{E}[g(\mathbf{u}; \zeta)] + \rho \|\mathbf{u} - \mathbf{w}\|_2^2. \end{align}\] Define \(f(\mathbf{w}, \mathbf{u}) = \mathbb{E}[g(\mathbf{u}; \zeta)] + \rho \|\mathbf{u} - \mathbf{w}\|_2^2\). Then \(f(\mathbf{w}, \mathbf{u})\) is \(\rho\)-strongly convex with respect to \(\mathbf{u}\) due to the \(\rho\)-weak convexity of \(F\).

For updating \(\mathbf{u}\), we use the standard SGD: \[\begin{align} \mathbf{u}_{t+1} = \mathbf{u}_t - \gamma_t(\mathcal{G}(\mathbf{u}_t; \zeta_t) + 2\rho (\mathbf{u}_t - \mathbf{w}_t)). \end{align}\] where \(\mathcal{G}(\mathbf{u}_t; \zeta_t)\in\partial g(\mathbf{u}_t; \zeta_t)\). For updating \(\mathbf{w}\), then we just apply GD with its gradient given by \(\nabla_1 f(\mathbf{w}_t,\mathbf{u}_t) = 2\rho(\mathbf{w}_t - \mathbf{u}_t)\): \[\begin{align} \mathbf{w}_{t+1} = \mathbf{w}_t - \eta_t 2\rho (\mathbf{w}_t - \mathbf{u}_t) = (1-2\eta_t\rho)\mathbf{w}_t + 2\eta_t\rho\mathbf{u}_t. \end{align}\] We present the updates in Algorithm 13. An interesting observation about this algorithm is that the \(\mathbf{u}\) update is similar to the Momentum update except that the momentum term \(\mathbf{u}_t - \mathbf{u}_{t-1}\) is replaced by \(\mathbf{u}_t - \mathbf{w}_t\), where \(\mathbf{w}_t\) is a MA weight vector.

Convergence Analysis

Let us first prove the following lemma.

Lemma 4.21

We have (i) \(F_\rho\) is \(L_F\)-smooth with \(L_F=\frac{6}{\rho}\); (ii) \(\nabla_1 f(\mathbf{w}, \mathbf{u})\) is Lipschitz continuous with \(L_1=2\rho\), and (iii) \(\mathbf{u}^*(\mathbf{w})\) is \(1\)-Lipschitz continuous.

Proof

The smoothness of \(F_\rho\) has been proved in Proposition 3.1 with \(\lambda=\rho/2\). The Lipschitz continuity of \(\nabla_1 f(\mathbf{w}, \mathbf{u})= 2\rho(\mathbf{w} - \mathbf{u})\) is obvious. Next, let us prove the Lipschitz continuity of \(\mathbf{u}^*(\mathbf{w})\). The proof is similar to that of Lemma 4.17.

Let us consider \(\mathbf{w}_1, \mathbf{w}_2\). By the optimality condition of \(\mathbf{u}^*(\mathbf{w}_1)\) and \(\mathbf{u}^*(\mathbf{w}_2)\) for a concave function, there exists \(\mathbf{v}(\mathbf{w}_1)\in\partial_2 f(\mathbf{w}_1, \mathbf{u}^*(\mathbf{w}_1)), \mathbf{v}(\mathbf{w}_2)\in\partial_2 f(\mathbf{w}_2, \mathbf{u}^*(\mathbf{w}_2))\) \[\begin{align*} \mathbf{v}(\mathbf{w}_1)^{\top}(\mathbf{u} - \mathbf{u}^*(\mathbf{w}_1))\leq 0,\quad \forall \mathbf{u}\\ \mathbf{v}(\mathbf{w}_2)^{\top}(\mathbf{u} - \mathbf{u}^*(\mathbf{w}_2))\leq 0,\quad\forall \mathbf{u} \end{align*}\] Let \(\mathbf{u}=\mathbf{u}^*(\mathbf{w}_2)\) in the first inequality and \(\mathbf{u}=\mathbf{u}^*(\mathbf{w}_1)\) in the second equality and add them together we have \[\begin{align*} (\mathbf{v}(\mathbf{w}_1)-\mathbf{v}(\mathbf{w}_2))^{\top}(\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1))\leq 0. \end{align*}\] Since \(-f(\mathbf{w}_1,\cdot)\) is \(\rho\)-strongly convex, similar to Lemma 1.6, we have for any \(\mathbf{v}\in\partial_2 f(\mathbf{w}_1, \mathbf{u}^*(\mathbf{w}_2))\), \[\begin{align*} (\mathbf{v}(\mathbf{w}_1)&-\mathbf{v})^{\top}(\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1))\geq \rho\|\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1)\|_2^2. \end{align*}\] Combining these two inequalities we have \[\begin{align*} \rho\|\mathbf{u}^*(\mathbf{w}_2) &- \mathbf{u}^*(\mathbf{w}_1)\|_2^2\leq (\mathbf{v}(\mathbf{w}_2)-\mathbf{v})^{\top}(\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1))\\ &\leq \|\mathbf{v}(\mathbf{w}_2)-\mathbf{v}\|_2\|\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1)\|_2. \end{align*}\] Since there exists \(\mathbf{v}'\in\partial g(\mathbf{u}^*(\mathbf{w}_2))\) such that \(\mathbf{v}(\mathbf{w}_2) =\mathbf{v}'+ \rho (\mathbf{u}^*(\mathbf{w}_2) - \mathbf{w}_2)\), we let \(\mathbf{v} = \mathbf{v}' + \rho (\mathbf{u}^*(\mathbf{w}_2) - \mathbf{w}_1)\), then \[\begin{align*} \|\mathbf{u}^*(\mathbf{w}_2) - \mathbf{u}^*(\mathbf{w}_1)\|_2\leq \|\mathbf{w}_2 - \mathbf{w}_1\|_2. \end{align*}\]

Since \(\partial_2f(\mathbf{w},\mathbf{u})\) is not Lipschitz continuous with respect to \(\mathbf{u}\), Lemma 4.20 is not directly applicable. We develop a similar one below.

Lemma 4.22

Consider the following update: \[\begin{align*} \mathbf{u}_{t+1} = \mathbf{u}_t - \gamma_t(\mathcal{G}(\mathbf{u}_t; \zeta_t) + 2\rho (\mathbf{u}_t - \mathbf{w}_t)). \end{align*}\] If \(\mathbb{E}_\zeta[\|\mathcal{G}(\mathbf{u};\zeta)\|^2_2]\leq G^2\) and \(\gamma_t\rho<1/8\), then we have \[\begin{align*} & \mathbb{E}_t\|\mathbf{u}_{t+1} - \mathbf{u}^*(\mathbf{w}_{t+1})\|^2_2\\ & \leq \left(1-\frac{\gamma_t \rho}{2}\right) \|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2 + 8\gamma_t^2 G^2 + \frac{12}{\gamma_t \rho}\mathbb{E}_t\|\mathbf{w}_{t+1} -\mathbf{w}_t\|^2_2. \end{align*}\]

Proof

Since \(\mathbf{u}_{t+1}\) is one-step SGD update of \(f(\mathbf{w}_t, \mathbf{u})\), the proof is similar to Lemma 3.8 for the non-smooth case. \[\begin{align}\label{eqn:u-rec} & \|\mathbf{u}_{t+1} - \mathbf{u}^*(\mathbf{w}_t)\|^2_2 = \|\mathbf{u}_t -\gamma_t \left(\mathcal{G}(\mathbf{u}_t;\zeta_t)+2\rho(\mathbf{u}_t- \mathbf{w}_t)\right)- \mathbf{u}^*(\mathbf{w}_t)\|^2_2\\ & = \|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|^2_2 + \gamma_t^2\|\mathcal{G}(\mathbf{u}_t;\zeta_t)+2\rho(\mathbf{u}_t- \mathbf{w}_t)\|^2_2\notag\\ &- 2\gamma_t(\mathcal{G}(\mathbf{u}_t;\zeta_t)+2\rho(\mathbf{u}_t-\mathbf{w}_t))^{\top}(\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)).\notag \end{align}\] Note that \(0 \in \partial g(\mathbf{u}^*(\mathbf{w}_t)) + 2\rho(\mathbf{u}^*(\mathbf{w}_t) - \mathbf{w}_t)\). Thus, \(\mathbf{v}_{t-1}=2\rho(\mathbf{w}_t - \mathbf{u}^*(\mathbf{w}_t))\in \partial g(\mathbf{u}^*(\mathbf{w}_t))\), \[\begin{align*} \mathbb{E}_t\|\mathcal{G}(\mathbf{u}_t;\zeta_t)+2\rho(\mathbf{u}_t-\mathbf{w}_t)\|^2_2 & = \mathbb{E}_t\|\mathcal{G}(\mathbf{u}_t;\zeta_t) - \mathbf{v}_{t-1} +2\rho(\mathbf{u}_t-\mathbf{u}^*(\mathbf{w}_t))\|^2_2 \\ & \leq 2\mathbb{E}_t\|\mathcal{G}(\mathbf{u}_t;\zeta_t)-\mathbf{v}_{t-1}\|_2^2 + 8 \rho^2\|\mathbf{u}_t-\mathbf{u}^*(\mathbf{w}_t)\|_2^2\\ & \leq 8G^2 + 8 \rho^2\|\mathbf{u}_t-\mathbf{u}^*(\mathbf{w}_t)\|_2^2, \end{align*}\] where the last inequality uses \(\|\mathbf{v}_{t-1}\|_2\leq G\). For the last term in (\(\ref{eqn:u-rec}\)), let \(\tilde{\mathbf{v}}_{t} = \mathbb{E}[\mathcal{G}(\mathbf{u}_t;\zeta_t)]+2\rho(\mathbf{u}_t- \mathbf{w}_t)\in\partial_2 f(\mathbf{w}_t, \mathbf{u}_t)\), then we have \[\begin{align*} & \mathbb{E}_t(\mathcal{G}(\mathbf{u}_t;\zeta_t)+2\rho(\mathbf{u}_t- \mathbf{w}_t))^{\top}(\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)) = \tilde{\mathbf{v}}_{t}^{\top}(\mathbf{u}_t-\mathbf{u}^*(\mathbf{w}_t))\\ & = (\tilde{\mathbf{v}}_{t} - \mathbf{v}(\mathbf{w}_t))^{\top}(\mathbf{u}_t-\mathbf{u}^*(\mathbf{w}_t)) \geq \rho \|\mathbf{u}_t-\mathbf{u}^*(\mathbf{w}_t)\|_2^2. \end{align*}\] where \(\mathbf{v}(\mathbf{w}_t):=0\in\partial_2 f(\mathbf{w}_t, \mathbf{u}^*(\mathbf{w}_t))\) and the last inequality is due to the strong convexity of \(f\) in terms of \(\mathbf{u}\). Combining the above inequalities we have \[\begin{align*} & \|\mathbf{u}_{t+1} - \mathbf{u}^*(\mathbf{w}_t)\|^2_2 = \|\mathbf{u}_t -\gamma_t \left(\partial g(\mathbf{u}_t;\zeta_t)+2\rho(\mathbf{u}_t- \mathbf{w}_t)\right)- \mathbf{u}^*(\mathbf{w}_t)\|^2_2\\ & \leq \|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|^2_2 + \gamma_t^2(8G^2 + 8\rho^2 \|\mathbf{u}_t-\mathbf{u}^*(\mathbf{w}_t)\|_2^2) - 2\gamma_t\rho \|\mathbf{u}_t-\mathbf{u}^*(\mathbf{w}_t)\|_2^2\\ &=(1-2\gamma_t\rho + 8\gamma_t^2\rho^2)\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2 + 8\gamma_t^2G^2\\ &\leq (1-\gamma_t\rho )\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2 + 8\gamma_t^2G^2 \end{align*}\] where the last inequality uses \(\gamma_t\leq \frac{1}{8\rho}\). Since \(\mathbf{u}^*(\mathbf{w})\) is \(1\)-Lipschitz continuous, we have \[\begin{align*} & \mathbb{E}_t\|\mathbf{u}_{t+1} - \mathbf{u}^*(\mathbf{w}_{t+1})\|^2_2\\ & \leq \left(1 + \frac{\gamma_t\rho}{2}\right) \mathbb{E}_t\|\mathbf{u}_{t+1} - \mathbf{u}^*(\mathbf{w}_t)\|^2_2 + \left(1+\frac{2}{\gamma_t\rho}\right) \|\mathbf{u}^*(\mathbf{w}_{t+1}) - \mathbf{u}^*(\mathbf{w}_t)\|^2_2 \\ & \leq \left(1-\frac{\gamma_t \rho}{2}\right) \|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2 + 16\gamma_t^2 G^2 + \frac{3}{\gamma_t \rho}\mathbb{E}_t\|\mathbf{w}_{t+1} -\mathbf{w}_t\|^2_2. \end{align*}\]

Lemma 4.23

Let \(\mathbf{z}_t = 2\rho(\mathbf{w}_t - \mathbf{u}_t)\). For the update \(\mathbf{w}_{t+1} = \mathbf{w}_t -\eta_t \mathbf{z}_t\), if \(\eta_t\leq 1/(2L_F)\), we have \[\begin{align}\label{eqn:nasa_starter} F_\rho(\mathbf{w}_{t+1})& \leq F_\rho(\mathbf{w}_t) + \frac{\eta_t}{2} \|\nabla F_\rho(\mathbf{w}_t) - \mathbf{z}_{t}\|_2^2- \frac{\eta_t}{2}\|\nabla F_\rho(\mathbf{w}_t)\|_2^2 - \frac{1}{4\eta_t}\|\mathbf{w}_{t+1} - \mathbf{w}_t\|_2^2\notag, \end{align}\] where \(L_F\) is the smoothness parameter of \(F_\rho(\cdot)\).

Since \(\nabla F_\rho(\mathbf{w}_t) = 2\rho(\mathbf{w}_t - \mathbf{u}^*(\mathbf{w}_t))\), hence \(\|\nabla F_\rho(\mathbf{w}_t) -\mathbf{z}_t\|_2^2 = 4\rho^2\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2\), whose recursion has been established in Lemma 4.22. We can combine these two lemmas and establish a complexity of \(O(1/\epsilon^4)\) for Algorithm 13 in order to find an \(\epsilon\)-stationary solution to \(F_\rho(\cdot)\).

4.5.2.2 Application to weakly-convex strongly-concave min-max problems

The same technique can be applied to solving weakly-convex strongly-concave min-max problems \(\min_{\mathbf{w}}\max_{\mathbf{u}\in\mathcal U}f(\mathbf{w}, \mathbf{u})\) with a single loop algorithm. In subsection 4.5.1, we assume the partial gradient \(\nabla_1 f(\mathbf{w}, \mathbf{u})\) is Lipschitz continuous. We replace this assumption by an assumption that \(f(\mathbf{w}, \mathbf{u})\) is \(\rho\)-weakly convex in terms of \(\mathbf{w}\) for any \(\mathbf{u}\in\mathcal U\).

In this case, \(F(\mathbf{w})=\max_{\mathbf{u}\in\mathcal U}f(\mathbf{w}, \mathbf{u})\) is not smooth but weakly convex. Let us consider its Moreau envelope: \[\begin{align*} \min_{\mathbf{w}} F_\rho(\mathbf{w}): = \min_{\mathbf{u}_1} F(\mathbf{u}_1) + \rho \|\mathbf{u}_1 - \mathbf{w}\|_2^2. \end{align*}\] This problem is equivalent to \[\begin{align*} \min_{\mathbf{w}, \mathbf{u}_1} \max_{\mathbf{u}_2\in\mathcal U} f(\mathbf{u}_1, \mathbf{u}_2) + \rho \|\mathbf{u}_1 - \mathbf{w}\|_2^2, \end{align*}\] which is strongly convex in terms of \(\mathbf{u}_1\) and strongly concave in terms of \(\mathbf{u}_2\).

Compared to (\(\ref{eqn:minmin-weak}\)), this problem just adds another layer of inner maximization. However, it can be still mapped to the general framework as discussed at the beginning. The gradient of \(F_\rho(\mathbf{w})\) is given by \(\mathcal{M}(\mathbf{w}, \mathbf{u}_1^*(\mathbf{w}))=\rho (\mathbf{w} - \mathbf{u}_1^*(\mathbf{w}))\). If we track \(\mathbf{u}_1^*(\mathbf{w}_t)\) by \(\mathbf{u}_{1,t}\) and its update relies on the gradient \(\partial_1 f(\mathbf{u}_{1,t}, \mathbf{u}_{2}^*(\mathbf{u}_{1,t}))\). Hence, we just need another variable \(\mathbf{u}_{2,t}\) to track \(\mathbf{u}_2^*(\mathbf{u}_{1,t})\).

We can develop a similar algorithm. First, let us update \(\mathbf{u}_1,\mathbf{u}_2\). Given \(\mathbf{w}_t, \mathbf{u}_{1,t}, \mathbf{u}_{2,t}\), we update \(\mathbf{u}_{1,t+1}, \mathbf{u}_{2,t+1}\) with SGD update by \[\begin{align} \mathbf{u}_{2,t+1} & = \Pi_{\mathcal U}[\mathbf{u}_{2,t} + \gamma_2 \partial_2 f(\mathbf{u}_{1,t}, \mathbf{u}_{2,t}; \zeta_t)]\\ \mathbf{u}_{1,t+1} &= \mathbf{u}_{1,t} - \gamma_1(\partial_1f(\mathbf{u}_{1,t}, \mathbf{u}_{2,t}; \zeta_t)+2\rho (\mathbf{u}_{1,t} - \mathbf{w}_t)). \end{align}\] Then we update \(\mathbf{w}_{t+1}\) with GD update by \[\begin{align} \mathbf{w}_{t+1} = \mathbf{w}_t - \eta 2\rho(\mathbf{w}_t - \mathbf{u}_{1,t}) = (1-2\eta\rho) \mathbf{w}_t + 2\eta\rho\mathbf{u}_{1,t}. \end{align}\] This algorithm also enjoys a complexity of \(O(1/\epsilon^4)\) for finding a nearly \(\epsilon\)-stationary solution of \(F(\mathbf{w})\). We refer the readers to Hu et al. (2024) for a convergence analysis of this algorithm.

4.5.2.3 Application to Compositional Optimization

We can apply a similar strategy to a compositional function \(F(\mathbf{w})=f_0(g(\mathbf{w}))\), where \(f_0\) is smooth convex and \(g\) is weakly convex. With the conjugate of \(f_0\), we can write \[\min_{\mathbf{w}}f_0(g(\mathbf{w})) = \min_{\mathbf{w}}\max_{\mathbf{u}_2\in\mathcal U}f(\mathbf{w},\mathbf{u}_2) := \mathbf{u}_2^{\top} g(\mathbf{w}) - f_0^*(\mathbf{u}_2).\] Since \(f_0\) is smooth, then \(f_0^*\) is strongly convex. Then if \(g\) is weakly convex and \(\mathcal U\) is bounded (i.e., \(f_0\) is Lipschitz), then \(f(\mathbf{w},\mathbf{u})\) is weakly convex and strongly concave. Optimizing the Moreau envelope of \(f_0(g(\mathbf{w}))\) yields: \[\min_{\mathbf{w}, \mathbf{u}_1}\max_{\mathbf{u}_2\in\mathcal U}\mathbf{u}_2^{\top} g(\mathbf{u}_1) - f_0^*(\mathbf{u}_2) + \rho\|\mathbf{u}_1 - \mathbf{w}\|_2^2,\] which is strongly convex in terms of \(\mathbf{u}_1\) and strongly concave in terms of \(\mathbf{u}_2\). We give an update below: \[\begin{align*} \mathbf{u}_{2,t+1} & = \Pi_{\mathcal U}[\mathbf{u}_{2,t} + \gamma_2 g(\mathbf{u}_{1,t}; \zeta_t)]\\ \mathbf{u}_{1,t+1} &= \mathbf{u}_{1,t} - \gamma_1(\partial_1 g(\mathbf{u}_{1,t}; \zeta_t)\mathbf{u}_{2,t}+2\rho (\mathbf{u}_{1,t} - \mathbf{w}_t))\\ \mathbf{w}_{t+1} &= \mathbf{w}_t - \eta 2\rho(\mathbf{w}_t - \mathbf{u}_{1,t}) = (1-2\eta\rho) \mathbf{w}_t + 2\eta\rho\mathbf{u}_{1,t}. \end{align*}\] Then similar convergence analysis can be developed with a complexity of \(O(1/\epsilon^4)\) for finding a nearly \(\epsilon\)-stationary solution to \(F\).

4.5.3 Non-convex Bilevel Optimization

In this section, we discuss the application of the compositional gradient estimation technique to non-convex bilevel optimization defined by

\[\begin{equation}\label{eqn:bi} \begin{aligned} \min_{\mathbf{w}\in\mathbb{R}^d}& f(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\\ \mathbf{u}^*(\mathbf{w})&=\arg\min_{\mathbf{u}\in\mathbb{R}^{d'}} g(\mathbf{w}, \mathbf{u}), \end{aligned} \end{equation}\]

where \(g\) is twice differentiable and \(\mu_g\)-strongly convex in terms of \(\mathbf{u}\). Let \(F(\mathbf{w}) =f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})).\) The following lemma states the gradient of the objective \(F(\mathbf{w})\).

Lemma 4.24 We have \[\begin{align*} \nabla F(\mathbf{w}) & = \nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) - \nabla_{21} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))^{\top}(\nabla_{22} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})))^{-1}\nabla_2 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})). \end{align*}\]

By the optimality condition of \(\mathbf{u}^*(\mathbf{w})\), we have \[\begin{align*} \nabla_2 g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))=0. \end{align*}\] By taking derivative on both sides, using the chain rule, and the implicit function theorem, we obtain \[\begin{align*} \nabla_{21} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) + \nabla_{22} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\frac{\partial \mathbf{u}^*(\mathbf{w})}{\partial \mathbf{w}}=0. \end{align*}\] Hence \[\begin{align*} \frac{\partial \mathbf{u}^*(\mathbf{w})}{\partial \mathbf{w}}=-(\nabla_{22} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})))^{-1}\nabla_{21} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})). \end{align*}\] Thus, \[\begin{align*} \nabla F(\mathbf{w}) & = \nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) + \frac{\partial \mathbf{u}^*(\mathbf{w})}{\partial \mathbf{w}}^{\top}\nabla_2 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\\ & = \nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) - \nabla_{21} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))^{\top}(\nabla_{22} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})))^{-1}\nabla_2 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})). \end{align*}\]

Let us define \[\begin{align*} &\mathcal{M}(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))=\nabla_1 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) - \nabla_{21} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))^{\top}(\nabla_{22} g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})))^{-1}\nabla_2 f(\mathbf{w}, \mathbf{u}^*(\mathbf{w})). \end{align*}\]

If we can establish the Lipschitz continuity of \(\mathcal{M}(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\) in terms of the second argument and the Lipschitz continuity of \(\mathbf{u}^*(\mathbf{w})\), then the similar technique can be leveraged. Let \(\mathbf{u}^*(\mathbf{w}_t)\) be tracked by \(\mathbf{u}_t\). It can be updated by SGD: \[\begin{align} \mathbf{u}_{t+1} = \mathbf{u}_t - \gamma_t \nabla_2 g(\mathbf{w}_t, \mathbf{u}_t; \zeta_t). \end{align}\]

With \(\mathbf{u}_t\), the gradient at \(\mathbf{w}_t\) can be estimated by \[\begin{align} \mathcal{M}(\mathbf{w}_t, \mathbf{u}_t)=\nabla_1 f(\mathbf{w}_t, \mathbf{u}_t) +\nabla_{21} g(\mathbf{w}_t, \mathbf{u}_t)^{\top}(\nabla_{22} g(\mathbf{w}_t, \mathbf{u}_t))^{-1}\nabla_2 f(\mathbf{w}_t, \mathbf{u}_t). \end{align}\] However, another challenge is to handle the Hessian inverse \((\nabla_{22} g(\mathbf{w}_t, \mathbf{u}_t)^{-1}\), which itself is a compositional structure. We will discuss three different ways to tackle this challenge. If we have a stochastic estimator of \(\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t)\) denoted by \(\mathbf{v}_{t}\), then we update the model parameter by: \[\begin{align} \mathbf{w}_{t+1} = \mathbf{w}_t - \eta_t \mathbf{v}_{t}. \end{align}\]

4.5.3.1 Approach 1: The MA Estimator

If the lower level problem is low-dimensional such that the inverse of the Hessian matrix can be efficiently computed, we can estimate \(\nabla_{22} g(\mathbf{w}_t, \mathbf{u}_t)\) by a MA estimator: \[\begin{align*} H_{22,t} = S_{\mu_g}[(1-\beta) H_{22,t-1} + \beta \nabla_{22} g(\mathbf{w}_t, \mathbf{u}_t; \zeta_{2,t})]. \end{align*}\] where \(S_{\mu_g}[\cdot]\) is a projection operator that projects a matrix into a matrix whose minimum eigen-value is lower bounded by \(\mu_g\), where \(\mu_g\) is the lower bound of eigen-values of \(\nabla_{22} g(\mathbf{w}, \mathbf{u})\). The projection ensures that \([H_{22,t}]^{-1}\) is Lipschitz continuous with respect to \(H_{22,t}\).

The a vanilla stochastic gradient estimator of \(\mathbf{w}_t\) and its MA estimator are computed by

\[\begin{equation}\label{eqn:bio-hma} \begin{aligned} &\mathbf{z}_{t} = \nabla_1 f(\mathbf{w}_t, \mathbf{u}_t; \xi_{t}) +\nabla_{21} g(\mathbf{w}_t, \mathbf{u}_t; \zeta'_{2,t})^{\top}(H_{22,t})^{-1}\nabla_2 f(\mathbf{w}_t, \mathbf{u}_t; \xi_{t})\\ &\mathbf{v}_{t} =(1-\beta)\mathbf{v}_{t-1} + \beta \mathbf{z}_t. \end{aligned} \end{equation}\]

Convergence Analysis

The proof is largely similar to that of Theorem 4.3. We provide a sketch of proof below. Recall that \[\begin{align*} &\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t)=\nabla_1 f(\mathbf{w}_t, \mathbf{u}_t) +\nabla_{21} g(\mathbf{w}_t, \mathbf{u}_t)^{\top}(\nabla_{22} g(\mathbf{w}_t, \mathbf{u}_t))^{-1}\nabla_2 f(\mathbf{w}_t, \mathbf{u}_t). \end{align*}\] Define: \[\begin{align*} \hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t)=\nabla_1 f(\mathbf{w}_t, \mathbf{u}_t) +\nabla_{21} g(\mathbf{w}_t, \mathbf{u}_t)^{\top} H_{22,t}^{-1}\nabla_2 f(\mathbf{w}_t, \mathbf{u}_t). \end{align*}\]

First, similar to Lemma 4.9, we have the following: \[\begin{align} F(\mathbf{w}_{t+1})& \leq F(\mathbf{w}_t) + \frac{\eta_t}{2} \|\mathbf{v}_t - \nabla F(\mathbf{w}_t)\|_2^2-\frac{\eta_t}{2}\|\nabla F(\mathbf{w}_t)\|_2^2- \frac{1}{4\eta_t} \|\mathbf{w}_{t+1} - \mathbf{w}_t\|_2^2. \end{align}\]

We establish a recursion of the error \(\|\mathbf{v}_t - \nabla F(\mathbf{w}_t)\|_2^2\) similar to Lemma 4.7 by noting that \(\mathbb E_{\xi_t, \zeta'_{2,t}}[\mathbf{z}_t]=\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t)\) and there exists \(\sigma>0\) such that \(\mathbb E_{\xi_t, \zeta'_{2,t}}[\|\mathbf{z}_t - \hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t)\|_2^2]\leq \sigma^2\). Thus, Lemma 4.7 implies that

\[\begin{align}\label{eqn:bio-v-err} &\mathbb E_{\xi_t, \zeta'_{2,t}}\left[\|\mathbf{v}_{t} - \nabla F(\mathbf{w}_t)\|_2^2\right] \leq (1-\beta_t)\|\mathbf{v}_{t-1} - \nabla F(\mathbf{w}_{t-1})\|_2^2 \\ & \quad + \frac{2L_F^2}{\beta_t}\|\mathbf{w}_{t-1}-\mathbf{w}_t\|_2^2 + 4\beta_t\|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t) - \nabla F(\mathbf{w}_t)\|_2^2 + \beta_t^2\sigma^2.\notag \end{align}\]

Then, we bound \(\|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t)-\nabla F(\mathbf{w}_t)\|_2^2\) by \[\begin{align*} \|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t)-\nabla F(\mathbf{w}_t)\|_2^2&\leq 2\|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t) -\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t) \|^2_2\\ & +2\|\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t) - \nabla F(\mathbf{w}_t)\|^2_2\\ & \leq O(\|H_{22,t} - \nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\|_2^2) + O(\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|^2_2). \end{align*}\]

As a result, we have \[\begin{align*} \mathbb E[\|\mathbf{v}_t - &\nabla F(\mathbf{w}_t)\|_2^2] \leq (1-\beta_t)\|\mathbf{v}_{t-1}- \nabla F(\mathbf{w}_{t-1})\|^2 + \frac{2L_F^2}{\beta_t}\|\mathbf{w}_t - \mathbf{w}_{t-1}\|_2^2\\ &+ \beta_t (O(\|H_{22,t} - \nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\|_2^2) + O(\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|^2_2)) + \beta_t^2O(\sigma^2). \end{align*}\] This result is similar to that in Lemma 4.8.

We can further build the error recursion of \(\|H_{22,t} - \nabla_{22}g(\mathbf{w}_t,\mathbf{u}_t)\|_2^2\) similar to Lemma 4.1, and the error recursion of \(\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|^2_2\) similar to Lemma 4.20. Combining these results, we can establish a complexity of \(O(1/\epsilon^4)\) for finding an \(\epsilon\)-stationary solution of \(F(\cdot)\) in expectation.

4.5.3.2 Approach 2: The Neumann Series (Matrix Taylor Approximation)

If the lower level problem is high-dimensional such that it is prohibited to compute the Hessian, one approach is to leverage the Neuman series: \[\begin{align} A^{-1}= \sum_{i=0}^\infty (I -A)^i, \quad \text{ if }\quad \|A\|\leq 1. \end{align}\] Hence, if \(\|\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\|\leq L_{22}\), we estimate the inverse of \(\frac{1}{L_{22}}\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\), yielding \[\begin{align} \left(\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\right)^{-1}\approx \frac{1}{L_{22}}\sum_{i=0}^{K-1} \left(I - \frac{1}{L_{22}}\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\right)^i. \end{align}\] This can be further estimated by a stochastic route, by sampling \(k\) from \(\{0,\ldots,K- 1\}\) randomly, then estimate the Hessian inverse by \[\begin{align} Q_{22,t} = \left\{\begin{array}{ll}\frac{K}{L_{22}}\prod_{i=1}^{k} \left(I - \frac{1}{L_{22}}\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t; \zeta_{i})\right)& \text{ if } k\geq 1\\ \frac{K}{L_{22}}I & \text{ if } k=0\end{array}\right.. \end{align}\] This is can be justified by \[\begin{align*} \mathbb{E}[Q_{22,t}] &= \frac{1}{K}\frac{K}{L_{22}}I + \frac{K-1}{K}\mathbb{E}_{k\sim\{1,\ldots, K-1\}}\left[\frac{K}{L_{22}}\prod_{i=1}^{k} \left(I - \frac{1}{L_{22}}\mathbb{E}[\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t; \zeta_{i})]\right)\right]\\ & = \mathbb{E}_{k}\frac{K}{L_{22}}\left(I - \frac{1}{L_{2}}\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\right)^k = \sum_{k=0}^{K-1}\frac{1}{L_{22}}\left(I -\frac{1}{L_{22}}\nabla_{22}g(\mathbf{w}_t, \mathbf{u}_t)\right)^k. \end{align*}\]

Then the vanilla gradient estimator of \(\mathbf{w}_t\) and its MA estimator are computed by

\[\begin{equation}\label{eqn:bio-neu} \begin{aligned} &\mathbf{z}_{t} = \nabla_1 f(\mathbf{w}_t, \mathbf{u}_t; \zeta_{1,t}) +\nabla_{21} g(\mathbf{w}_t, \mathbf{u}_t; \zeta'_{2,t})^{\top}Q_{22,t}\nabla_2 f(\mathbf{w}_t, \mathbf{u}_t; \zeta_{1,t})\\ &\mathbf{v}_{t} =(1-\beta)\mathbf{v}_{t-1} + \beta \mathbf{z}_t. \end{aligned} \end{equation}\]

Convergence Analysis

We provide a proof sketch below. We can understand that \(\mathbf{z}_t\) is an unbiased stochastic estimator of \[\begin{align*} \hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t)=\nabla_1 f(\mathbf{w}_t, \mathbf{u}_t) +\nabla_{21} g(\mathbf{w}_t, \mathbf{u}_t)^{\top}\mathbb{E}[Q_{22,t}]\nabla_2 f(\mathbf{w}_t, \mathbf{u}_t). \end{align*}\] We decompose the estimation error of \(\mathbf{v}_t\) similarly as in (\(\ref{eqn:bio-v-err}\)) and bound \(\|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t) - \nabla F(\mathbf{w}_t)\|_2^2\) by \[\begin{align*} \|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t) - \nabla F(\mathbf{w}_t)\|_2^2 &\leq 2\|\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t) -\nabla F(\mathbf{w}_t)\|_2^2 \\ & + 2\|\hat{\mathcal{M}}(\mathbf{w}_t, \mathbf{u}_t) - \mathcal{M}(\mathbf{w}_t,\mathbf{u}_t)\|_2^2 . \end{align*}\] The error recursion of the first term on the right hand side can be similarly bounded as before. To bound the last error, since \[\begin{equation*} \left[\nabla_{22}^2 g(\mathbf{w}, \mathbf{u})\right]^{-1} = \mathbb{E}[Q_{22}] + \frac{1}{L_{22}} \sum_{i=K}^{\infty} \left[I - \frac{1}{L_{22}} \nabla_{22}^2 g(\mathbf{w}, \mathbf{u})\right]^i, \end{equation*}\] we have \[\begin{align*} &\|\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t) - \hat{\mathcal{M}}(\mathbf{w}_t,\mathbf{u}_t)\|_2^2\leq O(\|\left[\nabla_{22}^2 g(\mathbf{w}, \mathbf{u})\right]^{-1} - \mathbb{E}[Q_{22}]\|_2^2),\\ &\left\| \left[\nabla_{22}^2 g(\mathbf{w}, \mathbf{u})\right]^{-1} - \mathbb{E}[Q_{22}] \right\|_2 \le \frac{1}{L_{22}} \sum_{i=K}^{\infty} \left\|I - \frac{1}{L_{22}} \nabla_{22}^2 g(\mathbf{w}, \mathbf{u})\right\|_2^i \le \frac{1}{\mu_g} \left(1 - \frac{\mu_g}{L_{22}}\right)^K. \end{align*}\] As a result, if \(K=O(\frac{L_{22}}{\mu_g}\log(1/(\mu_g\beta_t\sigma^2)))\), then \(\|\mathcal{M}(\mathbf{w}_t, \mathbf{u}_t) - \mathcal{M}'(\mathbf{w}_t,\mathbf{u}_t)\|_2^2\leq O(\beta_t\sigma^2)\). Then similar to the analysis of Approach 1, we can establish a complexity of \(O(1/\epsilon^4)\) for finding an \(\epsilon\)-stationary solution of \(F(\cdot)\) in expectation.

4.5.3.3 Approach 3: The penalty method

An alternative approach to avoid computing the Hessian inverse and Jacobian matrices is to reformulate the problem as a constrained optimization problem: \[\begin{align*} \min_{\mathbf{w},\mathbf{u}} &\quad f(\mathbf{w}, \mathbf{u}) \\ \text{s.t.} &\quad g(\mathbf{w}, \mathbf{u}) \leq \min_{\mathbf{u}} g(\mathbf{w}, \mathbf{u}). \end{align*}\] This constrained problem can be addressed using a penalty method (see Section 6.7): \[\begin{align*} \min_{\mathbf{w},\mathbf{u}} \; f(\mathbf{w}, \mathbf{u}) + \lambda \big( g(\mathbf{w}, \mathbf{u}) - \min_{\mathbf{y}} g(\mathbf{w}, \mathbf{y}) \big)_+, \end{align*}\] where \(\lambda > 0\) is a penalty parameter and \((\cdot)_+\) denotes the positive part. Since \(g(\mathbf{w}, \mathbf{u}) \geq \min_{\mathbf{y}} g(\mathbf{w}, \mathbf{y})\), the formulation simplifies to:

\[\begin{align}\label{eqn:bio-minmax} &\min_{\mathbf{w},\mathbf{u}} \; f(\mathbf{w}, \mathbf{u}) + \lambda \left( g(\mathbf{w}, \mathbf{u}) - \min_{\mathbf{y}} g(\mathbf{w}, \mathbf{y}) \right)\\ &= \min_{\mathbf{w},\mathbf{u}} \max_{\mathbf{y}} f(\mathbf{w}, \mathbf{u}) + \lambda \left( g(\mathbf{w}, \mathbf{u}) - g(\mathbf{w}, \mathbf{y}) \right). \end{align}\]

If both \(f\) and \(g\) are smooth and \(g\) is strongly convex in its second argument, the resulting formulation becomes a non-convex strongly-concave min-max problem, which can be effectively addressed using the SMDA algorithm with the following update for \(t\geq 1\):

\[\begin{equation}\label{eqn:bi-smda} \begin{aligned} &\mathbf{y}_{t+1} = \mathbf{y}_{t} + \gamma_t\lambda\nabla_2 g(\mathbf{w}_t, \mathbf{y}_{t}; \xi_t),\\ &\mathbf{z}_t = \nabla f(\mathbf{w}_t, \mathbf{u}_{t}; \zeta_t) + \lambda \left(\nabla g(\mathbf{w}_t, \mathbf{u}_{t}; \xi_t) - \left[\nabla_1 g(\mathbf{w}_t, \mathbf{y}_{t}; \xi_t)\atop 0\right]\right),\\ &\mathbf{v}_{t} = (1-\beta_t)\mathbf{v}_{t-1} + \beta_t \mathbf{z}_t, \\ &\left[\mathbf{w}_{t+1}\atop \mathbf{u}_{t+1}\right] = \left[\mathbf{w}_t\atop \mathbf{u}_{t}\right] - \eta_t \mathbf{v}_{t}. \end{aligned} \end{equation}\]

Convergence Analysis

The convergence analysis of (\(\ref{eqn:bi-smda}\)) for the min–max problem (\(\ref{eqn:bio-minmax}\)) follows a similar approach to that of Theorem 4.5 for SMDA. However, a remaining challenge lies in converting the convergence result for the min–max formulation into that of the original problem. To address this, we provide the detailed convergence analysis below. We begin by stating the following assumption.

Assumption 4.9

Regarding the problem (\(\ref{eqn:bi}\)), the following conditions hold:

\[\begin{align} \|\nabla f(\mathbf{w}_1, \mathbf{u}_1) - \nabla f(\mathbf{w}_2, \mathbf{u}_2)\|_2 \leq L_f\left\|\left(\mathbf{w}_1\atop \mathbf{u}_1\right) - \left(\mathbf{w}_2\atop \mathbf{u}_2\right)\right\|_2. \end{align}\]

\[\begin{align} \|\nabla g(\mathbf{w}_1, \mathbf{u}_1) - \nabla g(\mathbf{w}_2, \mathbf{u}_2)\|_2 \leq L_g\left\|\left(\mathbf{w}_1\atop \mathbf{u}_1\right) - \left(\mathbf{w}_2\atop \mathbf{u}_2\right)\right\|_2. \end{align}\]

\[\begin{align} &\mathbb{E}[\|\nabla f(\mathbf{w}, \mathbf{u}; \zeta)- \nabla f(\mathbf{w}, \mathbf{u})\|_2^2]\leq \sigma_f^2,\\ &\mathbb{E}[\|\nabla g(\mathbf{w}, \mathbf{u}; \xi)- \nabla g(\mathbf{w}, \mathbf{u})\|_2^2]\leq \sigma_g^2. \end{align}\]

Let us define \(\bar{\mathbf{w}} : = (\mathbf{w}, \mathbf{u})\) and

\[\begin{align} &\bar f(\bar{\mathbf{w}}, \mathbf{y}): = f(\mathbf{w}, \mathbf{u}) + \lambda \left( g(\mathbf{w}, \mathbf{u}) - g(\mathbf{w}, \mathbf{y}) \right),\\ &\bar F(\bar{\mathbf{w}}) : = \max_{\mathbf{y}}\bar f(\bar{\mathbf{w}}, \mathbf{y}). \end{align}\]

Then \[\begin{align*} &\nabla_1\bar f(\bar{\mathbf{w}}, \mathbf{y})=\nabla f(\mathbf{w}, \mathbf{u}) + \lambda \left(\nabla g(\mathbf{w}, \mathbf{u}) - \left[\nabla_1 g(\mathbf{w}, \mathbf{y})\atop 0\right]\right),\\ &\nabla_2\bar f(\bar{\mathbf{w}}, \mathbf{y})= -\lambda \nabla_2 g(\mathbf{w}, \mathbf{y}),\\ &\nabla_1\bar f(\bar{\mathbf{w}}, \mathbf{y};\varepsilon)=\nabla f(\mathbf{w}, \mathbf{u}; \zeta) + \lambda \left(\nabla g(\mathbf{w}, \mathbf{u};\xi) - \left[\nabla_1 g(\mathbf{w}, \mathbf{y}; \xi)\atop 0\right]\right),\\ &\nabla_2\bar f(\bar{\mathbf{w}}, \mathbf{y};\xi)= -\lambda \nabla_2 g(\mathbf{w}, \mathbf{y};\xi). \end{align*}\] where \(\varepsilon=(\zeta, \xi)\). We first show \(\bar f(\bar{\mathbf{w}}, \mathbf{y})\) satisfies the conditions in Assumption 4.8).

Lemma 4.25

Under Assumption 4.9, we have

Proof.
(i) is obvious. The Lipschitz continuity of \(\nabla_1 \bar f(\bar{\mathbf{w}}, \mathbf{y})\) follows that of \(\nabla f(\mathbf{w}, \mathbf{u})\) and \(\nabla g(\mathbf{w}, \mathbf{u})\). For (iii), we have \[\begin{align*} &\|\nabla_2\bar f(\bar{\mathbf{w}}_1, \mathbf{y}_1)- \nabla_2\bar f(\bar{\mathbf{w}}_2, \mathbf{y}_2)\|_2= \lambda \|\nabla_2 g(\mathbf{w}_1, \mathbf{u}_1)-\nabla_2 g(\mathbf{w}_2, \mathbf{u}_2)\|_2\\ &\leq \lambda \|\nabla g(\mathbf{w}_1, \mathbf{u}_1)-\nabla g(\mathbf{w}_2, \mathbf{u}_2)\|_2 \leq \lambda L_g\left\|\left(\mathbf{w}_1\atop \mathbf{u}_1\right) - \left(\mathbf{w}_2\atop \mathbf{u}_2\right)\right\|_2\\ &\leq \lambda L_g(\|\mathbf{w}_1 - \mathbf{w}_2\|_2 + \|\mathbf{u}_1 -\mathbf{u}_2\|_2)\leq \lambda L_g(\|\bar{\mathbf{w}}_1 - \bar{\mathbf{w}}_2\|_2 + \|\mathbf{u}_1 -\mathbf{u}_2\|_2). \end{align*}\] It is trivial to prove (iv). The last result follows that \(\max_{\mathbf{y}}\bar f(\bar{\mathbf{w}}, \mathbf{y})\geq f(\mathbf{w}, \mathbf{u})\geq \infty\).

Theorem 4.6

Suppose Assumption 4.9 hold. By setting \[\begin{align*} &\beta_t=\beta= \frac{\epsilon^2}{9\sigma_f^2 + 18\lambda^2\sigma_g^2}, \\ &\gamma_t=\gamma=\frac{\mu_g\epsilon^2}{96(L_f+2L_g\lambda)^2\lambda\sigma_g^2},\\ &\eta_t= \\ & \quad\min\left\{\frac{\beta}{\sqrt{8}(L_f+2L_g\lambda)(1+L_g)},\frac{\gamma\mu_g\lambda}{16\sqrt{3}(L_f + 2L_g\lambda)L_g},\frac{1}{2(L_f+2L_g\lambda)(1+L_g)}\right\} \end{align*}\] in (\(\ref{eqn:bi-smda}\)), then the following holds

\[\begin{align}\label{eq:smda-c} \mathbb{E}\left[\frac{1}{T}\sum_{t=0}^{T-1}\left\{\frac{1}{4}\|\mathbf{v}_{t}\|_2^2 + \|\nabla \bar F(\bar{\mathbf{w}}_t)\|_2^2\right\}\right]\leq \epsilon^2, \end{align}\]

with an iteration complexity of

\[\begin{align}\label{eqn:bi-smda-t} T&=O\left(\max\left\{\frac{C_\Upsilon \lambda}{\epsilon^2}, \frac{C_\Upsilon(\lambda\sigma_f^2+\lambda^3\sigma_g^2)}{\epsilon^4}, \frac{C_\Upsilon \lambda^3\sigma_g^2}{\epsilon^4\mu_g^2}\right\}\right), \end{align}\] where \(C_\Upsilon= 2(\bar F(\bar{\mathbf{w}}_0) - \min_{\bar{\mathbf{w}}}\bar F(\bar{\mathbf{w}})) + \frac{1}{\sqrt{8}\bar L_F}\|\mathbf{v}_0 - \nabla \bar F(\bar{\mathbf{w}}_0)\|_2^2 + \frac{L_1}{\sqrt{3}\kappa}\|\mathbf{y}_0 - \mathbf{y}^*(\mathbf{w}_0)\|_2^2\) and \(\kappa = L_g/\mu_g\) .

Proof.
We map the problem into the setting in Theorem 4.5 with \(L_1=L_f + 2L_g\lambda\), \(L_{21}=L_g\lambda\), \(L_{22}=L_g\lambda\), \(\mu=\mu_g\lambda\), \(\kappa = L_{21}/(\mu_g\lambda)=L_g/\mu_g\), \(L_F = L_1(1+\kappa)=(L_f+2L_g\lambda)(1+L_g)\), \(\sigma_1^2 = 3\sigma_f^2 + 6\lambda^2\sigma_g^2\), \(\sigma_2^2 = \lambda^2\sigma_g^2\). Then, substituting these values into the result of Theorem 4.5, we obtain the desired conclusion, retaining only the dependence on \(\lambda\), as it will be set to a sufficiently large value.

Convergence of the original function

Next, we derive the convergence of the original function in terms of \(\|\nabla F(\mathbf{w})\|_2\). We need the following additional assumption.

Assumption 4.10

  1. \(g\) is twice differentiable and \(\nabla_{21}g(\mathbf{w}, \mathbf{u})\) and \(\nabla g_{22}(\mathbf{w}, \mathbf{u})\) are \(L_{gg}\)-Lipschitz continuous; and (ii) \(\|\nabla_2 f(\mathbf{w}, \mathbf{u})\|_2\leq G_f\).

Lemma 4.26

Let \(\mathbf{u}^*_{\lambda}(\mathbf{w})=\arg\min_{\mathbf{u}}\bar F(\mathbf{w}, \mathbf{u})\), \(\mathbf{u}^*(\mathbf{w})=\arg\min_{\mathbf{u}}g(\mathbf{w}, \mathbf{u})\). Under Assumption 4.10(i), we have \[\begin{align*} \|\nabla F(\mathbf{w}) - \nabla_1\bar F(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w}))\|_2&\leq L_f(1+\frac{L_g}{\mu_g})\|\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|_2 \\ &+ L_{gg}\lambda(1+\frac{L_g}{\mu_g}) \|\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|^2_2. \end{align*}\]

Proof.
Let \(\mathbf{u}^* = \mathbf{u}^*(\mathbf{w})\). Then, \[\begin{align*} \nabla_{1} \bar{F}(\mathbf{w}, \mathbf{u}) &= \nabla_{1} f(\mathbf{w}, \mathbf{u}) + \lambda (\nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*)) \\ \nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u}) &= \nabla_{2} f(\mathbf{w}, \mathbf{u}) + \lambda \nabla_{2} g(\mathbf{w}, \mathbf{u}). \end{align*}\] Due to Lemma 4.24, we have

\[\begin{equation} \label{eq:31_new} \begin{split} &\nabla F(\mathbf{w}) - \nabla_{1} \bar{F}(\mathbf{w}, \mathbf{u}) = \nabla_{1} f(\mathbf{w}, \mathbf{u}^*) - \nabla_{1} f(\mathbf{w}, \mathbf{u}) \\ & -\nabla_{12} g(\mathbf{w}, \mathbf{u}^*) \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} \nabla_{2} f(\mathbf{w}, \mathbf{u}^*) - \lambda (\nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*)). \end{split} \end{equation}\]

We can rearrange terms for \((\nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*))\) as the following:

\[\begin{equation} \label{eq:32_new} \begin{split} \nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*) &= \nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} (\mathbf{u} - \mathbf{u}^*) \\ & + \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} (\mathbf{u} - \mathbf{u}^*). \end{split} \end{equation}\]

To continue, we leverage the following equality: \[\begin{equation*} \begin{split} \mathbf{u} - \mathbf{u}^* =& -\nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} g(\mathbf{w}, \mathbf{u}) - \nabla_{2} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{22} g(\mathbf{w}, \mathbf{u}^*) (\mathbf{u} - \mathbf{u}^*)) \\ & +\nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} g(\mathbf{w}, \mathbf{u}) - \nabla_{2} g(\mathbf{w}, \mathbf{u}^*)). \end{split} \end{equation*}\] By the optimality condition for \(\mathbf{u}^*\), \(\nabla_{2} g(\mathbf{w}, \mathbf{u}^*) = 0\), and \(\nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u})=\nabla_{2} f(\mathbf{w}, \mathbf{u}) + \lambda \nabla_{2} g(\mathbf{w}, \mathbf{u})\), we can express \(\mathbf{u} - \mathbf{u}^*\) as

\[\begin{equation}\label{eq:33_new} \begin{split} \mathbf{u} - \mathbf{u}^* &= -\nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} g(\mathbf{w}, \mathbf{u}) - \nabla_{2} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{22} g(\mathbf{w}, \mathbf{u}^*) (\mathbf{u} - \mathbf{u}^*)) \\ & + \frac{1}{\lambda} \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u}) - \nabla_{2} f(\mathbf{w}, \mathbf{u})). \end{split} \end{equation}\]

Combining (\(\ref{eq:33_new}\)) and (\(\ref{eq:32_new}\)) and then plugging the result back to (\(\ref{eq:31_new}\)), we have \[\begin{equation*} \begin{split} &\nabla F(\mathbf{w}) - \nabla_{1} \bar{F}(\mathbf{w}, \mathbf{u}) = \nabla_{1} f(\mathbf{w}, \mathbf{u}^*) - \nabla_{1} f(\mathbf{w}, \mathbf{u}) \\ & -\nabla_{12} g(\mathbf{w}, \mathbf{u}^*) \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} \nabla_{2} f(\mathbf{w}, \mathbf{u}^*) \\ &- \lambda (\nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} (\mathbf{u} - \mathbf{u}^*))\\ & +\lambda \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} g(\mathbf{w}, \mathbf{u}) - \nabla_{2} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{22} g(\mathbf{w}, \mathbf{u}^*) (\mathbf{u} - \mathbf{u}^*)) \\ & - \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u}) - \nabla_{2} f(\mathbf{w}, \mathbf{u})). \end{split} \end{equation*}\]

As a result, we have \[\begin{equation*} \begin{split} &\nabla F(\mathbf{w}) - \nabla_{1} \bar{F}(\mathbf{w}, \mathbf{u})+ \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} \nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u}) \\ & = \nabla_{1} f(\mathbf{w}, \mathbf{u}^*) - \nabla_{1} f(\mathbf{w}, \mathbf{u}) \\ & -\nabla_{12} g(\mathbf{w}, \mathbf{u}^*) \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} f(\mathbf{w}, \mathbf{u}^*) - \nabla_{2} f(\mathbf{w}, \mathbf{u})) \\ &- \lambda (\nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} (\mathbf{u} - \mathbf{u}^*))\\ & +\lambda \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} (\nabla_{2} g(\mathbf{w}, \mathbf{u}) - \nabla_{2} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{22} g(\mathbf{w}, \mathbf{u}^*) (\mathbf{u} - \mathbf{u}^*)) . \end{split} \end{equation*}\]

By Assumption 4.10 we have \[\begin{align*} &\|\nabla_{1} g(\mathbf{w}, \mathbf{u}) - \nabla_{1} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} (\mathbf{u} - \mathbf{u}^*)\|_2\leq L_{gg}\|\mathbf{u} - \mathbf{u}^*\|^2_2,\\ &\|\nabla_{2} g(\mathbf{w}, \mathbf{u}) - \nabla_{2} g(\mathbf{w}, \mathbf{u}^*) - \nabla_{22} g(\mathbf{w}, \mathbf{u}^*) (\mathbf{u} - \mathbf{u}^*)\|_2 \le L_{gg} \|\mathbf{u} - \mathbf{u}^*\|^2_2. \end{align*}\] By Assumption 4.9 we have \[\begin{align*} &\|\nabla_{1} f(\mathbf{w}, \mathbf{u}^*) - \nabla_{1} f(\mathbf{w}, \mathbf{u})\|_2\leq L_f\|\mathbf{u}^*-\mathbf{u}\|_2,\\ &\|\nabla_{2} f(\mathbf{w}, \mathbf{u}^*) - \nabla_{2} f(\mathbf{w}, \mathbf{u})\|_2\leq L_f\|\mathbf{u}^*-\mathbf{u}\|_2,\\ &\|\nabla_{12} g(\mathbf{w}, \mathbf{u}^*) \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1}\|_2\leq \frac{L_g}{\mu_g}. \end{align*}\] Thus, we have \[\begin{equation*} \begin{split} &\|\nabla F(\mathbf{w}) - \nabla_{1} \bar{F}(\mathbf{w}, \mathbf{u})+ \nabla_{12} g(\mathbf{w}, \mathbf{u}^*)^{\top} \nabla_{22} g(\mathbf{w}, \mathbf{u}^*)^{-1} \nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u})\|_2 \\ & \leq L_f(1+\frac{L_g}{\mu_g})\|\mathbf{u} - \mathbf{u}^*\|_2 + L_{gg}\lambda(1+\frac{L_g}{\mu_g}) \|\mathbf{u} - \mathbf{u}^*\|^2_2. \end{split} \end{equation*}\] Plugging \(\mathbf{u}=\mathbf{u}^*_\lambda(\mathbf{w}) = \min_{\mathbf{u}}\bar F(\mathbf{w}, \mathbf{u})\), then \(\nabla_{2} \bar{F}(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w}) ) =0\) and then we have \[\begin{equation*} \begin{split} &\|\nabla F(\mathbf{w}) - \nabla_{1} \bar{F}(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w}))\|_2 \\ & \leq L_f(1+\frac{L_g}{\mu_g})\|\mathbf{u}^*_\lambda(\mathbf{w}) - \mathbf{u}^*\|_2 + L_{gg}\lambda(1+\frac{L_g}{\mu_g}) \|\mathbf{u}^*_\lambda(\mathbf{w}) - \mathbf{u}^*\|^2_2. \end{split} \end{equation*}\]

Next, we bound \(\|\mathbf{u}^*_\lambda(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|_2\).

Lemma 4.27

Under Assumption 4.10(ii), we have \(\|\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|_2 \le \frac{G_f}{\lambda \mu_g}\).

Proof.
By the definitions of \(\mathbf{u}^*_\lambda(\mathbf{w})\), \(\mathbf{u}^*(\mathbf{w})\), we have \[\begin{align*} &\mathbf{u}^*_\lambda(\mathbf{w}) = \arg\min_{\mathbf{u}} \frac{1}{\lambda}f(\mathbf{w}, \mathbf{u}) + g(\mathbf{w}, \mathbf{u}), \\ &\mathbf{u}^*(\mathbf{w}) = \arg\min_{\mathbf{u}} g(\mathbf{w}, \mathbf{u}). \end{align*}\] By the optimality condition, \[\begin{align*} &\frac{1}{\lambda}\nabla_2 f(\mathbf{w}, \mathbf{u}_\lambda^*(\mathbf{w})) + \nabla_2 g(\mathbf{w}, \mathbf{u}_\lambda^*(\mathbf{w}))=0,\\ &\nabla_2 g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))=0. \end{align*}\] Since \(g(\mathbf{w}, \mathbf{u})\) is \(\mu_g\)-strongly convex w.r.t \(\mathbf{u}\) for any \(\mathbf{w}\), then we have \[\begin{align*} g(\mathbf{w},\mathbf{u}^*_{\lambda}(\mathbf{w}))\geq & g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})) + \nabla_2 g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))^{\top}(\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})) \\ & + \frac{\mu_g}{2}\|\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|_2^2,\\ g(\mathbf{w},\mathbf{u}^*(\mathbf{w}))\geq & g(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w})) + \nabla_2 g(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w}))^{\top}(\mathbf{u}^*(\mathbf{w}) - \mathbf{u}^*_{\lambda}(\mathbf{w})) \\ & + \frac{\mu_g}{2}\|\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|_2^2. \end{align*}\] Adding these two inequalities yields: \[\begin{align*} \mu_g\|\mathbf{u}^*_{\lambda}(\mathbf{w}) - \mathbf{u}^*(\mathbf{w})\|_2^2&\leq -\nabla_2 g(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w}))^{\top}(\mathbf{u}^*(\mathbf{w}) - \mathbf{u}^*_{\lambda}(\mathbf{w}))\\ & =\frac{1}{\lambda}\nabla_2 f(\mathbf{w}, \mathbf{u}_\lambda^*(\mathbf{w}))^{\top}(\mathbf{u}^*(\mathbf{w}) - \mathbf{u}^*_{\lambda}(\mathbf{w}))\\ &\leq \frac{1}{\lambda}\|\nabla_2 f(\mathbf{w}, \mathbf{u}_\lambda^*(\mathbf{w}))\|_2\|(\mathbf{u}^*(\mathbf{w}) - \mathbf{u}^*_{\lambda}(\mathbf{w}))\|_2. \end{align*}\] Dividing both sides by \(\|\mathbf{u}^*(\mathbf{w}) - \mathbf{u}^*_{\lambda}(\mathbf{w})\|_2\) and noting \(\|\nabla_2 f(\mathbf{w}, \mathbf{u}_\lambda^*(\mathbf{w}))\|_2\leq G_f\) concludes the proof.

Corollary 4.2

Under the same setting as in Theorem 4.6 with \(\lambda=O(\frac{1}{\epsilon})>2L_f/\mu_g\) and assume \(\|\mathbf{y}_0-\mathbf{y}^*(\mathbf{w}_0)\|^2_2\leq O(\epsilon)\), then the following holds

\[\begin{align} \mathbb{E}\left[\|\nabla F(\mathbf{w}_\tau)\|_2\right]\leq O(\epsilon), \end{align}\]

with an iteration complexity of

\[\begin{align}\label{eq:bi-smda-t} T&=O\left(\max\left\{\frac{1}{\epsilon^3}, \frac{\sigma_f^2}{\epsilon^5}, \frac{\sigma_g^2}{\epsilon^7}\right\}\right), \end{align}\] where \(\tau\in\{0,\ldots, T-1\}\) is randomly sampled.

Proof.
Combining Lemma 4.25 and Lemma 4.27, we have \[\begin{align*} \|\nabla F(\mathbf{w}_\tau)\|_2 & = \|\nabla F(\mathbf{w}_\tau) - \nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}^*_{\lambda}(\mathbf{w}_\tau))\|_2 + \|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}^*_{\lambda}(\mathbf{w}_\tau))\|_2\\ &\leq L_f(1+\frac{L_g}{\mu_g})\frac{G_f}{\mu_g\lambda} + L_{gg}\lambda(1+\frac{L_g}{\mu_g})\frac{G_f^2}{\mu_g^2\lambda^2}\\ & + \|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}^*_{\lambda}(\mathbf{w}_\tau)) - \nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2 + \|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2\\ &\leq \frac{2L_fL_gG_f}{\mu_g^2\lambda} + \frac{2L_{gg}L_gG_f^2}{\mu_g^3\lambda}\\ & + \|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}^*_{\lambda}(\mathbf{w}_\tau)) - \nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2 + \|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2. \end{align*}\] Since \(\bar F(\mathbf{w}, \mathbf{u})\) is \((\lambda\mu_g - L_f)\)-strongly convex w.r.t \(\mathbf{u}\), Lemma 1.6(c) implies that \[\begin{align*} (\lambda\mu_g - L_f)&\|\mathbf{u}^*_{\lambda}(\mathbf{w}_\tau) - \mathbf{u}_\tau\|_2^2\leq \frac{1}{(\lambda\mu_g - L_f)}\|\nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau) - \nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}^*_{\lambda}(\mathbf{w}_\tau))\|_2^2\\ & = \frac{1}{(\lambda\mu_g - L_f)}\|\nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau) \|^2_2. \end{align*}\] Due to \(\nabla_1 \bar F(\mathbf{w}, \mathbf{u}) = \nabla_1 f(\mathbf{w}, \mathbf{u}) + \lambda ( \nabla_1 g(\mathbf{w}, \mathbf{u}) - \nabla_1 g(\mathbf{w}, \mathbf{u}^*(\mathbf{w})))\), we have \[\begin{align*} &\|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}^*_{\lambda}(\mathbf{w}_\tau)) - \nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2 \leq (L_f + \lambda L_g)\|\mathbf{u}^*_{\lambda}(\mathbf{w}_\tau) - \mathbf{u}_\tau\|_2\\ &\leq \frac{(L_f + \lambda L_g)}{(\lambda\mu_g - L_f)}\|\nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2\\ &\leq \frac{2(\lambda\mu_g/2+\lambda L_g)}{\lambda\mu_g}\|\nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2=\frac{\mu_g+2L_g}{\mu_g}\|\nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2 \end{align*}\] where the last inequality uses \(L_f\leq \lambda\mu_g/2\). Combining the above inequalities, we obtain \[\begin{align*} \|\nabla F(\mathbf{w}_\tau)\|_2 &\leq \frac{2L_fL_gG_f}{\mu_g^2\lambda} + \frac{2L_{gg}L_gG_f^2}{\mu_g^3\lambda} \\ &+ \frac{\mu_g+ 2L_g}{\mu_g}\|\nabla_2 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2 + \|\nabla_1 \bar F(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2. \end{align*}\] From Theorem 4.6, we have \[\begin{align*} \mathbb{E}\!\left[\|\nabla_2 \bar{F}(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2^2 + \|\nabla_1 \bar{F}(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2^2\right] \le \epsilon^2. \end{align*}\] Hence, it follows that \(\mathbb{E}[\|\nabla_2 \bar{F}(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2] \le \epsilon\) and \(\mathbb{E}[\|\nabla_1 \bar{F}(\mathbf{w}_\tau, \mathbf{u}_\tau)\|_2] \le \epsilon\). If \(\lambda = O(1/\epsilon)\), then \(\mathbb{E}[\|\nabla F(\mathbf{w}_\tau)\|_2] \le O(\epsilon)\). The iteration complexity can be established by substituting \(\lambda = O(1/\epsilon)\) into Theorem 4.6 and noting that \(C_\Upsilon = O(1)\) when \(\|\mathbf{y}_0 - \mathbf{y}^*(\mathbf{w}_0)\|_2^2 \le O(\epsilon)\).

Critical

The complexity of \(O(1/\epsilon^7)\) is not the state-of-the-art sample complexity achievable under the same assumptions. Indeed, a double-loop large-batch method?similar to the one presented in Section 4.5.1.1 for solving the min-max problem \(\min_{\mathbf{w},\mathbf{u}} \max_{\mathbf{y}} f(\mathbf{w}, \mathbf{u}) + \lambda \left( g(\mathbf{w}, \mathbf{u}) - g(\mathbf{w}, \mathbf{y}) \right)\)–can yield a superior sample complexity of \(O(1/\epsilon^6)\) for achieving the stationarity condition \(\mathbb{E}[\|\nabla F(\mathbf{w})\|_2]\leq\epsilon^2\).

To see this, we apply a similar argument as in Section 4.5.1.1. Let \(F_\lambda(\mathbf{w}):=\min_{\mathbf{u}} \max_{\mathbf{y}} f(\mathbf{w}, \mathbf{u}) + \lambda \left( g(\mathbf{w}, \mathbf{u}) - g(\mathbf{w}, \mathbf{y}) \right)\). Then \[ \nabla F_\lambda(\mathbf{w})=\nabla_1 f(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w})) + \lambda \left(\nabla_1 g(\mathbf{w}, \mathbf{u}^*_\lambda(\mathbf{w})) - \nabla_1 g(\mathbf{w}, \mathbf{u}^*(\mathbf{w}))\right). \]

Given \(\mathbf{w}_t\), let \(\mathbf{u}_t, \mathbf{u}_{\lambda, t}\) denote approximations of \(\mathbf{u}^*(\mathbf{w}_t), \mathbf{u}^*_\lambda(\mathbf{w}_t)\) and let \[\begin{align*} &\mathbf{v}_t = \frac{1}{B}\sum_{i=1}^B\nabla_1 f(\mathbf{w}_t, \mathbf{u}_{\lambda,t}; \zeta_i) + \lambda \frac{1}{B}\sum_{i=1}^B\left(\nabla_1 g(\mathbf{w}_t, \mathbf{u}_{\lambda,t}; \xi_i) - \nabla_1 g(\mathbf{w}_t, \mathbf{u}_t; \xi_i)\right),\\ &\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \mathbf{v}_t. \end{align*}\]

Then, we have \[ \|\mathbf{v}_t - \nabla F_\lambda(\mathbf{w}_t)\|_2^2 \leq \frac{O(\lambda^2)}{B} + O\left(\lambda^2\|\mathbf{u}_{\lambda,t}-\mathbf{u}^*_\lambda(\mathbf{w}_t)\|_2^2\right) + O\left(\lambda^2\|\mathbf{u}_t-\mathbf{u}^*(\mathbf{w}_t)\|_2^2\right). \]

Following Corollary 3.9, we apply SGD to solve \(\min_{\mathbf{u}} f(\mathbf{w}, \mathbf{u}) + \lambda g(\mathbf{w}, \mathbf{u})\) with \[ T_0=\mathcal{O}\left(\max\left\{\frac{L_f+\lambda L_g}{\lambda \mu_g\sqrt{\epsilon}}, \frac{\sigma_f^2 + \lambda^2 \sigma_g^2}{(\lambda\mu_g)^2 (\epsilon/\lambda)^2}\right\}\right)=O\left(\frac{\lambda^2}{\epsilon^2}\right), \] we have \(\mathbb{E}[\|\mathbf{u}_{\lambda,t} - \mathbf{u}^*_{\lambda}(\mathbf{w}_t)\|_2^2] \leq \frac{\epsilon^2}{\lambda^2}\).

Similarly, applying SGD to solve \(\min_{\mathbf{u}} g(\mathbf{w}, \mathbf{u})\) with \[ T_1=\mathcal{O}\left(\max\left\{\frac{L_g}{\mu_g\sqrt{\epsilon}}, \frac{\sigma_g^2}{\mu_g^2 (\epsilon/\lambda)^2}\right\}\right)=O\left(\frac{\lambda^2}{\epsilon^2}\right), \] we have \(\mathbb{E}[\|\mathbf{u}_t - \mathbf{u}^*(\mathbf{w}_t)\|_2^2] \leq \frac{\epsilon^2}{\lambda^2}\).

Hence, with \(B=O(\lambda^2/\epsilon^2)\), we have \(\|\mathbf{v}_t - \nabla F_\lambda(\mathbf{w}_t)\|_2^2 \leq O(\epsilon^2)\).

Then, with \(\eta=O(1/L_{F_\lambda})\) and \(T=O(L_{F_\lambda}/\epsilon^2)\) iterations, we can prove that $ [|F_{}(_)|_2]O() $ following Lemma 4.9. Since Lemma 4.26 and Lemma 4.27 indicate that \(\|\nabla F_{\lambda}(\mathbf{w}_\tau) - \nabla F(\mathbf{w}_\tau)\|_2 \leq O(1/\lambda)\), with \(\lambda=O(1/\epsilon)\), we have \(\mathbb{E}[\|\nabla F(\mathbf{w}_\tau)\|_2]\leq O(\epsilon)\).

The sample complexity is \[ BT + T_0T + T_1T = O\left(\frac{\lambda^2 L_{F_\lambda}}{\epsilon^4}\right) = O\left(\frac{L_{F_\lambda}}{\epsilon^6}\right). \]

We can establish \(L_{F_\lambda} = O(1)\) independent of \(\lambda\) (Chen et al. 2025, Lemma B.7). Hence the total sample complexity is \(O(1/\epsilon^6)\).

It remains an open problem to develop a single-loop stochastic algorithm that achieves \(O(1/\epsilon^6)\) complexity without requiring a large batch size or assuming mean-square smoothness (see next section for more discussion).

← Go Back