← Go Back

Section 4.1 Stochastic Compositional Optimization

We have seen several advanced machine learning frameworks in Chapter 2, including DRO, GDRO, EXM, and COCE. Unfortunately, existing stochastic gradient methods such as SGD are not directly applicable to these new problems. The reason will become clear shortly. To address this challenge, we need new optimization tools.

In this chapter, we will consider a family of stochastic optimization problems called stochastic compositional optimization (SCO), whose objective is given by

\[\begin{align}\label{eqn:sco} \min_{\mathbf{w}\in\mathbb{R}^d} F(\mathbf{w}) : = \mathbb{E}_{\xi}f(\mathbb{E}_{\zeta}g(\mathbf{w}; \zeta); \xi), \end{align}\]

where \(\xi\) and \(\zeta\) are random variables, \(g(\cdot; \zeta):\mathbb{R}^d\rightarrow\mathbb{R}^{d'}\) is the inner random function, and \(f(\cdot; \xi):\mathbb{R}^{d'}\rightarrow\mathbb{R}\) is the outer random function. We assume both \(f\) and \(g\) are differentiable. Let \(f(\cdot)=\mathbb{E}_{\xi}f(\cdot; \xi)\) and \(g(\cdot) = \mathbb{E}_{\zeta}g(\cdot; \zeta)\). Then the objective function \(F(\mathbf{w}) = f(g(\mathbf{w}))\) is a composition of two functions.

Examples

Example 4.1

The KL-regularized DRO is a special case of SCO by setting \(f(\cdot) = \tau \log(\cdot)\) and \(g(\mathbf{w}) = \frac{1}{n}\sum_{i=1}^n\exp(\ell(\mathbf{w}; \mathbf{x}_i, y_i)/\tau)\).

Example 4.2

The KL-constrained DRO is a special case of SCO by setting \(\bar g=(g_1, g_2)\), \(f(\bar g) = g_1\log(g_2) + g_1\rho\) and \(g_1(\mathbf{w}, \tau)=\tau, g_2(\mathbf{w}, \tau) = \frac{1}{n}\sum_{i=1}^n\exp(\ell(\mathbf{w}; \mathbf{x}_i, y_i)/\tau)\).

Example 4.3

The compositional objective for AUC maximization has a compositional term of \(f(g(\mathbf{w}))\), where \(g(\mathbf{w})\) is a stochastic function and \(f\) is a deterministic function.

4.1.1 Optimization Challenge

The challenge of solving SCO lies in how to estimate the gradient \(\nabla F(\mathbf{w}) = \nabla g(\mathbf{w})\nabla f(g(\mathbf{w}))\), where \(\nabla g(\mathbf{w})\in\mathbb{R}^{d\times d'}\) denotes the transpose of the Jacobian matrix of \(g\) at \(\mathbf{w}\) and \(\nabla f(g)\in\mathbb{R}^{d'}\) is a gradient of \(f\) at \(g\).

A simple way of estimating the gradient is by using stochastic samples, i.e., \(G(\mathbf{w}; \xi, \zeta, \zeta') = \nabla g(\mathbf{w}; \zeta)\nabla f(g(\mathbf{w}; \zeta'); \xi)\), where \(\xi, \zeta, \zeta'\) are random samples. One can also use mini-batches of random samples to compute the estimator. However, the problem is that \(G(\mathbf{w}; \xi, \zeta, \zeta')\) is a biased estimator when \(f\) is non-linear, i.e., \(\mathbb{E}_{\xi,\zeta, \zeta'}G(\mathbf{w}; \xi, \zeta, \zeta')\neq \nabla F(\mathbf{w})\). This will break all assumptions made in the convergence analysis in Chapter 3. Directly using this estimator in SGD could result in non-convergence or it requires a large batch size for estimating \(g(\mathbf{w})\) as discussed below.

4.1.2 A Straightforward Approach with Large Batch Sizes

