Section 4.3 Stochastic Compositional Momentum Methods
In this section, we present a method that matches the sample complexity of the large mini-batch approach without using large mini-batches under the smoothness conditions of \(f\) and \(F\). The idea is to design a gradient estimator such that its error can be reduced gradually. It turns out this technique, related to the momentum methods for standard stochastic optimization, is more widely applicable to other problems discussed later in this chapter. Furthermore, we introduce advanced methods to further improve the complexity to \(O(1/\epsilon^3)\) under stronger conditions.
It is worth noting that the results in this section apply to the standard stochastic optimization problem under the smoothness assumption of \(g(\mathbf{w})\) by setting \(f_i(g) =g\) and \(L_1=0\) in the complexity results and removing the \(\mathbf{u}\) update in the algorithm.
4.3.1 Moving-Average Gradient Estimator
The first method is to use the following moving-average gradient estimator: \[\begin{align}\label{eqn:mag} \mathbf{v}_{t} & = (1-\beta_t)\mathbf{v}_{t-1} + \beta_t \nabla g(\mathbf{w}_t; \zeta_t')\nabla f(\mathbf{u}_{t}; \xi_t), \end{align}\] where \(0 \leq \beta_t < 1\). With \(\mathbf{v}_{t}\), the model parameter is updated by: \[\begin{align} \mathbf{w}_{t+1} & = \mathbf{w}_t - \eta_t \mathbf{v}_{t}. \end{align}\]
We present the full steps in Algorithm 10 and refer to it as SCMA.
To understand this method, we can view \(\mathbf{v}_{t}\) as a better estimator of the gradient, with its estimation error gradually decreasing over iterations — a property we will prove shortly. This yields an enhanced stability of momentum-based metholds observed in practice.
Algorithm 10: SCMA
- Input: learning rate schedules \(\{\eta_t\}_{t=1}^{T}\), \(\{\gamma_t\}_{t=1}^{T}\); starting points \(\mathbf{w}_0\), \(\mathbf{u}_0, \mathbf{v}_0\)
- Let \(\mathbf{w}_1 = \mathbf{w}_0 - \eta_0\mathbf{v}_0\)
- For \(t=1,\dotsc,T\)
- Sample \(\zeta_t, \zeta'_t\) and \(\xi_t\)
- Compute the inner function value estimator \(\mathbf{u}_{t} = (1-\gamma_t)\mathbf{u}_{t-1} + \gamma_t g(\mathbf{w}_t; \zeta_t)\)
- Compute the vanilla gradient estimator \(\mathbf{z}_t = \nabla g(\mathbf{w}_t; \zeta_t')\nabla f(\mathbf{u}_{t}; \xi_t)\)
- Update the MA gradient estimator \(\mathbf{v}_{t} = (1-\beta_t)\mathbf{v}_{t-1} + \beta_t \mathbf{z}_t\)
- Update the model \(\mathbf{w}\) by \(\mathbf{w}_{t+1} = \mathbf{w}_t - \eta_t \mathbf{v}_{t}\)
This method is analogous to applying the stochastic momentum method to the ERM problem, using the term \(\nabla g(\mathbf{w}_t; \zeta_t')\nabla f(\mathbf{u}_{t}; \xi_t)\) as a surrogate for the true stochastic gradient. This connection is revealed by reformulating the update into a canonical momentum form: \[\begin{align}\label{eqn:mom} \mathbf{w}_{t+1} & = \mathbf{w}_t - \eta'_t\nabla g(\mathbf{w}_t; \zeta_t')\nabla f(\mathbf{u}_{t}; \xi_t) + \beta'_t (\mathbf{w}_t - \mathbf{w}_{t-1}), \end{align}\] where the effective step size and momentum parameters are \(\eta_t' = \eta_t \beta_t\) and \(\beta'_t = \eta_t(1-\beta_t)/\eta_{t-1}\), respectively. The term \(\beta'_t(\mathbf{w}_t - \mathbf{w}_{t-1})\) is the momentum term.
In the special case where \(f\) is the identity function, the update is identical to the classical stochastic momentum method, also known as stochastic heavy-ball method, renowned for its accelerated performance on quadratic functions relative to plain gradient descent. Hence, the convergence analysis presented below also applies to the stochastic momentum method for ERM by setting \(L_1=0\).
First, we prove a generic lemma that establishes the error recursion of \(\mathbf{v}_{t}\).
Lemma 4.7 Let \(\mathbf{v}_{t} = (1-\beta_t)\mathbf{v}_{t-1} +\beta_t \mathbf{z}_t\), where \(\mathbb{E}_t[\mathbf{z}_t]=\mathcal{M}_t\). If \(\mathbb{E}_t[\|\mathbf{z}_t - \mathcal{M}_t\|_2^2]\leq \sigma^2\), then we have \[\begin{align}\label{eq:nasa_grad} &\mathbb{E}_{t}\left[\left\|\mathbf{v}_{t} - \nabla F(\mathbf{w}_t)\right\|_2^2\right] \leq (1-\beta_t)\left\|\mathbf{v}_{t-1} - \nabla F(\mathbf{w}_{t-1})\right\|_2^2 + \beta_t^2\sigma^2\\\nonumber & \quad + \frac{2L_F^2}{\beta_t}\left\|\mathbf{w}_{t-1}-\mathbf{w}_t\right\|_2^2 + 4\beta_t\left\|\mathcal{M}_{t} - \nabla F(\mathbf{w}_t)\right\|_2^2. \end{align}\]
Proof
Due to the update formula \(\mathbf{v}_{t} = (1-\beta_t)\mathbf{v}_{t-1} + \beta_t \mathbf{z}_t\), we have \[\begin{align*} & \mathbb{E}_{t}\left[\left\|\mathbf{v}_{t} - \nabla F(\mathbf{w}_t)\right\|_2^2\right] \\ & = \mathbb{E}_{t}\left[\left\|(1-\beta_t)\mathbf{v}_{t-1} + \beta_t \mathbf{z}_t - \nabla F(\mathbf{w}_t)\right\|_2^2\right]\\ & = \mathbb{E}_{t}\bigg[\|\underbrace{(1-\beta_t)\mathbf{v}_{t-1}- \nabla F(\mathbf{w}_t) + \beta_t \mathcal{M}_t}\limits_{\mathbf{a}_t}+\underbrace{\beta_t (\mathbf{z}_t - \mathcal{M}_t)}\limits_{\mathbf{b}_t}\|_2^2\bigg]. \end{align*}\]
Note that \(\mathbb{E}_{t}[\mathbf{a}_t^{\top}\mathbf{b}_t] = 0\). Besides, we have \(\mathbb{E}_t[\|\mathbf{b}_t\|_2^2]\leq \beta_t^2 \sigma^2\). Due to Young’s inequality, we have \(\|a+b\|_2^2 \leq (1+\alpha)\|a\|_2^2 + (1+1/\alpha)\|b\|_2^2\) for any \(\alpha>0\). Hence, \[\begin{align*} &\|\mathbf{a}_t\|_2^2=\|(1-\beta_t)(\mathbf{v}_{t-1} - \nabla F(\mathbf{w}_{t-1})) + (1-\beta_t)(\nabla F(\mathbf{w}_{t-1}) - \nabla F(\mathbf{w}_t))\\ &\quad \quad\quad \quad + \beta_t (\mathcal{M}_t - \nabla F(\mathbf{w}_t))\|_2^2\\ &\leq (1-\beta_t)^2(1+\beta_t)\|(\mathbf{v}_{t-1} - \nabla F(\mathbf{w}_{t-1}))\|_2^2\\ & + (1+\frac{1}{\beta_t})\|(1-\beta_t)(\nabla F(\mathbf{w}_{t-1}) - \nabla F(\mathbf{w}_t)) + \beta_t (\mathcal{M}_t - \nabla F(\mathbf{w}_t))\|_2^2\\ & \leq (1-\beta_t)\left\|\mathbf{v}_{t-1} - \nabla F(\mathbf{w}_{t-1})\right\|_2^2 + \frac{2(1+\beta_t)(1-\beta_t)^2}{\beta_t}\left\|\nabla F(\mathbf{w}_{t-1}) - \nabla F(\mathbf{w}_t)\right\|_2^2 \\ & + \frac{2(1+\beta_t)\beta_t^2}{\beta_t}\left\|\mathcal{M}_t - \nabla F(\mathbf{w}_t)\right\|_2^2\\ & \leq (1-\beta_t)\left\|\mathbf{v}_{t-1} - \nabla F(\mathbf{w}_{t-1})\right\|_2^2 + \frac{2L_F^2}{\beta_t}\left\|\mathbf{w}_{t-1}-\mathbf{w}_t\right\|_2^2 + 4\beta_t\left\|\mathcal{M}_{t} - \nabla F(\mathbf{w}_t)\right\|_2^2. \end{align*}\]
Combining the above results, we finish the proof.■
With the above lemma, we are able to establish the error recursion of \(\mathbf{v}_{t}\) of SCMA.
Lemma 4.8 Under Assumption 4.1, Assumption 4.2, Assumption 4.3, and Assumption 4.4, for \(t\geq 1\) SCMA satisfies that \[\begin{align} &\mathbb{E}_{\xi_t, \zeta'_t}\left[\left\|\mathbf{v}_{t} - \nabla F(\mathbf{w}_t)\right\|_2^2\right] \leq (1-\beta_t)\left\|\mathbf{v}_{t-1} - \nabla F(\mathbf{w}_{t-1})\right\|_2^2 \\\notag & \quad + \frac{2L_F^2}{\beta_t}\left\|\mathbf{w}_{t-1}-\mathbf{w}_t\right\|_2^2 + 4G_2^2L_1^2\beta_t\left\|\mathbf{u}_{t} - g(\mathbf{w}_t)\right\|_2^2 + \beta_t^2\sigma^2,\label{eqn:nasa_grad} \end{align}\] where \(\sigma^2=G_1^2\sigma_2^2+G_2^2\sigma_1^2\).
💡 Why it matters
The above lemma establishes the recursion of the error of stochastic
gradient estimator \(\mathbf{v}_{t}\).
It is the key to show that the average of the estimator error of \(\mathbf{v}_{t}\) will converge to zero.
Proof
We denote by \(\mathbb{E}_{t}[\cdot]=\mathbb{E}_{\xi_t, \zeta'_t}[\cdot]\). Let \(\mathbf{z}_t = \nabla g(\mathbf{w}_t;\zeta'_t)\nabla f(\mathbf{u}_{t};\xi_t)\) and \(\mathcal{M}_t = \mathbb{E}_t[\mathbf{z}_t] = \nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_{t})\). Lemma 4.4 proves that \[\begin{align} \mathbb{E}_t[\|\mathbf{z}_t -\mathcal{M}_t\|_2^2]&\leq G_2^2\sigma_1^2 + G_1^2\sigma_2^2,\label{eqn:sgf} \end{align}\] and \[\begin{align*} \|\mathcal{M}_t - \nabla F(\mathbf{w}_t)\|_2^2& = \|\nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_{t}) -\nabla g(\mathbf{w}_t)\nabla f(g(\mathbf{w}_{t})) \|_2^2\\ & \leq G_2^2L_1^2\|\mathbf{u}_t - g(\mathbf{w}_t)\|_2^2. \end{align*}\]
Plugging these two results into Lemma 4.7, we finish the proof.■
If we use the same random sample \(\zeta_t\) to compute \[\mathbf{z}_t=\nabla g(\mathbf{w}_t;\zeta_t)\nabla f(\mathbf{u}_{t};\xi_t),\] then \(\mathcal{M}_t=\mathbb{E}_{\xi_t,\zeta_t}[\mathbf{z}_t]\) is not equal to \(\nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_t)\). However, we just need to assume that \(\mathbb{E}_{\zeta_t,\xi_t}[\|\mathbf{z}_t -\mathcal{M}_t\|_2^2]\) is bounded and \(\|\nabla g(\mathbf{w}_t; \zeta_t)\|_2^2 \le G_2\). Then \[\begin{align*} \|\mathcal{M}_t - \nabla F(\mathbf{w}_t)\|_2^2& = \|\mathbb{E}_{\zeta_t}\nabla g(\mathbf{w}_t; \zeta_t)\nabla f(\mathbf{u}_{t}) -\mathbb{E}_{\zeta_t}\nabla g(\mathbf{w}_t;\zeta_t)\nabla f(g(\mathbf{w}_{t})) \|_2^2\\ & \leq \mathbb{E}_{\zeta_t} \|\nabla g(\mathbf{w}_t; \zeta_t)\nabla f(\mathbf{u}_{t}) -\nabla g(\mathbf{w}_t;\zeta_t)\nabla f(g(\mathbf{w}_{t})) \|_2^2 \\ &\leq \mathbb{E}_{\zeta_t} [\|\nabla g(\mathbf{w}_t; \zeta_t)\|_2^2\|\nabla f(\mathbf{u}_{t}) -\nabla f(g(\mathbf{w}_{t})) \|_2^2]\\ &\leq \mathbb{E}_{\zeta_t}\!\left[ G_2^2 L_1^2 \, \|\mathbf{u}_t - g(\mathbf{w}_t)\|^2_2 \right]. \end{align*}\] The following analysis will proceed in the same manner.
To enjoy the above recursion of the gradient estimator’s error, we state the following lemma, which is a variant of the standard descent lemma of gradient descent.
Lemma 4.9 For the update \(\mathbf{w}_{t+1} = \mathbf{w}_t -\eta_t \mathbf{v}_{t}, t\geq 0\), if \(\eta_t\leq 1/(2L_F)\), we have \[\begin{align}\label{eq:nasa_starter} F(\mathbf{w}_{t+1})& \leq F(\mathbf{w}_t) + \frac{\eta_t}{2} \left\|\nabla F(\mathbf{w}_t) - \mathbf{v}_{t}\right\|_2^2- \frac{\eta_t}{2}\left\|\nabla F(\mathbf{w}_t)\right\|_2^2 - \frac{1}{4\eta_t} \left\|\mathbf{w}_{t+1} - \mathbf{w}_t\right\|_2^2. \end{align}\]
💡 Why it matters
This lemma ensures that if the stochastic gradient error satisfies \(\mathbb{E}\left[\frac{1}{T}\sum_{t=1}^T \|\nabla
F(\mathbf{w}_t) - \mathbf{v}_{t}\|_2^2\right] \rightarrow 0\),
then the convergence of \(\mathbb{E}\left[\frac{1}{T}\sum_{t=1}^T \|\nabla
F(\mathbf{w}_t)\|_2^2\right]\) to zero is guaranteed.
Proof
Due to the smoothness of \(F\), we have \[\begin{align*} & F(\mathbf{w}_{t+1}) \leq F(\mathbf{w}_t) + \mathbf{\nabla} F(\mathbf{w}_t)^{\top}(\mathbf{w}_{t+1} - \mathbf{w}_t) + \frac{L_F}{2}\left\|\mathbf{w}_{t+1}-\mathbf{w}_t\right\|_2^2\\ & = F(\mathbf{w}_t) + (\nabla F(\mathbf{w}_t) - \mathbf{v}_{t})^{\top}(\mathbf{w}_{t+1} - \mathbf{w}_t) + \mathbf{v}_{t}^{\top}(\mathbf{w}_{t+1} - \mathbf{w}_t)+ \frac{L_F}{2}\left\|\mathbf{w}_{t+1}-\mathbf{w}_t\right\|_2^2\\ & = F(\mathbf{w}_t) - \eta_t (\nabla F(\mathbf{w}_t) - \mathbf{v}_{t})^{\top}\mathbf{v}_{t} - \left(\frac{1}{\eta_t} - \frac{L_F}{2}\right)\left\|\mathbf{w}_{t+1} - \mathbf{w}_t\right\|_2^2\\ & = F(\mathbf{w}_t) + \eta_t \left\|(\nabla F(\mathbf{w}_t) - \mathbf{v}_{t})\right\|_2^2 - \eta_t (\nabla F(\mathbf{w}_t) - \mathbf{v}_{t})^{\top}\nabla F(\mathbf{w}_t) \\ &- \left(\frac{1}{\eta_t} - \frac{L_F}{2}\right)\left\|\mathbf{w}_{t+1} - \mathbf{w}_t\right\|_2^2. \end{align*}\]
Since \((\nabla F(\mathbf{w}_t) - \mathbf{v}_{t})^{\top}\nabla F(\mathbf{w}_t)=\frac{1}{2}\left(\left\|\nabla F(\mathbf{w}_t) - \mathbf{v}_{t}\right\|_2^2 + \left\|\nabla F(\mathbf{w}_t)\right\|_2^2 - \left\|\mathbf{v}_{t}\right\|_2^2\right)\), then we have \[\begin{align*} & F(\mathbf{w}_{t+1}) \leq F(\mathbf{w}_t) + \eta_t \left\|\nabla F(\mathbf{w}_t) - \mathbf{v}_{t}\right\|_2^2- \left(\frac{1}{\eta_t} - \frac{L_F}{2}\right)\left\|\mathbf{w}_{t+1} - \mathbf{w}_t\right\|_2^2 \\ &- \frac{\eta_t}{2}\left(\left\|\nabla F(\mathbf{w}_t) - \mathbf{v}_{t}\right\|_2^2 + \left\|\nabla F(\mathbf{w}_t)\right\|_2^2 - \left\|\mathbf{v}_{t}\right\|_2^2 \right)\\ & = F(\mathbf{w}_t) + \frac{\eta_t}{2} \left\|\nabla F(\mathbf{w}_t) - \mathbf{v}_{t}\right\|_2^2 - \frac{\eta_t}{2}\left\|\nabla F(\mathbf{w}_t)\right\|_2^2 - \left(\frac{1}{2\eta_t} - \frac{L_F}{2}\right)\left\|\mathbf{w}_{t+1} - \mathbf{w}_t\right\|_2^2. \end{align*}\]■
To prove the final convergence of SCMA, we present a useful lemma.
Lemma 4.10 If \(\eta_t\leq 1/L\), assume that there exist non-negative sequences \(A_t, B_t, \Gamma_t, \Delta_t, \delta_t, t\geq 0\) satisfying: \[\begin{align*} (*)&A_{t+1}\leq A_t + \eta_t \Delta_t - \eta_t B_t - \eta_t \Gamma_t\\ (\sharp)&\Delta_{t+1} \leq (1-\beta_{t+1})\Delta_{t} + C_1\beta_{t+1}\delta_{t+1} + \frac{C_2\eta_{t}^2}{\beta_{t+1}}\Gamma_t + \beta_{t+1}^2\sigma^2,\\ (\diamond)&\delta_{t+1} \leq (1-\gamma_{t+1}) \delta_{t} + \frac{C_3\eta_{t}^2}{\gamma_{t+1}} \Gamma_t + \gamma_{t+1}^2{\sigma'}^2. \end{align*}\] Let \(\Upsilon_{t}=A_{t} + \frac{\eta_{t-1}}{\beta_{t}}\Delta_{t} + \frac{C_1\eta_{t-1}}{\gamma_{t}} \delta_{t}\). If \(\frac{\eta_t}{\beta_{t+1}}\leq \frac{\eta_{t-1}}{\beta_{t}}\), \(\frac{\eta_t}{\gamma_{t+1}}\leq \frac{\eta_{t-1}}{\gamma_{t}}\), \(\eta_t\leq \min(\frac{\beta_{t+1}}{\sqrt{4C_2}}, \frac{\gamma_{t+1}}{\sqrt{8C_1C_3}})\), and \(\Upsilon_t\geq A_*\), then we have \[\begin{align*} &\sum_{t=0}^{T-1}\frac{1}{\sum_{t=0}^{T-1}\eta_t}(\eta_t B_t + \frac{1}{2}\eta_t \Gamma_{t})\leq \frac{C_\Upsilon}{\sum_{t=0}^{T-1}\eta_t} + \frac{\sum_{t=0}^{T-1}\left(\eta_t\beta_{t+1}\sigma^2 + 2C_1\eta_t\gamma_{t+1}{\sigma'}^2\right)}{\sum_{t=0}^{T-1}\eta_t}, \end{align*}\] where \(C_\Upsilon = \Upsilon_0 - A_* \leq A_0 - A_* + \frac{1}{2\sqrt{C_2}}\Delta_{0} + \sqrt{\frac{C_1}{8C_3}} \delta_{0}\).
If \(\beta =\frac{\epsilon^2}{3\sigma^2}, \gamma = \frac{\epsilon^2}{6C_1\sigma'^2}, \eta = \min(\frac{1}{L}, \frac{\beta}{\sqrt{4C_2}}, \frac{\gamma}{\sqrt{8C_1C_3}})\), then in order to guarantee \[ \sum_{t=0}^{T-1}\frac{1}{T}(B_t + \frac{1}{2}\Gamma_t)\leq \epsilon^2. \] the iteration complexity is in the order of \[ T=O\left(\max\left\{\frac{C_\Upsilon L}{\epsilon^2}, \frac{C_\Upsilon\sigma^2\sqrt{C_2}}{\epsilon^4}, \frac{C_\Upsilon\sqrt{C_1C_3}C_1\sigma'^2}{\epsilon^4}\right\}\right). \]
Proof
The proof is constructive. The idea is to construct a telescoping series of \(A_{t}+a_t\Delta_{t}+b_t \delta_{t}\) with some appropriate sequences of \(a_t, b_t\). First, we have \[\begin{align*} &A_{t+1} + a_{t+1}\Delta_{t+1} + b_{t+1} \delta_{t+1}\leq A_t + \eta_t \Delta_t - \eta_t B_t - \eta_t \Gamma_t\\ & +a_{t+1}(1-\beta_{t+1})\Delta_{t} + a_{t+1}C_1\beta_{t+1}\delta_{t+1} + a_{t+1}\frac{C_2\eta_{t}^2}{\beta_{t+1}}\Gamma_t + a_{t+1}\beta_{t+1}^2\sigma^2\\ &+b_{t+1}(1-\gamma_{t+1}) \delta_{t} + b_{t+1}\frac{C_3\eta_{t}^2}{\gamma_{t+1}} \Gamma_{t} + b_{t+1}\gamma_{t+1}^2{\sigma'}^2. \end{align*}\] Let \(a_{t+1} = \eta_t/\beta_{t+1}\leq \eta_{t-1}/\beta_{t}\) and \(b_{t+1} = C_1\eta_t(1+\gamma_{t+1})/\gamma_{t+1}\), we have \[\begin{align*} &A_{t+1} + \frac{\eta_t}{\beta_{t+1}}\Delta_{t+1} +(C_1\eta_t\frac{1+\gamma_{t+1}}{\gamma_{t+1}}- C_1\eta_t) \delta_{t+1}\leq A_t - \eta_t B_t - \eta_t \Gamma_t \\ & +\left(\eta_t+\frac{\eta_t}{\beta_{t+1}}(1-\beta_{t+1})\right)\Delta_{t} + \frac{C_2\eta_{t}^3}{\beta_{t+1}^2}\Gamma_t + \eta_t\beta_{t+1}\sigma^2\\ &+C_1\eta_t\frac{1+\gamma_{t+1}}{\gamma_{t+1}}(1-\gamma_{t+1})\delta_{t} + \frac{C_3C_1\eta_{t}^3(1+\gamma_{t+1})}{\gamma_{t+1}^2} \Gamma_t +C_1\eta_t(1+\gamma_{t+1})\gamma_{t+1}{\sigma'}^2. \end{align*}\] Thus, \[\begin{align*} &A_{t+1} + \frac{\eta_t}{\beta_{t+1}}\Delta_{t+1} + \frac{C_1\eta_t}{\gamma_{t+1}} \delta_{t+1}\leq A_t +\frac{\eta_t}{\beta_{t+1}}\Delta_{t}+\frac{C_1\eta_t}{\gamma_{t+1}}\delta_{t}\\ &- \eta_t B_t - \left(\eta_t - \frac{C_2\eta_{t}^3}{\beta_{t+1}^2} - \frac{C_3C_1\eta_{t}^3(1+\gamma_{t+1})}{\gamma_{t+1}^2}\right) \Gamma_t\\ &+ \eta_t\beta_{t+1}\sigma^2 + C_1\eta_t(1+\gamma_{t+1})\gamma_{t+1}{\sigma'}^2. \end{align*}\] Since \(\eta_t/\beta_{t+1}\leq \eta_{t-1}/\beta_{t}\) and \(\eta_t/\gamma_{t+1}\leq \eta_{t-1}/\gamma_{t}\) and \(\gamma_{t+1}\leq 1\), we have \[\begin{align*} &A_{t+1} + \frac{\eta_t}{\beta_{t+1}}\Delta_{t+1} + \frac{C_1\eta_t}{\gamma_{t+1}} \delta_{t+1}\leq A_t +\frac{\eta_{t-1}}{\beta_{t}}\Delta_{t}+\frac{C_1\eta_{t-1}}{\gamma_{t}}\delta_{t}\\ &- \eta_t B_t - \left(\eta_t - \frac{C_2\eta_{t}^3}{\beta_{t+1}^2} - \frac{2C_3C_1\eta_{t}^3}{\gamma_{t+1}^2}\right) \Gamma_t\\ &+ \eta_t\beta_{t+1}\sigma^2 + 2C_1\eta_t\gamma_{t+1}{\sigma'}^2. \end{align*}\] Since \(C_2\eta_{t}^3/\beta_{t+1}^2\leq \eta_t/4\) because \(\eta_t\leq \beta_{t+1}/\sqrt{4C_2}\) and \(2C_3C_1\eta_{t}^3/\gamma_{t+1}^2\leq \eta_t/4\) because \(\eta_t\leq \gamma_{t+1}/\sqrt{8C_1C_3}\), we have \[\begin{align*} &A_{t+1} + \frac{\eta_t}{\beta_{t+1}}\Delta_{t+1} + \frac{C_1\eta_t}{\gamma_{t+1}} \delta_{t+1}\leq A_t +\frac{\eta_{t-1}}{\beta_{t}}\Delta_{t}+\frac{C_1\eta_{t-1}}{\gamma_{t}}\delta_{t}\\ &- \eta_t B_t - \frac{1}{2}\eta_t \Gamma_t+ \eta_t\beta_{t+1}\sigma^2 + 2C_1\eta_t\gamma_{t+1}{\sigma'}^2. \end{align*}\] Define \(\Upsilon_{t+1}=A_{t+1} + \frac{\eta_t}{\beta_{t+1}}\Delta_{t+1} + \frac{C_1\eta_t}{\gamma_{t+1}} \delta_{t+1}\), we have \[\begin{align*} &\eta_t B_t + \frac{1}{2}\eta_t \Gamma_t\leq \Upsilon_{t} - \Upsilon_{t+1}+ \eta_t\beta_{t+1}\sigma^2 + 2C_1\eta_t\gamma_{t+1}{\sigma'}^2. \end{align*}\] Hence \[\begin{align*} &\sum_{t=0}^{T-1}(\eta_t B_t + \frac{1}{2}\eta_t \Gamma_t)\leq \Upsilon_{0} - A_* + \sum_{t=0}^{T-1}\left(\eta_t\beta_{t+1}\sigma^2 + 2C_1\eta_t\gamma_{t+1}{\sigma'}^2\right). \end{align*}\]
Next, let us consider \(\eta_t=\eta, \beta_t=\beta, \gamma_t = \gamma\). Then we have \[\begin{align*} &\sum_{t=0}^{T-1}\frac{1}{T}(B_t + \frac{1}{2}\Gamma_t)\leq \frac{C_\Upsilon}{\eta T} + \left(\beta\sigma^2 + 2C_1\gamma{\sigma'}^2\right). \end{align*}\] In order to ensure the RHS is less than \(\epsilon^2\), it suffices to have \[\begin{align*} \beta =\frac{\epsilon^2}{3\sigma^2},\quad \gamma = \frac{\epsilon^2}{6C_1{\sigma'}^2}, \quad T= \frac{C_\Upsilon}{3\epsilon^2\eta}. \end{align*}\] Since \[\begin{align*} \eta =\min\bigg(\frac{1}{L}, \frac{\beta}{\sqrt{4C_2}}, \frac{\gamma}{\sqrt{8C_1C_3}}\bigg), \end{align*}\] thus the order of \(T\) becomes \[\begin{align*} T&= O\left(\max\left\{\frac{C_\Upsilon L}{\epsilon^2}, \frac{C_\Upsilon\sqrt{C_2}}{\epsilon^2\beta}, \frac{C_\Upsilon\sqrt{C_1C_3}}{\gamma\epsilon^2}\right\}\right)\\ &=O\left(\max\left\{\frac{C_\Upsilon L_F}{\epsilon^2}, \frac{C_\Upsilon\sigma^2\sqrt{C_2}}{\epsilon^4}, \frac{C_\Upsilon\sqrt{C_1C_3}C_1{\sigma'}^2}{\epsilon^4}\right\}\right), \end{align*}\] where \[\begin{align*} C_\Upsilon&=A_0 - A_* + \frac{\eta}{\beta}\Delta_{0} + \frac{C_1\eta}{\gamma} \delta_{0}\leq A_0 - A_* + \frac{1}{2\sqrt{C_2}}\Delta_{0} + \frac{\sqrt{C_1}}{\sqrt{8C_3}} \delta_{0}. \end{align*}\]■
Finally, let us prove the convergence of SCMA.
Theorem 4.3 Suppose Assumption 4.1, Assumption 4.2, Assumption 4.3, and Assumption 4.4 hold. For the SCMA algorithm, set the parameters as follows: \(\beta = \frac{\epsilon^2}{3\sigma^2}, \gamma = \frac{\epsilon^2}{6C_1\sigma_0^2}\), and \(\eta = \min\left(\frac{1}{2L_F}, \frac{\beta}{\sqrt{4C_2}}, \frac{\gamma}{\sqrt{8C_1C_3}}\right)\), where \(\sigma^2 = G_2^2\sigma_1^2 + G_1^2\sigma_2^2, C_1 = 4G_2^2L_1^2, C_2 = 4L_F^2, C_3 = 2G_2^2\). Then, the following \[\begin{align*} \mathbb E\left[\frac{1}{T}\sum_{t=0}^{T-1}\left\{\frac{1}{4}\left\|\mathbf{v}_{t}\right\|_2^2 + \left\|\nabla F(\mathbf{w}_t)\right\|_2^2\right\}\right] \leq \epsilon^2 \end{align*}\] holds, with an iteration complexity of \[ T = O\left(\max\left\{\frac{C_\Upsilon L_F}{\epsilon^2}, \frac{C_\Upsilon\sigma^2L_F}{\epsilon^4}, \frac{C_\Upsilon L_1^3\sigma_0^2}{\epsilon^4}\right\}\right). \] where \(C_\Upsilon := 2\left(F(\mathbf{w}_0) - F_*\right) + \frac{1}{8L_F} \left\|\nabla F(\mathbf{w}_0) - \mathbf{v}_0\right\|_2^2 + \frac{L_1}{2}\left\|\mathbf{u}_0 - g(\mathbf{w}_0)\right\|_2^2\).
💡 Why it matters
\(\mathbf{Insights\ 1:}\) Theorem 4.3 indicates that SCMA enjoys the same
complexity of \(O(1/\epsilon^4)\) for
finding an \(\epsilon\)-stationary
solution as SGD for ERM. In addition, the averaged estimation error of
the moving-average gradient estimator \(\mathbf{v}_t\), i.e., \(\mathbb{E}[\frac{1}{T}\sum_{t=0}^{T-1}\|\mathbf{v}_{t}
- \nabla F(\mathbf{w}_t)\|_2^2]\), converges to zero as \(T\rightarrow \infty\).
\(\mathbf{Insights\ 2:}\) We can apply the above result to the Momentum method for solving the standard stochastic optimization \(\min_\mathbf{w} F(\mathbf{w}): = \mathbb{E}_\zeta [g(\mathbf{w}; \zeta)]\) by setting \(L_1=0\). The complexity of the Momentum method is \[ T = O\left(\max\left\{\frac{(F(\mathbf{w}_0)-F_*)L_F}{\epsilon^2}, \frac{(F(\mathbf{w}_0)-F_*)\sigma^2L_F}{\epsilon^4}, \frac{ \left\|\nabla F(\mathbf{w}_0) - \mathbf{v}_0\right\|_2^2\sigma^2}{\epsilon^4}\right\}\right), \] which is no worse than that of SGD in Theorem 3.3. The key advantage of the Momentum method over SGD is that it ensures the averaged estimation error of the moving-average gradient estimator \(\mathbf{v}_t\) converges to zero.
The convergence bound also suggests that it is better to initialize \(\mathbf{v}_0\) in a way such that \(\left\|\nabla F(\mathbf{w}_0) - \mathbf{v}_0\right\|_2^2\) is small, e.g., using the mini-batch gradient at \(\mathbf{w}_0\) instead of initializing it to zero.
Proof
The three inequalities in Lemma 4.8, Lemma 4.9 and Lemma 4.1 that we have proved so far are \[\begin{align*} (*)\;&F(\mathbf{w}_{t+1}) \leq F(\mathbf{w}_t) + \frac{\eta_t}{2}\left\|\nabla F(\mathbf{w}_t) - \mathbf{v}_{t}\right\|_2^2- \frac{\eta_t}{2}\left\|\nabla F(\mathbf{w}_t)\right\|_2^2- \frac{\eta_t}{4} \left\|\mathbf{v}_{t}\right\|_2^2, t\geq 0\\ (\sharp)\;&\mathbb E\left[\left\|\mathbf{v}_{t} - \nabla F(\mathbf{w}_t)\right\|_2^2\right] \leq \mathbb E[(1-\beta_t)\left\|\mathbf{v}_{t-1} - \nabla F(\mathbf{w}_{t-1})\right\|_2^2]\\ &\quad+ \mathbb E\left[4G_2^2L_1^2\beta_t\left\|\mathbf{u}_{t} - g(\mathbf{w}_t)\right\|_2^2 + \frac{2L_F^2\eta_{t-1}^2}{\beta_t}\left\|\mathbf{v}_{t-1}\right\|_2^2 + \beta_t^2\sigma^2\right],\\ (\diamond)\;&\mathbb E\left[\left\|\mathbf{u}_{t} - g(\mathbf{w}_t)\right\|_2^2\right] \leq \mathbb E[(1-\gamma_t) \left\|\mathbf{u}_{t-1} - g(\mathbf{w}_{t-1})\right\|_2^2]\\ & \quad + \mathbb E\left[\frac{G_2^2\eta_{t-1}^2}{\gamma_t} \left\|\mathbf{v}_{t-1}\right\|_2^2 + \gamma_t^2\sigma_0^2\right]. \end{align*}\]
Define \(A_t= 2(F(\mathbf{w}_t) - F_*)\) and \(B_t = \|\nabla F(\mathbf{w}_t)\|_2^2\), \(\Gamma_t = \left\|\mathbf{v}_{t}\right\|_2^2/2\), \(\Delta_t= \left\|\nabla F(\mathbf{w}_t) - \mathbf{v}_{t}\right\|_2^2\), \(\delta_{t}=\left\|\mathbf{u}_{t} - g(\mathbf{w}_t)\right\|_2^2\), and \(\Upsilon_{t}=A_{t} + \frac{\eta_{t-1}}{\beta_{t}}\Delta_{t} + \frac{C_1\eta_{t-1}}{\gamma_{t}} \delta_{t}.\)
Then the three inequalities satisfy those in Lemma 4.10 with \(C_1=4G_2^2L_1^2, C_2=4L_F^2, C_3=2G_2^2, \sigma^2 = G_1^2\sigma_2^2+G_2^2\sigma_1^2, {\sigma'}^2=\sigma_0^2\). Then \(\eta_t, \beta_t, \gamma_t\) satisfy \[ \frac{\eta_t}{\beta_{t+1}}\leq \frac{\eta_{t-1}}{\beta_{t}}, \frac{\eta_t}{\gamma_{t+1}}\leq \frac{\eta_{t-1}}{\gamma_{t}}, \eta_t\leq \min(\frac{\beta_{t+1}}{\sqrt{4C_2}}, \frac{\gamma_{t+1}}{\sqrt{8C_1C_3}}). \] Then we have \[\begin{align*} &\mathbb E\left[\sum_{t=0}^{T-1}\frac{1}{\sum_{t=0}^{T-1}\eta_t}(\eta_t \|\nabla F(\mathbf{w}_t)\|_2^2 + \frac{\eta_t}{4}\left\|\mathbf{v}_{t}\right\|_2^2)\right]\\ & \leq \frac{C_\Upsilon}{\sum_{t=1}\eta_t} + \frac{\sum_{t=0}^{T-1}\left(\eta_t\beta_{t+1}\sigma^2 + 2C_1\eta_t\gamma_{t+1}\sigma_0^2\right)}{\sum_{t=0}^{T-1}\eta_t}. \end{align*}\]
Since the setting of \(\eta, \gamma, \beta\) satisfy that in Lemma 4.10, the order of \(T\) becomes \[\begin{align*} T &=O\left(\max\left\{\frac{C_\Upsilon L_F}{\epsilon^2}, \frac{C_\Upsilon\sigma^2\sqrt{C_2}}{\epsilon^4}, \frac{C_\Upsilon\sqrt{C_1C_3}C_1\sigma_0^2}{\epsilon^4}\right\}\right)\\ &=O\left(\max\left\{\frac{C_\Upsilon L_F}{\epsilon^2}, \frac{C_\Upsilon\sigma^2L_F}{\epsilon^4}, \frac{C_\Upsilon L_1^3\sigma_0^2}{\epsilon^4}\right\}\right), \end{align*}\] where \[\begin{align*} C_\Upsilon&=2(F(\mathbf{w}_0) - F_*) + \frac{1}{2\sqrt{C_2}}\|\mathbf{v}_0 - \nabla F(\mathbf{w}_0)\|_2^2 + \frac{\sqrt{C_1}}{\sqrt{8C_3}}\|\mathbf{u}_0 - g(\mathbf{w}_0)\|_2^2\\ & =2(F(\mathbf{w}_0)-F_*) +\frac{1}{4L_F} \left\| \mathbf{v}_{0} - \nabla F(\mathbf{w}_{0})\right\|_2^2+\frac{L_1}{2}\left\|\mathbf{u}_{0} - g(\mathbf{w}_{0})\right\|_2^2. \end{align*}\]■
4.3.2 STORM Estimators
We can further reduce the error of the gradient estimator by using advanced variance reduction techniques under stronger assumptions. We make the following assumptions.
Assumption 4.6
There exists \(L_1,G_1>0\) such that
- \(\mathbb E[\|\nabla f(g;\xi)-\nabla f(g';\zeta)\|_2^2]\leq L_1^2\|g-g'\|_2^2,\forall g,g'\);
- \(\mathbb E[\|\nabla f(g;\xi)\|_2^2]\leq G_1^2,\forall g\).
Assumption 4.7
There exists \(L_2,G_2>0\) such that
- \(\mathbb E[\|\nabla g(\mathbf{w};\zeta)-\nabla g(\mathbf{w}';\zeta)\|_2^2]\leq L_2^2\|\mathbf{w}-\mathbf{w}'\|_2^2,\forall \mathbf{w},\mathbf{w}'\);
- \(\mathbb E[\|\nabla g(\mathbf{w};\zeta)\|_2^2]\leq G_2^2,\forall \mathbf{w}\).
Due to Jensen’s inequality, Assumption 4.6(i) implies the Lipschitz continuity assumption of \(\nabla f\) in Assumption 4.1(i). Similarly, Assumption 4.7(i) implies that in Assumption 4.2(i), respectively. Hence, Assumption 4.6(i) and Assumption 4.7(i) are stronger, which are referred to as mean-square smoothness condition of \(f\) and \(g\).
Let us first discuss a generic STORM estimator, an improved variant of the moving average estimator. Without loss of generality, we consider estimating a sequence of mappings \(\{\mathcal{M}(\mathbf{w}_t)\}_{t=1}^T\) through their stochastic values at each iteration \(\{\mathcal{M}(\mathbf{w}_t;\zeta_t)\}_{t=1}^T\), where \(\mathbb E_{\zeta_t}[\mathcal{M}(\mathbf{w}_t;\zeta_t)]=\mathcal{M}(\mathbf{w}_t)\in\mathbb{R}^{d'}\). We assume the mapping \(\mathcal{M}\) satisfies:
\[ \mathbb E_{\zeta}[\|\mathcal{M}(\mathbf{w};\zeta)-\mathcal{M}(\mathbf{w}';\zeta)\|_2^2]\leq G^2\|\mathbf{w}-\mathbf{w}'\|_2^2,\forall \mathbf{w},\mathbf{w}'. \]
The STORM estimator is given by a sequence of \(\mathcal{U}_1,\ldots,\mathcal{U}_T\), where
\[\begin{align} \mathcal{U}_{t} &= (1-\gamma_t)\mathcal{U}_{t-1}+\gamma_t\mathcal{M}(\mathbf{w}_t;\zeta_t)+(1-\gamma_t)(\mathcal{M}(\mathbf{w}_t;\zeta_t)-\mathcal{M}(\mathbf{w}_{t-1};\zeta_t)). \end{align}\]
and \(\gamma_t\in(0,1)\).
It augments the moving-average estimator by adding an extra term \((1-\gamma_t)(\mathcal{M}(\mathbf{w}_t;\zeta_t)-\mathcal{M}(\mathbf{w}_{t-1};\zeta_t))\), which can be viewed as an error correction term.
Applying the STORM estimator to estimating the sequence of \(\{g(\mathbf{w}_t)\}_{t\geq 1}\), we have the following sequence:
\[\begin{align} \mathbf{u}_{t} = (1-\gamma_t)\mathbf{u}_{t-1}+\gamma_t g(\mathbf{w}_t;\zeta_t)+(1-\gamma_t)(g(\mathbf{w}_t;\zeta_t)-g(\mathbf{w}_{t-1};\zeta_t)). \end{align}\]
Given \(\mathbf{u}_{t}\), we can compute a moving-average gradient estimator (1) similar to SCMA. However, this will not yield an improved rate compared with SCMA. To reduce the estimator error of the gradient, we apply another STORM estimator to estimate \(\mathcal{M}_{t}=\nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_{t})\). This is computed by the following sequence:
\[\begin{align}\label{eqn:cstorm_grad} \mathbf{v}_{t} & = (1-\beta_t)\mathbf{v}_{t-1}+\beta_t\nabla g(\mathbf{w}_t;\zeta_t')\nabla f(\mathbf{u}_{t};\xi_t)\\ &\quad +(1-\beta_t)(\nabla g(\mathbf{w}_t;\zeta_t')\nabla f(\mathbf{u}_{t};\xi_t)-\nabla g(\mathbf{w}_{t-1};\zeta_t')\nabla f(\mathbf{u}_{t-1};\xi_t)).\notag \end{align}\]
With \(\mathbf{v}_{t}\), we update the model parameters by
\[\begin{align*} \mathbf{w}_{t+1} = \mathbf{w}_t-\eta\mathbf{v}_{t}. \end{align*}\]
The full steps of this method is presented in Algorithm 11, which is referred to as SCST.
In the special case where \(f\) is the identity function, the update is identical to the classical variance-reduced method (also known as STORM) for non-convex optimization \(\min_{\mathbf{w}}\mathbb E_{\zeta}[g(\mathbf{w};\zeta)]\), i.e.,
\[\begin{equation}\label{eqn:storm} \begin{aligned} &\mathbf{v}_{t} = (1-\beta_t)\mathbf{v}_{t-1}+\beta_t\nabla g(\mathbf{w}_t;\zeta_t')+(1-\beta_t)(\nabla g(\mathbf{w}_t;\zeta_t')-\nabla g(\mathbf{w}_{t-1};\zeta_t')),\\ &\mathbf{w}_{t+1} = \mathbf{w}_t-\eta_t\mathbf{v}_{t}. \end{aligned} \end{equation}\]
It is renowned for its improved complexity of \(O(1/\epsilon^3)\) better than the complexity \(O(1/\epsilon^4)\) of SGD for finding an \(\epsilon\)-stationary solution.
- Input: learning rate schedules \(\{\eta_t\}_{t=0}^{T}\), \(\{\gamma_t\}_{t=1}^{T}\); starting points \(\mathbf{w}_0\), \(\mathbf{u}_0\), \(\mathbf{v}_0\)
- Let \(\mathbf{w}_1=\mathbf{w}_0-\eta_0\mathbf{v}_0\)
- For \(t=1,\dotsc,T\)
- Sample \(\zeta_t,\zeta'_t\) and \(\xi_t\)
-
Update the inner
function value estimator
\[\mathbf{u}_{t}=(1-\gamma_t)\mathbf{u}_{t-1}+\gamma_t g(\mathbf{w}_t;\zeta_t)+(1-\gamma_t)(g(\mathbf{w}_t;\zeta_t)-g(\mathbf{w}_{t-1};\zeta_t))\]
- Compute the vanilla gradient estimator \(\mathbf{z}_t=\nabla g(\mathbf{w}_t;\zeta_t')\nabla f(\mathbf{u}_{t};\xi_t)\)
- Compute \(\tilde{\mathbf{z}}_{t-1}=\nabla g(\mathbf{w}_{t-1};\zeta_t')\nabla f(\mathbf{u}_{t-1};\xi_t)\)
- Update the STORM gradient estimator \(\mathbf{v}_{t}=(1-\beta_t)\mathbf{v}_{t-1}+\beta_t\mathbf{z}_t+(1-\beta_t)(\mathbf{z}_t-\tilde{\mathbf{z}}_{t-1})\)
- Update the model by \(\mathbf{w}_{t+1}=\mathbf{w}_t-\eta_t\mathbf{v}_{t}\)
Convergence Analysis
We first prove a general result of the STORM estimator that applies to both \(\mathbf{u}_t\) and \(\mathbf{v}_t\).
Lemma 4.11
Consider \(\mathbf{v}_{t}=(1-\beta_t)\mathbf{v}_{t-1}+\beta_t\mathbf{z}_t+(1-\beta_t)(\mathbf{z}_t-\tilde{\mathbf{z}}_{t-1})\), where \(\beta_t\in(0,1)\). Let \(\mathbb E_t\) denote the expectation over randomness associated with \(\mathbf{z}_t,\tilde{\mathbf{z}}_{t-1}\) condition on the randomness before \(t\)-the iteration. If \(\mathbb E_t[\mathbf{z}_t]=\mathcal{M}_t\) and \(\mathbb E_t[\tilde{\mathbf{z}}_{t-1}]=\mathcal{M}_{t-1}\). If \(\mathbb E_t[\|\mathbf{z}_t-\mathcal{M}_t\|_2^2]\leq \sigma^2\), then we have
\[\begin{align*} \mathbb E_t\left[\|\mathbf{v}_{t}-\mathcal{M}_t\|^2\right] \leq (1-\beta_t)\|\mathbf{v}_{t-1}-\mathcal{M}_{t-1}\|_2^2+\mathbb E_t[2\|\mathbf{z}_t-\tilde{\mathbf{z}}_{t-1}\|_2^2]+2\beta_t^2\sigma^2. \end{align*}\]
Proof
\[\begin{align*} &\mathbb E_{t}\left[\|\mathbf{v}_{t}-\mathcal{M}_t\|^2\right]=\mathbb E_{t}\left[\|(1-\beta_t)\mathbf{v}_{t-1}-\mathcal{M}_t+\beta_t\mathbf{z}_t+(1-\beta_t)(\mathbf{z}_t-\tilde{\mathbf{z}}_{t-1})\|^2\right]\\ &=\mathbb E_{t}\left[\|(1-\beta_t)(\mathbf{v}_{t-1}-\mathcal{M}_{t-1})+(1-\beta_t)((\mathbf{z}_t-\tilde{\mathbf{z}}_{t-1})-(\mathcal{M}_t-\mathcal{M}_{t-1}))+\beta_t(\mathbf{z}_t-\mathcal{M}_t)\|_2^2\right]. \end{align*}\]
Note that
\[\begin{align*} \mathbb E_{t}&\left[\langle(1-\beta_t)(\mathbf{v}_{t-1}-\mathcal{M}_{t-1}), (1-\beta_t)((\mathbf{z}_t-\tilde{\mathbf{z}}_{t-1})-(\mathcal{M}_t-\mathcal{M}_{t-1}))+\beta_t(\mathbf{z}_t-\mathcal{M}_t)\rangle\right]=0. \end{align*}\]
Then,
\[\begin{align*} &\mathbb E_t\left[\|\mathbf{v}_{t}-\mathcal{M}_t\|^2\right]\leq (1-\beta_t)^2\|\mathbf{v}_{t}-\mathcal{M}_{t-1}\|_2^2\\ &\quad +\|(1-\beta_t)((\mathbf{z}_t-\tilde{\mathbf{z}}_{t-1})-(\mathcal{M}_t-\mathcal{M}_{t-1}))+\beta_t(\mathbf{z}_t-\mathcal{M}_t)\|_2^2\\ &\stackrel{(\diamond)}{\leq} (1-\beta_t)^2\|\mathbf{v}_{t}-\mathcal{M}_{t-1}\|^2\\ &\quad +2(1-\beta_t)^2\mathbb E_t[\|((\mathbf{z}_t-\tilde{\mathbf{z}}_{t-1})-(\mathcal{M}_t-\mathcal{M}_{t-1}))\|_2^2]+2\beta_t^2\mathbb E_t[\|\mathbf{z}_t-\mathcal{M}_t\|_2^2]\\ &\stackrel{(*)}{\leq} (1-\beta_t)^2\|\mathbf{v}_{t-1}-\mathcal{M}_{t-1}\|_2^2+2(1-\beta_t)^2\mathbb E_t[\|\mathbf{z}_t-\tilde{\mathbf{z}}_{t-1}\|_2^2]+2\beta_t^2\sigma^2, \end{align*}\]
where \((\diamond)\) uses the Young’s inequality, \((*)\) uses the fact that \(\mathbb E[\|a-\mathbb E[a]\|_2^2]\leq \mathbb E[\|a\|_2^2]\), and \(\mathbb E_t[\mathbf{z}_t-\tilde{\mathbf{z}}_{t-1}]=\mathcal{M}_t-\mathcal{M}_{t-1}\).
■
Let us first prove an error recursion of \(\mathbf{u}_{t}\) in the lemma below.
Lemma 4.12
Under Assumption 4.7(ii), we have:
\[\begin{equation*} \begin{split} &\mathbb E_{\zeta_t}\left[\|\mathbf{u}_{t}-g(\mathbf{w}_t)\|^2 \right] \leq (1-\gamma_t)\|\mathbf{u}_{t-1}-g(\mathbf{w}_{t-1})\|_2^2+2\gamma_t^2\sigma_0^2+2G_2^2\|\mathbf{w}_{t}-\mathbf{w}_{t-1}\|_2^2\\ &\mathbb E_{\zeta_t}\left[\|\mathbf{u}_{t}-\mathbf{u}_{t-1}\|_2^2\right] \leq 2\gamma_{t}^2\sigma_0^2+4\gamma_{t}^2\|\mathbf{u}_{t-1}-g(\mathbf{w}_{t-1})\|_2^2+6G_2^2\|\mathbf{w}_t-\mathbf{w}_{t-1}\|_2^2. \end{split} \end{equation*}\]
💡 Why it matters
Compared to the error recursion of \(\mathbf{u}_{t}\) to that in Lemma 4.1, the improvement comes
from the last term reducing from \(\frac{2G_2^2\|\mathbf{w}_{t}-\mathbf{w}_{t-1}\|_2^2}{\gamma_t}\)
to \(2G_2^2\|\mathbf{w}_{t}-\mathbf{w}_{t-1}\|_2^2\).
Proof
The first part follows directly from Lemma 4.11 by noting the mean-Lipschitz continuity of \(g(\mathbf{w};\zeta)\).
To prove the second part, we proceed as follows:
\[\begin{equation*} \begin{split} \mathbb E_{t}&\left[\|\mathbf{u}_{t}-\mathbf{u}_{t-1}\|_2^2 \right]\\ =&\mathbb E_{t}\left[\left\|\gamma_{t}\left(g(\mathbf{w}_{t};\zeta_{t})-\mathbf{u}_{t-1}\right)+(1-\gamma_{t})\left(g(\mathbf{w}_{t};\zeta_{t})-g(\mathbf{w}_{t-1};\zeta_{t})\right)\right\|_2^{2}\right]\\ \leq &\mathbb E_{t}\left[2\gamma_{t}^2\left\|g(\mathbf{w}_{t};\zeta_{t})-\mathbf{u}_{t-1}\right\|_2^2+2(1-\gamma_{t})^2\left\|g(\mathbf{w}_{t};\zeta_{t})-g(\mathbf{w}_{t-1};\zeta_{t})\right\|_2^{2}\right]\\ \leq &\mathbb E_{t}\left[2\gamma_{t}^2\left\|g(\mathbf{w}_{t};\zeta_{t})-\mathbf{u}_{t-1}\right\|_2^2\right]+2(1-\gamma_{t})^2G_2^2\|\mathbf{w}_{t}-\mathbf{w}_{t-1}\|_2^{2}. \end{split} \end{equation*}\]
Next, we bound the first term on the RHS as
\[\begin{align*} &\mathbb E_{t}\left[2\gamma_{t}^2\left\|g(\mathbf{w}_{t};\zeta_{t})-\mathbf{u}_{t-1}\right\|_2^2\right]\\ &=\mathbb E_{t}\left[2\gamma_{t}^2\left\|g(\mathbf{w}_{t};\zeta_{t})-g(\mathbf{w}_t)+g(\mathbf{w}_t)-\mathbf{u}_{t-1}\right\|_2^2\right]\\ &\leq 2\gamma_{t}^2\sigma_0^2+2\gamma_{t}^2\|g(\mathbf{w}_t)-\mathbf{u}_{t-1}\|_2^2\\ &\leq 2\gamma_{t}^2\sigma_0^2+2\gamma_{t}^2\|g(\mathbf{w}_t)-g(\mathbf{w}_{t-1})+g(\mathbf{w}_{t-1})-\mathbf{u}_{t-1}\|_2^2\\ &\leq 2\gamma_{t}^2\sigma_0^2+4\gamma_{t}^2\|g(\mathbf{w}_{t-1})-\mathbf{u}_{t-1}\|_2^2+4\gamma_{t}^2G_2^2\|\mathbf{w}_t-\mathbf{w}_{t-1}\|_2^2, \end{align*}\]
where the first inequality uses the fact \(\mathbb E[g(\mathbf{w}_{t};\zeta_{t})-g(\mathbf{w}_{t})]=0\). Combining the above results, we finish the proof.
■
Next, we build an error recursion of \(\|\mathbf{v}_{t}-\mathcal{M}_t\|_2^2\).
Lemma 4.13
Let \(\sigma^2=G_2^2\sigma_1^2+G_1^2\sigma_2^2\). Under Assumption 4.6 and Assumption 4.7, (\(\ref{eqn:cstorm_grad}\)) satisfies that
\[\begin{align}\label{eq:ncstorm_grad} &\mathbb E_{\zeta'_t,\xi_t}\left[\|\mathbf{v}_{t}-\mathcal{M}_t\|^2\right]\leq (1-\beta_t)\|\mathbf{v}_{t-1}-\mathcal{M}_{t-1}\|_2^2\\ &\quad +16G_2^2L_1^2\gamma_{t}^2\|\mathbf{u}_{t-1}-g(\mathbf{w}_{t-1})\|_2^2+(24G_2^4L_1^2+4G_1^2L_2^2)\|\mathbf{w}_t-\mathbf{w}_{t-1}\|_2^2\notag\\ &\quad +2\beta_t^2\sigma^2+8G_2^2L_1^2\gamma_{t}^2\sigma_0^2.\notag \end{align}\]
Proof
First, (6) gives \(\mathbb E_t[\|\mathbf{z}_t-\mathcal{M}_t\|_2^2]\leq \sigma^2\).
Second,
\[\begin{equation*} \begin{aligned} &\mathbb E_t\left[\|\mathbf{z}_t-\tilde{\mathbf{z}}_{t-1}\|_2^2\right]\\ &=\mathbb E_t[\|\nabla g(\mathbf{w}_t;\zeta_t')\nabla f(\mathbf{u}_{t};\xi_t)-\nabla g(\mathbf{w}_{t-1};\zeta_t')\nabla f(\mathbf{u}_{t-1};\xi_t)\|_2^2]\\ &=\mathbb E_t[\|\nabla g(\mathbf{w}_t;\zeta_t')\nabla f(\mathbf{u}_{t};\xi_t)-\nabla g(\mathbf{w}_t;\zeta_t')\nabla f(\mathbf{u}_{t-1};\xi_t)\\ &\quad\quad\quad +\nabla g(\mathbf{w}_t;\zeta_t')\nabla f(\mathbf{u}_{t-1};\xi_t)-\nabla g(\mathbf{w}_{t-1};\zeta_t')\nabla f(\mathbf{u}_{t-1};\xi_t)\|_2^2]\\ &\stackrel{(\triangle)}{\leq} 2G_2^2L_1^2\|\mathbf{u}_{t}-\mathbf{u}_{t-1}\|_2^2+2G_1^2L_2^2\|\mathbf{w}_t-\mathbf{w}_{t-1}\|_2^2, \end{aligned} \end{equation*}\]
where \((\triangle)\) uses Assumption 4.6(i) and Assumption 4.7(i).
It then follows:
\[\begin{align*} \mathbb E_t\left[\|\mathbf{v}_{t}-\mathcal{M}_t\|^2\right] &\leq (1-\beta_t)^2\|\mathbf{v}_{t-1}-\mathcal{M}_{t-1}\|_2^2\\ &\quad +4G_2^2L_1^2\|\mathbf{u}_{t}-\mathbf{u}_{t-1}\|_2^2+4G_1^2L_2^2\|\mathbf{w}_t-\mathbf{w}_{t-1}\|_2^2+2\beta_t^2\sigma^2. \end{align*}\]
By using the second inequality of Lemma 4.12, i.e.,
\[\begin{equation*} \mathbb E_{\zeta_t}\left[\|\mathbf{u}_{t}-\mathbf{u}_{t-1}\|_2^2\right]\leq 2\gamma_{t}^2\sigma_0^2+4\gamma_{t}^2\|\mathbf{u}_{t-1}-g(\mathbf{w}_{t-1})\|_2^2+6G_2^2\|\mathbf{w}_t-\mathbf{w}_{t-1}\|_2^2, \end{equation*}\]
we have
\[\begin{align*} &\mathbb E_t\left[\|\mathbf{v}_{t}-\mathcal{M}_t\|^2\right]\leq (1-\beta_t)\|\mathbf{v}_{t-1}-\mathcal{M}_{t-1}\|_2^2+16G_2^2L_1^2\gamma_{t}^2\|\mathbf{u}_{t-1}-g(\mathbf{w}_{t-1})\|_2^2\\ &\quad +(24G_2^4L_1^2+4G_1^2L_2^2)\|\mathbf{w}_t-\mathbf{w}_{t-1}\|_2^2+2\beta_t^2\sigma^2+8G_2^2L_1^2\gamma_{t}^2\sigma_0^2. \end{align*}\]
■
Similar to Lemma 4.9, we have the following descent lemma.
Lemma 4.14
For the update \(\mathbf{w}_{t+1}=\mathbf{w}_t-\eta_t\mathbf{v}_{t},t\geq 0\), if \(\eta_t\leq 1/(2L_F)\) we have
\[\begin{align}\label{eq:csstorm_starter} F(\mathbf{w}_{t+1})&\leq F(\mathbf{w}_t)+\eta_tG_2^2L_1^2\|\mathbf{u}_{t}-g(\mathbf{w}_t)\|_2^2+\eta_t\|\mathbf{v}_{t}-\mathcal M_t\|_2^2\notag\\ &\quad -\frac{\eta_t}{2}\|\nabla F(\mathbf{w}_t)\|_2^2-\frac{1}{4\eta_t}\|\mathbf{w}_{t+1}-\mathbf{w}_t\|_2^2. \end{align}\]
This lemma can be proved following that of Lemma 4.9 by bounding \(\|\mathbf{v}_{t}-\nabla F(\mathbf{w}_t)\|_2^2\leq 2\|\mathbf{v}_{t}-\mathcal{M}_t\|_2^2+2\|\mathcal{M}_t-\nabla F(\mathbf{w}_t)\|_2^2\leq 2\|\mathbf{v}_{t}-\mathcal{M}_t\|_2^2+2G_2^2L_1^2\|\mathbf{u}_{t}-g(\mathbf{w}_t)\|_2^2\).
Lemma 4.15
For \(\eta_t\leq 1/L\), the non-negative sequences \(A_t,B_t,\Gamma_t,\Delta_t,\delta_t,t\geq 0\) satisfy:
\[\begin{align*} (*)\;&A_{t+1}\leq A_t+\eta_t\Delta_t+\eta_t\delta_t-\eta_t B_t-\eta_t\Gamma_t\\ (\sharp)\;&\Delta_{t+1}\leq (1-\beta_{t+1})\Delta_{t}+C_1\gamma_{t+1}^2\delta_{t}+C_2\eta_{t}^2\Gamma_t+\beta_{t+1}^2\sigma^2+\gamma_{t+1}^2\sigma'^2,\\ (\diamond)\;&\delta_{t+1}\leq (1-\gamma_{t+1})\delta_{t}+C_3\eta_{t}^2\Gamma_t+\gamma_{t+1}^2{\sigma''}^2. \end{align*}\]
Let \(\Upsilon_{t+1}=A_{t+1}+\frac{c}{\eta_t}\Delta_{t+1}+\frac{c'}{\eta_{t}}\delta_{t+1}\geq A_*\). Suppose \(c,c',\eta_t,\gamma_t,\beta_t\) satisfy:
\[\begin{equation} \label{eqn:lrr} \begin{aligned} &C_2c+C_3c'\leq \frac{1}{2},\quad \eta_t+\frac{c}{\eta_t}(1-\beta_{t+1})\leq \frac{c}{\eta_{t-1}},\\ &\eta_t+\frac{c}{\eta_t}C_1\gamma_{t+1}^2+\frac{c'}{\eta_t}(1-\gamma_{t+1})\leq \frac{c'}{\eta_{t-1}}. \end{aligned} \end{equation}\]
Then,
\[\begin{align}\label{eqn:3seq-1} &\sum_{t=0}^{T-1}(\eta_t B_t+\frac{1}{2}\eta_t\Gamma_t)\leq C_\Upsilon+\sum_{t=0}^{T-1}\left(\frac{c\beta^2_{t+1}}{\eta_t}\sigma^2+\frac{c\gamma^2_{t+1}}{\eta_t}\sigma'^2+\frac{c'\gamma^2_{t+1}}{\eta_t}\sigma''^2\right). \end{align}\]
If we set \(c=\frac{1}{4C_2},c'=\frac{1}{4C_3},\beta_t=\frac{\epsilon\eta\sqrt{C_2}}{\sigma},\gamma_t=\min\left(\frac{\epsilon\eta\sqrt{C_2}}{\sigma'},\frac{\epsilon\eta\sqrt{C_3}}{\sigma''},\frac{C_2}{2C_3C_1}\right)\), and \(\eta_t=\eta=\min\left(\frac{1}{L},\frac{\epsilon}{4\sqrt{C_2}\sigma},\frac{\epsilon\sqrt{C_2}}{8C_3\sigma'},\frac{\epsilon}{8\sqrt{C_3}\sigma''},\frac{\sqrt{C_2}}{4C_3\sqrt{C_1}}\right)\), then in order to grantee
\[\begin{align}\label{eqn:3seq-2} \sum_{t=0}^{T-1}\frac{1}{T}(B_t+\frac{1}{2}\Gamma_t)\leq \epsilon^2, \end{align}\]
the iteration complexity is in the order of
\[ T=O\left(\max\left\{\frac{C_\Upsilon L}{\epsilon^2},\frac{C_\Upsilon C_3\sqrt{C_1/C_2}}{\epsilon^2},\frac{C_\Upsilon\sigma\sqrt{C_2}}{\epsilon^3},\frac{C_\Upsilon C_3\sigma'}{\epsilon^3\sqrt{C_2}},\frac{C_\Upsilon\sigma''\sqrt{C_3}}{\epsilon^3}\right\}\right) \]
where \(C_\Upsilon=\Upsilon_0-A_*=A_{0}+\frac{1}{4C_2\eta}\Delta_{0}+\frac{1}{4C_3\eta}\delta_{0}-A_*\).
If \((*),(\sharp),(\diamond)\) hold in expectation, then the two inequalities in (\(\ref{eqn:3seq-1}\)) and (\(\ref{eqn:3seq-2}\)) hold in expectation.
Proof
The proof is constructive. The idea is to multiply the second inequality by \(a_{t+1}\) and the third inequality by \(b_{t+1}\) such that we can construct a telescoping series of \(A_{t}+a_t\Delta_{t}+b_t\delta_{t}\). First, we have
\[\begin{align*} &A_{t+1}+a_{t+1}\Delta_{t+1}+b_{t+1}\delta_{t+1}\leq A_t+\eta_t\Delta_t+\eta_t\delta_t-\eta_t B_t-\eta_t\Gamma_t\\ &\quad +a_{t+1}(1-\beta_{t+1})\Delta_{t}+a_{t+1}C_1\gamma^2_{t+1}\delta_{t}+a_{t+1}C_2\eta_{t}^2\Gamma_t+a_{t+1}\beta_{t+1}^2\sigma^2+a_{t+1}\gamma_{t+1}^2{\sigma'}^2\\ &\quad +b_{t+1}(1-\gamma_{t+1})\delta_{t}+b_{t+1}C_3\eta_{t}^2\Gamma_{t}+b_{t+1}\gamma_{t+1}^2{\sigma''}^2. \end{align*}\]
Let \(a_{t+1}=c/\eta_t\) and \(b_{t+1}=c'/\eta_t\), we have
\[\begin{align*} &A_{t+1}+\frac{c}{\eta_{t}}\Delta_{t+1}+\frac{c'}{\eta_t}\delta_{t+1}\leq A_t-\eta_t B_t-\eta_t \Gamma_t\\ &\quad +\left(\eta_t+\frac{c}{\eta_t}(1-\beta_{t+1})\right)\Delta_{t}+C_2c\eta_t\Gamma_t+\frac{c\beta^2_{t+1}}{\eta_t}\sigma^2+\frac{c\gamma^2_{t+1}}{\eta_t}\sigma'^2\\ &\quad +\left(\eta_t+\frac{c}{\eta_t}C_1\gamma_{t+1}^2+\frac{c'}{\eta_t}(1-\gamma_{t+1})\right)\delta_{t}+C_3c'\eta_{t}\Gamma_t+\frac{c'\gamma^2_{t+1}}{\eta_t}\sigma''^2. \end{align*}\]
With (\(\ref{eqn:lrr}\)) we have
\[\begin{align*} &A_{t+1}+\frac{c}{\eta_{t}}\Delta_{t+1}+\frac{c'}{\eta_t}\delta_{t+1}\leq A_t+\frac{c}{\eta_{t-1}}\Delta_{t}+\frac{c'}{\eta_{t-1}}\delta_{t}-\eta_t B_t-\frac{1}{2}\eta_t\Gamma_t\\ &\quad +\frac{c\beta^2_{t+1}}{\eta_t}\sigma^2+\frac{c\gamma^2_{t+1}}{\eta_t}\sigma'^2+\frac{c'\gamma^2_{t+1}}{\eta_t}\sigma''^2. \end{align*}\]
Define \(\Upsilon_{t+1}=A_{t+1}+\frac{c}{\eta_t}\Delta_{t+1}+\frac{c'}{\eta_{t}}\delta_{t+1}\), we have
\[\begin{align*} \eta_t B_t+\frac{1}{2}\eta_t\Gamma_t\leq \Upsilon_{t}-\Upsilon_{t+1}+\frac{c\beta^2_{t+1}}{\eta_t}\sigma^2+\frac{c\gamma^2_{t+1}}{\eta_t}\sigma'^2+\frac{c'\gamma^2_{t+1}}{\eta_t}\sigma''^2. \end{align*}\]
Hence
\[\begin{align*} &\sum_{t=0}^{T-1}(\eta_t B_t+\frac{1}{2}\eta_t\Gamma_t)\leq \Upsilon_{0}-A_*+\sum_{t=0}^{T-1}\left(\frac{c\beta^2_{t+1}}{\eta_t}\sigma^2+\frac{c\gamma^2_{t+1}}{\eta_t}\sigma'^2+\frac{c'\gamma^2_{t+1}}{\eta_t}\sigma''^2\right). \end{align*}\]
Next, let us consider \(\eta_t=\eta,\beta_t=\beta,\gamma_t=\gamma\). Then we have
\[\begin{align*} &\sum_{t=0}^{T-1}\frac{1}{T}(B_t+\frac{1}{2}\Gamma_t)\leq \frac{\Upsilon_{0}-A_*}{\eta T}+\left(\frac{c\beta^2}{\eta^2}\sigma^2+\frac{c\gamma^2}{\eta^2}\sigma'^2+\frac{c'\gamma^2}{\eta^2}\sigma''^2\right). \end{align*}\]
In order to ensure the RHS is less than \(\epsilon^2\), it suffices to have
\[\begin{align*} &\beta=\frac{\epsilon\eta}{2\sqrt{c}\sigma},\quad \gamma=\min\left(\frac{\epsilon\eta}{2\sqrt{c}\sigma'},\frac{\epsilon\eta}{2\sqrt{c'}\sigma''}\right),\quad T=\frac{C_\Upsilon}{4\epsilon^2\eta}. \end{align*}\]
To ensure (\(\ref{eqn:lrr}\)), it suffices to have
\[\begin{align*} \eta^2\leq c\beta,\quad C_1c\gamma\leq c'/2,\quad \eta^2\leq c'\gamma/2,\quad c=\frac{1}{4C_2},\quad c'=\frac{1}{4C_3}. \end{align*}\]
As a result, if we set
\[\begin{align*} \eta &=\min\left(\frac{1}{L},\frac{\epsilon\sqrt{c}}{2\sigma},\frac{\epsilon c'}{4\sqrt{c}\sigma'},\frac{\epsilon \sqrt{c'}}{4\sigma''},\frac{c'}{2\sqrt{cC_1}}\right)\\ &=\min\left(\frac{1}{L},\frac{\epsilon}{4\sqrt{C_2}\sigma},\frac{\epsilon \sqrt{C_2}}{8C_3\sigma'},\frac{\epsilon}{8\sqrt{C_3}\sigma''},\frac{\sqrt{C_2}}{4C_3\sqrt{C_1}}\right)\\ \beta&=\frac{\epsilon\eta\sqrt{C_2}}{\sigma},\quad \gamma=\min\left(\frac{\epsilon\eta\sqrt{C_2}}{\sigma'},\frac{\epsilon\eta\sqrt{C_3}}{\sigma''},\frac{C_2}{2C_3C_1}\right), \end{align*}\]
we have
\[\begin{equation*} \sum_{t=0}^{T-1}\frac{1}{T}(B_t+\frac{1}{2}\Gamma_t)\leq \epsilon^2. \end{equation*}\]
Plugging the values of \(\eta\) into the requirement of \(T\) yields the order of \(T\).
■
Theorem 4.4
Suppose that Assumption 4.3, Assumption 4.6, and Assumption 4.7 hold. For SCST, in order to guarantee
\[\begin{align*} \mathbb E\left[\frac{1}{T}\sum_{t=0}^{T-1}\left\{\frac{1}{4}\|\mathbf{v}_{t}\|^2+\|\nabla F(\mathbf{w}_t)\|^2\right\}\right]\leq \epsilon^2, \end{align*}\]
we can set the parameters as \(\eta=\min\{O(\frac{1}{L_F}),O(\frac{\epsilon}{L_1\sigma}),O(\frac{\epsilon}{L_1^2\sigma_0})\}\), \(\beta=O(\frac{\epsilon\eta L_1}{\sigma})\), and \(\gamma=\min\{O(\frac{\epsilon\eta}{\sigma_0}),1\}\), and the iteration complexity is
\[ T=O\left(\max\left(\frac{C_\Upsilon L_1(\sigma_1+\sigma_2)}{\epsilon^3},\frac{C_\Upsilon \sigma_0L_1^2}{\epsilon^3},\frac{C_\Upsilon L_F}{\epsilon^2}\right)\right), \]
where \(C_\Upsilon=O(F(\mathbf{w}_0)-F_*+\frac{1}{L_1^2\eta}\|\nabla g(\mathbf{w}_0)\nabla f(\mathbf{u}_0)-\mathbf{v}_0\|_2^2+\frac{1}{L_1^2\eta}\|g(\mathbf{w}_0)-\mathbf{u}_0\|_2^2)\).
💡 Why it matters
We only explicitly maintain the dependence on \(L_1\), which will have implications when we
handle non-smooth \(f\) in next
Chapter.
The above theorem can help us establish an improved iteration complexity of \(O(1/\epsilon^3)\). First, we need to ensure \(C_\Upsilon=O(1)\), which can be satisfied by using a large initial batch size. In particular, we can set \(\mathbf{u}_0=\frac{1}{B_0}\sum_{i=1}^{B_0}g(\mathbf{w}_0;\zeta_i)\), \(\mathbf{v}_0=\frac{1}{B_0}\sum_{i=1}^{B_0}\nabla g(\mathbf{w}_0;\zeta'_i)\nabla f(\mathbf{u}_0;\xi_i)\), where \(\{\zeta_i,\zeta'_i,\zeta_i\}_{i=1}^{B_0}\) are independent random variables. Thus, we have \(\mathbb E[\|\mathbf{u}_0-g(\mathbf{w}_0)\|_2^2]\leq O(\frac{1}{B_0})\) and \(\mathbb E[\|\mathbf{v}_0-\nabla g(\mathbf{w}_0)\nabla f(\mathbf{u}_0)\|_2^2]\leq O(\frac{1}{B_0})\). Hence, if we set \(B_0=O(\frac{\sigma}{L_1\epsilon},\frac{\sigma_0}{\epsilon})\) we have \(C_\Upsilon=O(1)\). This initial batch size requirement can be removed by using a decreasing parameters \(\eta_t=O(1/t^{1/3}),\beta_t=O(1/t^{2/3}),\gamma_t=O(1/t^{2/3})\).
Compared to the result of SCMA in Theorem 4.3, SCST has a higher order of step size \(\eta\) and a smaller order of iteration complexity.
Proof
Let us recall the three inequalities in Lemma 4.14, Lemma 4.13 and Lemma 4.12:
\[\begin{align*} (*)\;&F(\mathbf{w}_{t+1})\leq F(\mathbf{w}_t)+\eta_tG_2^2L_1^2\|\mathbf{u}_{t}-g(\mathbf{w}_t)\|_2^2+\eta_t\|\mathbf{v}_{t}-\mathcal{M}_t\|^2-\frac{\eta_t}{2}\|\nabla F(\mathbf{w}_t)\|_2^2\\ &\quad -\frac{1}{4\eta_t}\|\mathbf{w}_{t+1}-\mathbf{w}_t\|_2^2,\\ (\sharp)\;&\mathbb E\left[\|\mathbf{v}_{t}-\mathcal{M}_t\|^2\right]\leq \mathbb E[(1-\beta_t)\|\mathbf{v}_{t-1}-\mathcal{M}_{t-1}\|_2^2]+16G_2^2L_1^2\gamma_{t}^2\|\mathbf{u}_{t-1}-g(\mathbf{w}_{t-1})\|_2^2] \\ &\quad +\mathbb E[(24G_2^4L_1^2+4G_1^2L_2^2)\|\mathbf{w}_t-\mathbf{w}_{t-1}\|_2^2+2\beta_t^2\sigma^2+8G_2^2L_1^2\gamma_{t}^2\sigma_0^2],\\ (\diamond)\;&\mathbb E_{\zeta_t}\left[\|\mathbf{u}_{t}-g(\mathbf{w}_t)\|^2 \right]\leq (1-\gamma_t)\|\mathbf{u}_{t-1}-g(\mathbf{w}_{t-1})\|_2^2\\ &\quad +\mathbb E[2G_2^2\|\mathbf{w}_{t}-\mathbf{w}_{t-1}\|_2^2+2\gamma_t^2\sigma_0^2]. \end{align*}\]
Define
\[\begin{align*} A_t &= 2(F(\mathbf{w}_t)-F_*),\quad B_t=\|\nabla F(\mathbf{w}_t)\|_2^2,\\ \Gamma_t&= \|\mathbf{v}_{t}\|_2^2/2,\quad \Delta_t=2\|\mathbf{v}_{t}-\mathcal M_t\|_2^2,\quad \delta_t=2L_1^2G_2^2\|\mathbf{u}_{t}-g(\mathbf{w}_t)\|_2^2. \end{align*}\]
They satisfy the three inequalities marked by \(*,\sharp,\diamond\) in Lemma 4.15 with \(C_1=O(1),C_2=O(G_2^4L_1^2+G_1^2L_2^2),C_3=O(L_1^2G_2^4),\sigma^2=O(G_2^2\sigma_1^2+G_1^2\sigma_2^2),\sigma'^2=O(L_1^2G_2^2\sigma_0^2),\sigma''^2=O(L_1^2G_2^2\sigma_0^2)\). Plugging these into Lemma 4.15, we can finish the proof.
■