Let us consider a straightforward approach that uses large batch sizes for estimating the gradient. In particular, we update the model parameter by the following: \[\begin{align} &\mathbf{u}_t = \frac{1}{B}\sum_{j=1}^B g(\mathbf{w}_t; \zeta_{j,t}),\quad \mathbf{v}_{t} = \frac{1}{B}\sum_{i=1}^B \nabla g(\mathbf{w}_t; \zeta'_{i,t})\nabla f\left(\mathbf{u}_t; \xi_{i,t}\right)\\ &\mathbf{w}_{t+1} =\mathbf{w}_t - \eta_t \mathbf{v}_{t}, \end{align}\] where \(\zeta_{1,t},\ldots, \zeta_{B,t}, \zeta'_{1,t},\ldots, \zeta'_{B,t},\xi_{1,t},\ldots, \xi_{B,t}\) are independent samples. The idea of this approach is that if the batch size \(B\) is sufficiently large, then \(\mathbf{u}_t\) is an accurate estimator of \(g(\mathbf{w}_t)\) and \(\mathbf{v}_t\) is an accurate estimator of \(\nabla F(\mathbf{w}_t)\).

Next, we briefly discuss the complexity of this approach for finding an \(\epsilon\)-stationary solution. The key is to show that with a sufficiently large \(B\), the error of the gradient estimator \(\mathbf{v}_t\) is small. We consider the same assumptions as in Assumption 4.1, Assumption 4.2, Assumption 4.3, and Assumption 4.4, which will be introduced in the next section. Let \(\zeta_t=(\zeta_{1,t},\ldots, \zeta_{B,t})\), \(\zeta'_t=(\zeta'_{1,t},\ldots, \zeta'_{B,t})\), and \(\xi_t=(\xi_{1,t},\ldots, \xi_{B,t})\). First, we notice that \(\mathbb{E}_{\zeta'_t, \xi_t}[\mathbf{v}_t] = \nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_t)\). Then it follows: \[\begin{align*} &\mathbb{E}[\| \mathbf{v}_{t} - \nabla F(\mathbf{w}_t)\|_2^2]= \mathbb{E}\bigg[\bigg\| \mathbf{v}_t - \nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_t)+ \nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_t) - \nabla F(\mathbf{w}_t)\bigg\|_2^2\bigg]\\ &=\mathbb{E}\bigg[\bigg\| \frac{1}{B}\sum_{i=1}^B \nabla g(\mathbf{w}_t; \zeta'_{i,t})\nabla f(\mathbf{u}_t; \xi_{i,t}) - \nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_t)\bigg\|_2^2\bigg]\\ & \quad\quad+ \mathbb{E}\bigg[\bigg\|\nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_t) - \nabla F(\mathbf{w}_t)\bigg\|_2^2\bigg]. \end{align*}\] Following Lemma 4.4, the first term can be bounded by \[\begin{align*} &\mathbb{E}\bigg[\bigg\| \frac{1}{B}\sum_{i=1}^B \nabla g(\mathbf{w}_t; \zeta'_{i,t})\nabla f(\mathbf{u}_t; \xi_{i,t}) - \nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_t)\bigg\|^2_2\bigg]\\ &\leq \frac{1}{B}\mathbb{E}\bigg[\bigg\|\nabla g(\mathbf{w}_t; \zeta'_{i,t})\nabla f(\mathbf{u}_t; \xi_{i,t}) - \nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_t)\bigg\|^2_2\bigg]\leq \frac{G_2^2\sigma_1^2+G_1^2\sigma_2^2}{B}. \end{align*}\] Following Assumption 4.1, Assumption 4.2, and Assumption 4.3, the second term can be bounded by \[\begin{align*} &\mathbb{E}\bigg[\bigg\| \nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_t) - \nabla F(\mathbf{w}_t)\bigg\|_2^2\bigg]=\mathbb{E}\bigg[\bigg\| \nabla g(\mathbf{w}_t)\nabla f(\mathbf{u}_t) - \nabla g(\mathbf{w}_t)\nabla f(g(\mathbf{w}_t))\bigg\|_2^2\bigg]\\ &\leq \mathbb{E}[G_2^2L_1^2\|\mathbf{u}_t - g(\mathbf{w}_t)\|_2^2]\leq\frac{G_2^2L_1^2\sigma_0^2}{B}. \end{align*}\] As a result, if \(B=O(\max(L_1^2\sigma_0^2/\epsilon^2, (\sigma_1^2+\sigma_2^2)/\epsilon^2))\), then \(\mathbb{E}[\|\mathbf{v}_{t} - \nabla F(\mathbf{w}_t)\|_2^2]\leq O\left(\frac{L_1^2\sigma_0^2}{B} + \frac{\sigma_1^2+\sigma_2^2}{B}\right)\leq O(\epsilon)\). Following Lemma 4.9, if we set \(\eta=O(1/L_F)\) and \(T=O(L_F/\epsilon^2)\), it guarantees that \[ \mathbb{E}\left[\frac{1}{T}\sum_{t=1}^T\|\nabla F(\mathbf{w}_t)\|_2^2\right]\leq O(\epsilon^2). \] Overall, it yields a sample complexity of

\[\begin{align}\label{eqn:com-largebatch} BT=O\left(\max\bigg(\frac{L_FL_1^2\sigma_0^2}{\epsilon^4}, \frac{L_F(\sigma_1^2+\sigma_2^2)}{\epsilon^4}\bigg)\right). \end{align}\]

There are several drawbacks to this approach: (i) the required batch size depends on the accuracy level \(\epsilon\), which is difficult to determine in practice; (ii) supporting the extremely large batch sizes demands substantial computational resources (e.g., a large number of GPUs). In the following sections, we will introduce approaches that do not require large batch sizes.

← Go Back