loader
Generating audio...

arxiv

Paper 2412.11660

Non-Convex Optimization in Federated Learning via Variance Reduction and Adaptive Learning

Authors: Dipanwita Thakur, Antonella Guzzo, Giancarlo Fortino, Sajal K. Das

Published: 2024-12-16

Abstract:

This paper proposes a novel federated algorithm that leverages momentum-based variance reduction with adaptive learning to address non-convex settings across heterogeneous data. We intend to minimize communication and computation overhead, thereby fostering a sustainable federated learning system. We aim to overcome challenges related to gradient variance, which hinders the model's efficiency, and the slow convergence resulting from learning rate adjustments with heterogeneous data. The experimental results on the image classification tasks with heterogeneous data reveal the effectiveness of our suggested algorithms in non-convex settings with an improved communication complexity of $\mathcal{O}(\epsilon^{-1})$ to converge to an $\epsilon$-stationary point - compared to the existing communication complexity $\mathcal{O}(\epsilon^{-2})$ of most prior works. The proposed federated version maintains the trade-off between the convergence rate, number of communication rounds, and test accuracy while mitigating the client drift in heterogeneous settings. The experimental results demonstrate the efficiency of our algorithms in image classification tasks (MNIST, CIFAR-10) with heterogeneous data.

Paper Content:
Page 1: Non-Convex Optimization in Federated Learning via Variance Reduction and Adaptive Learning Dipanwita Thakur1, Antonella Guzzo1, Giancarlo Fortino1Sajal K. Das2 1University of Calabria, Italy 2Missouri University of Science & Technology, USA dipanwita.thakur@unical.it, antonella.guzzo@unical.it, giancarlo.fortino@unical.it, sdas@mst.edu Abstract This paper proposes a novel federated algorithm that leverages momentum-based variance reduction with adaptive learning to address non-convex settings across heterogeneous data. We intend to minimize communication and computation overhead, thereby fostering a sustainable federated learning system. We aim to overcome challenges related to gradient variance, which hinders the model’s efficiency, and the slow convergence result- ing from learning rate adjustments with heterogeneous data. The experimental results on the image classification tasks with heterogeneous data reveal the effectiveness of our suggested algorithms in non-convex settings with an improved commu- nication complexity of O(ϵ−1)to converge to an ϵ-stationary point - compared to the existing communication complexity O(ϵ−2)of most prior works. The proposed federated version maintains the trade-off between the convergence rate, number of communication rounds, and test accuracy while mitigating the client drift in heterogeneous settings. The experimental results demonstrate the efficiency of our algorithms in image classification tasks (MNIST, CIFAR-10) with heterogeneous data. Introduction Federated learning (FL) is a distributed machine learn- ing technique allowing multiple devices to train a model collaboratively without sharing their data with a central server(McMahan et al. 2017, Initial version posted on arXiv in February 2016). Traditional federated learning architec- ture consists of a global server and several local devices. The global server, also known as the federated server, generates a global model. The global model is sent to the local clients to process the locally generated data. The parameters used in the local models are sent to the global server, and the global server aggregates the parameters and sends back the updated global model to the local devices to enhance the performance of the local models using locally generated data. Usually, devices are heterogeneous. However, there are several issues related to the traditional federated learning. First, several global rounds are performed in a traditional FL framework to converge with non-independent, Identically Distributed (non-IID) client datasets and high communica- tion costs per round. Hence, it enhances the communication Copyright ©2025, Accepted the FLUID workshop@Association for the Advancement of Artificial Intelligence (www.aaai.org). All rights reserved.complexity. Second, the number of clients participating in each iteration also affects the performance of the traditional FL. Every client can’t participate in each round due to slow convergence for tuning the learning rate. Tuning the learning rate also enhances the computational complexity. A lower number of client participants increases the number of global rounds due to the non-IID data as each local device processes its personalized data. Third, the training data are dispersed widely among many devices, and there is a complicated con- nection time between each device and the central server. The slow communication is a direct result, which inspired FL algorithms with high communication efficiency. Fourth, non- convex optimization is among the most vital techniques in contemporary machine learning (e.g., neural network train- ing). Convergence in non-convex settings is still challenging in FL settings with non-IID data. Cross-silo and cross-device are two different settings of FL. The cross-silo environment is associated with a limited number of dependable clients, usually businesses like banks or medical facilities. On the other hand, there could be a very big number of clients in the cross-device federated learn- ing environment—all 6.84 billion active smartphones, for instance. Because of this, we might never, ever go through all of the clients’ data during training in that particular scenario. Additional characteristics of the cross-device environment include resource-poor clients interacting via a very unsta- ble network. The combination of this setting’s fundamental elements creates special difficulties that do not exist in the cross-silo setting. In this paper, we consider the more difficult cross-device scenario. Notably, recent developments in FL optimization, including FedDyn (Acar et al. 2021), SCAF- FOLD (Karimireddy et al. 2020b), and (Wu et al. 2023) are no longer relevant because they were created for cross-silo environments. The computation related to federated learning design, de- spite the model used for training, is directly related to energy consumption. Hence, we need an optimized FL algorithm that gives convergence guarantees. Due to the issues men- tioned above, designing an optimized FL that converges is challenging. Several FL algorithms proposed so far mainly focus on non-iid data and communication efficiency. To de- velop the convergence guarantee for the FL algorithm, the researchers assumed either iid data or the participation of allarXiv:2412.11660v1 [cs.LG] 16 Dec 2024 Page 2: the devices1. However, these assumptions are not feasible in real scenarios. In the real world, it is necessary to handle heterogeneous (non-IID) data to mitigate the “client-drift” problem. Several optimization algorithms are proposed in the FL literature to mitigate the client drift issue. However, very few are considered in both the non-convex and non-IID settings and provide the convergence analysis for the same. This paper aims to decrease the number of communica- tion rounds, which ultimately reduces energy consumption. According to (McMahan et al. 2017, Initial version posted on arXiv in February 2016), FL settings are more prone to communication overhead than computation overhead. The objective is to increase computation cost to minimize the com- munication rounds to train a model. This can be achievable either by increasing parallelism or increasing the computation of each device. Participation of more clients with higher com- putation can reduce the number of communication rounds (McMahan et al. 2023). Most FL algorithms consider a fixed learning rate for all clients in each iteration. However, not all parameters will benefit from this uniform learning rate, which could lead to slow convergence due to high compu- tation for tuning of learning rate. While certain parameters may require smaller adjustments to avoid overshooting the optimum value, others may require more frequent updates to speed up convergence. To tackle this, algorithms with adjustable learning rates were created. With these methods, the algorithm can traverse the optimization landscape more quickly by modifying the learning rate of each parameter according to its past gradients. Even though several FL techniques have been put forth, their undesirable convergence behavior prevents them from signif- icantly lowering communication costs by avoiding frequent transmission between the central server and local worker nodes. Numerous factors contribute to its cause, including (1) client drift (Karimireddy et al. 2020b), in which local client models approach local rather than global optima; (2) lack of adaptivity as stochastic gradient descent (SGD) based update (Reddi et al. 2021) and (3) large training iterations to con- verge as a model parameter and training dataset sizes keep growing. Even with these new developments, most research is devoted to fixing client drifts (Karimireddy et al. 2020b; Khanduri et al. 2021; Xu et al. 2022). Not all problems can be resolved by the FL system that is in place now. Several optimization techniques are available for gradient descent to reduce the loss function and converge the algo- rithm. Momentum-based updates and adaptive learning rates are the two popularly used optimization techniques for gra- dient descent. The deep learning algorithms normally used as a model in the FL frameworks are of nonconvex settings. Momentum-based updates are used to overcome local min- ima and accelerate convergence, and the adaptive learning rates are used to improve the convergence speed and accuracy of the algorithm. Traditional optimizers, such as Adaguard, RMProp, and Adam, are used in adaptive federated learning. However, the structure of adaptive FL systems is challenging because the local device moves in diverse directions, and 1Throughout the article, we refer to network things like nodes, clients, sensors, or organizations as “devices.”the global servers cannot be updated often in the FL sce- nario. Convergence problems may arise from the adaptive FL method’s flawed architecture (Chen, Li, and Li 2020). FedAdagrad, FedYogi, and FedAdam are the federated ver- sions of adaptive optimizers first proposed in (Reddi et al. 2021). However, this analysis is only valid when β1 = 0 . Therefore, it is unable to make use of momentum. To over- come this issue, MimeAdam is proposed in (Karimireddy et al. 2020a), which applies server statistics locally. How- ever, MimeAdam needs to calculate the entire local gradient, which may not be allowed in real life. FedAMS is a more recent proposal (Wang, Lin, and Chen 2023) that offers complete proof but does not increase the pace of convergence. FedAdagrad, FedYogi, FedAdam, and FedAMS have sample convergence rates of O(ϵ−4)over- all, which is not better than FedAvg. They also need an additional global learning rate to tune simultaneously. In (Rostami and Kia 2023), the authors used a stochastic and variance-reduced technique called Stochastic Variance Re- duced Gradient (SVRG) for local updates. This method cal- culated the full batch gradient of each participating device at each iteration which ultimately enhanced the computa- tional complexity. In (Wu et al. 2023), the authors proposed an effective adaptive FL algorithm (FAFED) in cross-silo settings using the momentum-based variance-reducing tech- nique. FAFED shows the fast convergence with the best- known sample complexity O(ϵ−3)and communication com- plexity O(ϵ−2). However, the non-convex setting required for deep neural network (DNN) algorithms is not addressed. Moreover, DNN requires higher GPU memory while ana- lyzing the images. Hence, large batch sizes or calculation of gradient checkpoints increase the computational complexity with slow convergence. Based on the issues mentioned above, we propose a new FL algorithm in this paper using the adaptive learning rate with variance reduction of momentum SGD in non-convex settings. Specifically, we suggest a general energy-efficient, optimized FL framework using the concept of reduced com- munication rounds in which (1) the devices train over several epochs using an adaptive learning rate optimizer to mini- mize loss on their local data and (2) the server updates its global model by applying a momentum-based gradient op- timizer to the average of the clients’ model updates, while the clients train over several epochs using a client optimizer to minimize loss on their local data. The proposed method combines momentum-based variance reduction updates and adaptive learning rate optimization techniques without in- creasing client storage or communication costs for local and global updates. Contributions. our major contributions are as follows. •We design a federated learning algorithm where momentum-based variance reduction is used for global updates in addition to the momentum-based local updates with adaptive learning. The design of the proposed FL al- gorithm is aimed at addressing two main issues in order to speed up the convergence of Federated Learning (FL) and another issue to mitigate the high computation with non- IID setting for non-convex problems. These issues are the high variances associated with: (i) the noise from local Page 3: client-level stochastic gradients, and (ii) the global server aggregation step due to varying client characteristics when multiple local updates are present. A high computation issue is associated with tuning the learning rate. We can tackle these issues separately by reducing the variance in the global server update and the local client updates using both global and local momentum with adaptive learning. •We give convergence analysis in generic non-convex set- tings and design new federated optimization methods com- patible with multiple devices using the proposed FL algo- rithm. The convergence analysis shows that the proposed FL algorithm converges to an ϵ−stationary point for smooth non-convex problems with a non-IID setting with improved communication complexity of O(ϵ−1)com- pared with state-of-the-art approaches. •Experimental results with non-IID data in non-convex set- tings reveal the effectiveness of our suggested algorithm which encourages communication efficiency and subtly reduces client drift under heterogeneous data distribution. Related Works While several FL algorithms have been proposed, most of them are related to communication efficiency and data pri- vacy. Our work focuses on communication-efficient and early convergent FL algorithms. Federated Averaging (FedAvg) (McMahan et al. 2017, Initial version posted on arXiv in February 2016) is the first and most popularly used FL al- gorithm. However, FedAvg fails to guarantee theoretically with non-IID in the convex optimization setting. FedAvg does not completely address the fundamental issues related to het- erogeneity, despite the fact that it has empirically proven successful in diverse environments. FedAvg in the setting of systems heterogeneity typically drops devices that are unable to compute E epochs within a given time window, rather than allowing participating devices to undertake varying amounts of local work based on their underlying systems restrictions. To address this issue, several regularization approaches (Acar et al. 2021; Gao et al. 2022; Karargyris et al. 2023; Mendieta et al. 2022) are used to enforce local optimization. FedProx (Li et al. 2020) is another FL algorithm without assuming IID data and the participation of all the client devices. Specifically, FedProx exhibits noticeably more accurate and steady convergence behavior than FedAvg in extremely heterogeneous situations. The FedProx methodology relies on a gradient similarity assumption. The FedProx adds a prox- imal term for each local objective to prove the convergence. However, several presumptions are essential to implement FedProx in realistic federated contexts. The SCAFFOLD approach (Karimireddy et al. 2020b) achieves convergence rates independent of the degree of heterogeneity by utiliz- ing control variates to lessen client drift. Although the ap- proach works well for cross-silo FL, it cannot be used for cross-device FL because it requires clients to keep their states constant over rounds. A similar approach, FedDyn (Acar et al. 2021), was proposed to reduce communication rounds using linear and quadratic penalty terms. Despite their best efforts, FedAvg’s performance is not entirely understood. However, it is proved from the above that client drift is one of the ma-jor issues for performance degradation in FL. It introduced control variate to adjust local updates in order to address the client drift issue. Furthermore, MOON (Li, He, and Song 2021) suggested using model contrastive learning to use the similarity between local and global model representations to rectify the local training. Various research attempts to en- hance FedAvg in various ways. For example, in (Wang et al. 2021), the authors explained that the objective inconsistency is the reason behind the slow convergence and also provided the converge analysis of the proposed and previous methods. On the other hand, FedMa (Wang et al. 2020) was proposed to enhance the convergence rate by matching and averaging hidden elements and also reduce the communication cost. Compared with existing works on FL (McMahan et al. 2017, Initial version posted on arXiv in February 2016), (Li et al. 2020), (Karimireddy et al. 2020b), this paper uses two different optimization methods to reduce the communication rounds to make the proposed FL algorithm an energy-efficient one. In contrast to state-of-the-art FL algorithms, the client devices optimize their local models using an adaptive learn- ing rate optimizer to minimize loss on their local data, and the server updates its global model by applying a momentum- based gradient optimizer to the average of the clients’ model updates. However, none of them studies the variance reduc- tion with adaptive learning in the federated environment. Table 1 the comparison of our proposed method with state- of-the-art methods which considered non-IID settings and non-convex functions. Algorithms 1 and 2 depict the federated version for clients and server updates using momentum-based variance reduc- tion with the adaptive learning rate. Preliminaries Focusing on non-convex settings, our objective is to solve an optimization problem with the help of the functions min x∈Rdf(ω) where f(ω) =1 NNX i=1Fi(ω)(1) Fi(ω) =Ex∼DiLi(x;ω), which gives the loss of the pre- diction on data sample (xi, yi)with model parameters ω, ω∈ Z andDidenotes the data distribution for the ithclient. When the distributions Diare different among the clients, it is known as the heterogeneous data setting. Each Fi(ω)is differentiable, possibly non-convex satisfies L-Lipschitz con- tinuous gradient (i.e, L-smooth) for some parameter L >0. For each client iandω, the true gradient ▽Fi(ω)is assumed to have an unbiased stochastic gradient gi(ω). The loss on an example is denoted by f(x, ω), and the model’s train- ing loss is denoted by F. We can derive sample functions f(·, x)such that E[f(·, x)] =F(·), where xdenotes a batch of data samples. Problem 1 encompasses a wide range of machine learning and imaging processing issues, such as neural network training, non-convex loss ERM (empirical risk minimization), and many more. As per the traditional standard for non-convex optimization, our attention is di- Page 4: Reference Client Participation Compression? BCD? ALR? T (Koloskova et al. 2020) Full No Yes No O(1 Nϵ2) (Haddadpour et al. 2020) Full Yes Yes No O(1 Nϵ2) (Khanduri et al. 2021) Full No Yes No O(1 Nϵ1.5) (Karimireddy et al. 2020a) Partial No Yes No O(1√ Rϵ1.5) (Karimireddy et al. 2020b) Partial No Yes No O(1 Rϵ1.5) (Das et al. 2022) Partial Yes No No O(max (pα N,1√ R)1 ϵ1.5)∗1 Ours Partial No No Yes O(1 ϵ1) Table 1: Parameter description in the SOTA and our proposed FL algorithms: Here, N is the total number of clients, P is the number of participating clients in each round, either N=P (Full) or P < N (Partial), compression denoted the either compressed communication or not, BCD denotes the bounded client dissimilarity assumptions, ALR denotes the use of adaptive learning rate or not and T denotes the number of gradient updates to achieve E[|fω|2]≤ϵon smooth non-convex settings rected towards algorithms that can effectively identify an approximate stationary point x that satisfies ∥▽f(x)∥2≤ϵ. Furthermore, we also assume the following. Assumption 1. (Smoothness). The function FiisL1-smooth if its gradient is L1-Lipschitz, that is ∥▽Fi(x)−▽Fi(y)∥ ≤ L∥x−y∥for all x, y∈Rd. We also have: f(y)≤f(x) + ⟨▽f(x), y−x⟩+L 2∥y−x∥2. Assumption 2. (Bounded Variance). The function Fi have σ-bounded (local) variance i.e., E[∥▽[fi(x, ω)]j− [▽Fi(x)]j∥2]≤σ2for all x∈Rd,j∈[d]andi∈ [n]. Moreover, we assume that the global variance is also bounded, (1/n)Pn i=1∥▽[Fi(x)]j−[f(x)]j∥2≤ ▽2g, jfor allx∈Rd,j∈[d]. Assumption 3. (Bounded Gradients). The function fi(x, ω) haveG-bounded gradients i.e., for any i∈[n], x∈Rdand ω∈ Z,∥[▽fi(x, ω)]j∥ ≤G,∀j∈[d]. Assumption 4. (Unbiased Gradients). Each component function fi(x;ω)computed at each client node is unbiased ∀xi∼ D i,i∈[N]andx∈Rd. Assumption 5. Non-negativity . Considering each fi(ω)is non-negative and hence, f∗ i=minf i(ω)≥0. While compu- tation, most of the loss functions are usually positive. In case of negative, we add some constant value to make it positive. Definition I (ϵ-stationary Point). ϵ-stationary point, x satisfies ∥▽f(x)∥ ≤ ϵ. Furthermore, in tinteractions, a stochastic algorithm is considered to reach an ϵ-stationary point if E[∥▽f(x)∥]≤ϵ, where the expectation is over the algorithm’s randomness up to time t. Our objective is to find a critical point of F, where ▽F(ω) = 0 , i.e., to converge to an ϵ-stationary point for smooth non-convex functions, i.e., E[∥▽f(x)∥2]≤ϵ. Only we are trying to access stochastic gradient on arbitrary points. In this experiment, we consider stochastic gradient descent (SGD) as a standard algorithm. Using the following recursion, SGD generates a series of iterates ω1, ω2,···, ωN. ωi+1=ωi−ηigi (2) where gi=▽f(xi, ωi), f(·, x1),···, f(·, xN)are either the i.i.d. or non-i.i.d samples from a distribution D. The learningratesη1, η2,···, ηN∈Raffect the performance if not tuned properly. The learning rate, i.e., step size plays a crucial role in the convergence of SGD. In order to address this sensitivity and render SGD resilient to parameter selection, “adaptive” SGD methods are frequently employed, in which the step-size is determined dynamically utilizing stochastic gradient data from both the current and previous samples (Cutkosky and Orabona 2019; Khanduri et al. 2020). In this work, we present one such “adaptive” technique for distributed non-convex stochastic optimization that designs the step sizes based on the available stochastic gradient information. Proper selection of the learning rate, ηiin SGD ensures that a randomly chosen recurrence ωiensures E[∥▽F(ωi)∥]≤O(1/N1/4). Variance reduction The use of momentum-based variance reduction performs well in prior works to reduce the number of hyperparameters such as batch size. The momentum-based variance reduction is also used to reduce the variance of gradient in non-convex updates. In this work, the variance is represented as follows: mi= (1−β)mi−1+β▽f(xi, ωi) +(1−β)(▽f(xi, ωi)− ▽f(xi−1, ωi)) ωi+1=ωi−ηimi(3) One additional term, (1−β)(▽f(xi, ωi)−▽f(xi−1, ωi))is used to the update with adaptive learning rate. The variance reduction concept is similar to conventional reduction, where two different gradients are used in each step. Proposed Algorithm To improve FL’s convergence rate and communication overhead the following problems must be resolved: (i) the high variance of simple averaging used in the global server aggregation step when there are multiple local updates, which is made worse by client heterogeneity; (ii) the high variance linked to the noise of local client-level stochastic gradients; and the (iii) heterogeneity among the local functions. The key idea of the proposed FL algorithm is to apply momentum-based variance reduction for the global and local updates with the adaptive learning rate for clients. Hence, in the optimized federated non-convex Page 5: Algorithm 1: Algorithm for Client Update Input :c,k,w, Initial point: ω1 1:η0←k w1/3 2:for all t= 1,2,···Tdo 3: for all clients i= 1,2,···Nin parallel do 4: for all local epochs j= 1,2,···Edo 5: ifj= 1then 6: Setm(t,j) i =▽fi(ω(t,j) i),ˆm(t−1,j) i = ▽fi(ˆω(t−1,j) i ) 7: else 8: Each client i, randomly select a batch size, B(t,j) i. 9: Compute the stochastic gradient of the non-convex loss function, fiover B(t,j) i atω(t,j) i,ˆω(t−1,j) i ,ω(t,j−1) i , and ˆω(t−1,j−1) i , i.e., e∇fi(ω(t,j) i;B(t,j) i), e∇fi(ω(t−1,j) i ;B(t,j) i),e∇fi(ω(t,j−1) i ;B(t,j) i) ande∇fi(ω(t−1,j−1) i ;B(t,j) i). 10: ηt←k (wt+e∇fi(ω(t,j) i))1/3 11: Update m(t,j) i =e∇fi(ω(t,j) i;B(t,j) i) + m(t,j−1) i −e∇fi(ω(t,j−1) i ;B(t,j) i) and m(t−1,j) i =e∇fi(ω(t−1,j) i ;B(t,j) i) + m(t−1,j−1) i −e∇fi(ω(t−1,j−1) i ;B(t,j) i) 12: end if 13: Update ω(t,j+1) i =ω(t,j) i−ηtm(t,j) i and ˆω(t−1,j+1) i = ˆω(t−1,j) i −ηtˆω(t−1,j) i . 14: end for 15: end for 16:end for 17:Send (ωt−ωt,E i)and((ωt−ωt,E i)−(ωt−1−ˆωt−1,E i )) to the server. heterogeneous settings, the total of Nnumber of clients are jointly trying to solve the following optimization problem: min ω∈R[F(ω) :=NX i=1SiFi(ω)] (4) where Si=Ni Nis the relative sample size, Fi(ω) = 1 NiP x∈Difi(x;ω)is the ithclient’s local objective function. In our case, the learning model defined fias a non-convex loss function and x∈ Direpresents the data sample from the local data Di. After receiving the current global model ω(t,0) attthcommunication rounds, each client independently ex- ecutes κiiterations of the local solver to optimize its local objective. In our work, the local solver is the momentum- based variance reduction with an adaptive learning rate. The number of local updates κifor each client ican vary in our theoretical framework. If a client run Eepochs with batch size B, then κi=⌊ENi B⌋, where Niis the sample data of the ithclient. Recall from the simplest and most popular FL algorithm, FedAvg (McMahan et al. 2017, Initial versionposted on arXiv in February 2016), the update rule can be written as follows: ω(t+1,0)−ω(t,0)=NX i=1Si∆t i=−NX i=1Si.ηκi−1X j=0gi(ω(t,j) i) (5) where ω(t,j) i represents the model of the i-th client in the t-th communication rounds after the j-th local update, at round t, ∆t iis the local parameters changes of i-th client. In our work, we have modified the update rule where local updates are performed using momentum-based variance reduction with an adaptive learning rate, and the global update is performed using momentum-based variance reduction. The client and server update rules are demonstrated using the following algorithms. Algorithm 2: Algorithm for Server Update Require: Initial parameter ω1, Number of communication rounds T, Number of participating clients R≤N, Epochs E Ensure: ω0=ω1 1:for all round t= 1,2,···do 2: Server broadcasts ωt, ωt−1to a the set of Rnumber of clients, St 3: for all client i∈ Stdo 4: Setω(t,0) i=ωtandˆω(t−1,0) i =ωt−1. Execute Al- gorithm 1. 5: end for 6: ift= 0then 7: Setˆm(t)=1 RP i∈S(t)(ω(t)−ω(t,E) i){ˆmis the global momentum} 8: else 9: Setˆm(t)=β(t) RP i∈S(t)(ω(t)−ω(t,E) i) + (1 − β(t)) ˆm(t−1)+1−β(t) RP i∈S(t)((ω(t)−ω(t,E) i)− (ω(t−1)−ˆω(t−1,E) i )) 10: end if 11: Update ω(t+1)=ωt−ˆm(t) 12:end for Cutosky and Orabona (Cutkosky and Orabona 2019) first proposed the variance reduction method in non-convex set- tings and proposed an algorithm, namely STORM. In this paper, the idea of the STORM is considered as an update rule in a slightly different way. Equation 3 represents the update rule in STORM. In equation 3, β∈[0,1)is the momen- tum parameter per iteration and xiis the randomly selected sample in i-th iteration. The stochastic gradient at ωi+1is calculated on xi. In Algorithm 1, lines 6, 11, and 13 give the client updates using momentum-based variance reduction us- ing adaptive learning rate after running Enumber of epochs. It is pertinent to mention that the concept is similar to 3 where β= 0and the full gradient is considered in the computation of the initial gradient. Similarly, in the Algorithm 2, global momentum is used for server aggregation. To overcome the pitfall of the used server aggregation scheme in conventional federated averaging such as FedAvg (McMahan et al. 2017, Page 6: Initial version posted on arXiv in February 2016), in this paper, we consider the gradient as (ω(t)−ω(t,E) i). Empirical Analysis This section empirically studies the proposed method and compares it with different baseline methods to demonstrate its efficacy by applying it to solve federated learning problems and image classification tasks. Experiments are implemented using PyTorch, and we run all experiments on CPU machines with 3.50 GHz Intel Xeon with NVIDIA GeForce RTX 2080 Ti GPU. Implementation Details Dataset Our image classification tasks will involve utilizing the MNIST dataset and CIFAR-10 dataset with 50 clients in the network. The MNIST dataset comprises 60,000 training images and 10,000 testing images, categorized into 10 classes. Each image is comprised of 28 ×28 arrays of grayscale pix- els. The CIFAR-10 dataset encompasses 50,000 training images and 10,000 testing images, featuring 60,000 32 ×32 color images sorted into 10 categories. Each client will be equipped with an identical Convolutional Neural Network (CNN) model serving as the classifier, and the loss function employed will be cross entropy we use Dirichlet distribution to generate a non-IID dataset with α= 0.5 Baselines In our study we perform a comparative analysis of our approach with SOTA methods: FedAvg (McMahan et al. 2017, Initial version posted on arXiv in February 2016), FedProx (Li et al. 2020), FedNova(Wang et al. 2021) and FedGLOMO (Das et al. 2022). FedAvg is one of the federated optimization approaches that is most frequently used among them. One could think of FedProx as a re-parametrization and generalization of FedAvg. FedNova is a quick error conver- gence normalized averaging method that removes objective inconsistency. FedGLOMO used the local and global update using momentum-based variance reduction with quantized messages to reduce communication complexity. The major difference between our proposed method and FedGLOMO is that no message quantization is used. We also use an adaptive learning rate strategy for local updates for early convergence. Network Architecture In our experiment, the neural network is used with two different layers and the ReLU activation function. The size of the hidden layers is 600. We train the models by using the categorical cross-entropy loss with ℓ2- regularization. PyTorch’s weight decay value is set to 1e−4 to apply ℓ2-regularization. Validation Figures 1 and 2 demonstrate the cross-entropy-based training loss and accuracy with the CIFAR-10 dataset concerning the number of epochs on the federated clients. Similarly, Figures 4 and 3 show the training loss and accuracy with the CIFAR- 10 dataset concerning the number of communication rounds. Moreover, Figures 5 and 6 demonstrate the test accuracy and loss in each communication round. The test loss and accu- racy clearly depict the early convergence of our proposed method. Similarly, Figures 7 and 8 demonstrate the cross- entropy-based training loss and accuracy with the MNIST Figure 1: Train Loss vs Epochs with CIFAR-10 Figure 2: Train Accuracy vs Epochs with CIFAR-10 Figure 3: Train Accuracy vs Communications with CIFAR- 10 Figure 4: Train Loss vs Com- munications with CIFAR-10 Figure 5: Test Accuracy with CIFAR-10 Figure 6: Test Loss with CIFAR-10 dataset concerning the number of epochs on the federated clients. Similarly, Figures 10 and 9 show the training loss and accuracy with the MNIST dataset concerning the number of communication rounds. Figures 11 and 12 represent the test accuracy and loss using the MNIST dataset and also prove the early convergence. Using both the datasets, the experi- mental results depict that, while our proposed FL algorithm is marginally better than AdaGLOMO on test accuracy, on both the training loss and accuracy our proposed algorithm appears to be somewhat faster in terms of the number of communication rounds. Moreover, the testing accuracy is more stable than the SOTA methods as the variations of the accuracy are less in our proposed method. The test accuracy and loss results depict the early convergence of the proposed method. Page 7: Mechanisms Non-IID Data Sample Communication Adaptive Learning Momentum Variance Reduction MNIST CIFAR-10 Local Global Accuracy(%) Loss Accuracy(%) Loss X X X X 75.24±0.45 0.67 72.45±0.34 0.86 O(ϵ−4) O(ϵ−4) X ✓ X X 90.19±0.40 0.41 86.92±0.42 0.73 O(ϵ−4) O(ϵ−4) ✓ X X X 90.78±0.45 0.20 87.05±0.45 0.46 O(ϵ−4) O(ϵ−4) ✓ ✓ X X 91.24±0.45 0.0024 86.45±0.38 0.0078 O(ϵ−3) O(ϵ−3) X ✓ ✓ X 89.24±0.45 0.0026 85.32± 0.0032 O(ϵ−4) O(ϵ−4) X ✓ ✓ ✓ 96.23±0.45 0.0024 90.45±0.54 0.0120 O(ϵ−3) O(ϵ−1.5) ✓ ✓ ✓ ✓ 98.04±0.40 0.0017 92.12±0.35 0.0020 O(ϵ−3/2) O(ϵ−1) Table 2: Ablation Studies with different components of the proposed method Figure 7: Train Loss vs Epochs with MNIST Figure 8: Train Accuracy vs Epochs with MNIST Figure 9: Train Accuracy vs Communications with MNIST Figure 10: Train Loss vs Com- munications with MNIST Figure 11: Test Accuracy with MNIST Figure 12: Test Loss with MNIST Ablation Studies In the proposed FL algorithm we use local updates with adap- tive learning and global updates. Both the local and global updates use momentum-based variance reduction. Table 2 shows the performance with different components such as SGD with adaptive learning, momentum-based variance re-duction for local updates, and global updates. For this experi- ment, we have used 80 epochs with 400 rounds and a batch size of 50. Out of 100 clients, a random number of clients participated in each round. It should be noted that the size of the batch used in our experiment is small. This confirms the effectiveness of the proposed method in overcoming the need for large batch sizes for convergence. Conclusion In this study, a communication-efficient approach to feder- ated learning is proposed, leveraging both momentum-based variance reduction and adaptive learning rates. Both local and global updates integrate momentum-based variance reduc- tion and adaptive learning rates for the local models, thereby reducing communication cycles for early convergence. The proposed method exhibits robust performance in non-convex heterogeneous settings with non-IID data, surpassing sev- eral state-of-the-art methods. The principal aim of this paper is to minimize loss, a demonstration accomplished through thorough empirical and theoretical analyses. The proposed method is experimented on a cross-device FL framework. While performing the theoretical analysis, several assumptions are made to prove the convergence. Here, the belief that all the clients participate to show the reduced complexity is a major limitation of this work. In the future, we will come up with cross-silo FL with a similar strategy. Page 8: References Acar, D. A. E.; Zhao, Y .; Navarro, R. M.; Mattina, M.; What- mough, P. N.; and Saligrama, V . 2021. Federated Learning Based on Dynamic Regularization. ArXiv , abs/2111.04263. Chen, X.; Li, X.; and Li, P. 2020. Toward Communication Efficient Adaptive Gradient Method. Proceedings of the 2020 ACM-IMS on Foundations of Data Science Conference . Cutkosky, A.; and Orabona, F. 2019. Momentum-based vari- ance reduction in non-convex SGD . Red Hook, NY , USA: Curran Associates Inc. Das, R.; Acharya, A.; Hashemi, A.; sujay sanghavi; Dhillon, I. S.; and ufuk topcu. 2022. Faster Non-Convex Federated Learning via Global and Local Momentum. In The 38th Conference on Uncertainty in Artificial Intelligence . Gao, L.; Fu, H.; Li, L.; Chen, Y .; Xu, M.; and Xu, C.-Z. 2022. FedDC: Federated Learning with Non-IID Data via Local Drift Decoupling and Correction. In 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) , 10102–10111. Haddadpour, F.; Kamani, M. M.; Mokhtari, A.; and Mahdavi, M. 2020. Federated Learning with Compression: Unified Analysis and Sharp Guarantees. In International Conference on Artificial Intelligence and Statistics , volume 130. Karargyris, A.; Umeton, R.; Sheller, M. J.; Aristizabal, A.; George, J.; et al. 2023. Federated benchmarking of med- ical artificial intelligence with MedPerf. Nature Machine Intelligence , 5(7): 799–810. Karimireddy, S. P.; Jaggi, M.; Kale, S.; Mohri, M.; Reddi, S. J.; Stich, S. U.; and Suresh, A. T. 2020a. Mime: Mimicking Centralized Stochastic Algorithms in Federated Learning. CoRR , abs/2008.03606. Karimireddy, S. P.; Kale, S.; Mohri, M.; Reddi, S.; Stich, S.; and Suresh, A. T. 2020b. SCAFFOLD: Stochastic Controlled Averaging for Federated Learning. In III, H. D.; and Singh, A., eds., Proceedings of the 37th International Conference on Machine Learning , volume 119 of Proceedings of Machine Learning Research , 5132–5143. PMLR. Khanduri, P.; Sharma, P.; Kafle, S.; Bulusu, S.; Rajawat, K.; and Varshney, P. K. 2020. Distributed Stochastic Non- Convex Optimization: Momentum-Based Variance Reduc- tion. arXiv:2005.00224. Khanduri, P.; Sharma, P.; Yang, H.; Hong, M.-F.; Liu, J.; Ra- jawat, K.; and Varshney, P. K. 2021. STEM: A Stochastic Two-Sided Momentum Algorithm Achieving Near-Optimal Sample and Communication Complexities for Federated Learning. ArXiv , abs/2106.10435. Koloskova, A.; Loizou, N.; Boreiri, S.; Jaggi, M.; and Stich, S. 2020. A Unified Theory of Decentralized SGD with Chang- ing Topology and Local Updates. In III, H. D.; and Singh, A., eds., Proceedings of the 37th International Conference on Machine Learning , volume 119 of Proceedings of Machine Learning Research , 5381–5393. PMLR. Li, Q.; He, B.; and Song, D. 2021. Model-Contrastive Feder- ated Learning. In 2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) , 10708–10717.Li, T.; Sahu, A. K.; Zaheer, M.; Sanjabi, M.; Talwalkar, A.; and Smith, V . 2020. Federated Optimization in Heteroge- neous Networks. arXiv:1812.06127. McMahan, H. B.; Moore, E.; Ramage, D.; Hampson, S.; and y Arcas, B. A. 2017, Initial version posted on arXiv in February 2016. Communication-efficient learning of deep networks from decentralized data. In In Proceedings of the 20th International Conference on Artificial Intelligence and Statistics , 1273–1282. McMahan, H. B.; Moore, E.; Ramage, D.; Hampson, S.; and y Arcas, B. A. 2023. Communication-Efficient Learning of Deep Networks from Decentralized Data. arXiv:1602.05629. Mendieta, M.; Yang, T.; Wang, P.; Lee, M.; Ding, Z.; and Chen, C. 2022. Local Learning Matters: Rethinking Data Heterogeneity in Federated Learning. In 2022 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) , 8387–8396. Reddi, S.; Charles, Z. B.; Zaheer, M.; Garrett, Z.; Rush, K.; Kone ˇcný, J.; Kumar, S.; and McMahan, B., eds. 2021. Adap- tive Federated Optimization . Rostami, M.; and Kia, S. S. 2023. Federated Learning Using Variance Reduced Stochastic Gradient for Probabilistically Activated Agents. In 2023 American Control Conference (ACC) , 861–866. Wang, H.; Yurochkin, M.; Sun, Y .; Papailiopoulos, D.; and Khazaeni, Y . 2020. Federated Learning with Matched Aver- aging. In International Conference on Learning Representa- tions . Wang, J.; Liu, Q.; Liang, H.; Joshi, G.; and Poor, H. V . 2021. A Novel Framework for the Analysis and Design of Hetero- geneous Federated Learning. IEEE Transactions on Signal Processing , 69: 5234–5249. Wang, Y .; Lin, L.; and Chen, J. 2023. Communication- Efficient Adaptive Federated Learning. arXiv:2205.02719. Wu, X.; Huang, F.; Hu, Z.; and Huang, H. 2023. Faster adap- tive federated learning. In Proceedings of the Thirty-Seventh AAAI Conference on Artificial Intelligence and Thirty-Fifth Conference on Innovative Applications of Artificial Intelli- gence and Thirteenth Symposium on Educational Advances in Artificial Intelligence , AAAI’23/IAAI’23/EAAI’23. AAAI Press. ISBN 978-1-57735-880-0. Xu, J.; Du, W.; Jin, Y .; He, W.; and Cheng, R. 2022. Ternary Compression for Communication-Efficient Federated Learn- ing. IEEE Transactions on Neural Networks and Learning Systems , 33(3): 1162–1176. Page 9: Appendix / Supplemental material In this section, we present the convergence analysis and proof of the convergence of the proposed federated learning algo- rithm. Convergence Analysis The convergence is proved using the following theorems. Theorem .1. Based on the Assumptions 1, 2, 3, and 4 and for initial batch size B=bE, we setηt=k (wt+σ2t)1/3,k=(bN)2/3σ2/3 L. Also set c=(8L)2 bN+σ2 24LEk3=L2(64 bN+1 24(bN)2E)and w=max{(4LkE)3−σ2t,2σ2,ck 16LE3 }. Then we can set the local updates, E and batch size, b as follows: E=O (T/N2)v/3 , b=O (T/N2)1/2−v/2 (6) where v∈[0,1]. After applying the variance reduction with the adaptive learning rate for noise removal, according to Theorem 1, we have E∥▽f(ω)∥2=Of(ω1)−f∗ N2v/3T1−v/3 +˜Oσ2 N2v/3T1−v/3 + ˜O▽2 N2v/3T1−v/3 (7) For any v∈[0,1], the sample complexity of the proposed method is ˜O ϵ−3/2 . Hence, each client involved in a com- munication cycle required at most ˜O N−1ϵ−3/2 gradient computations. Moreover, the communication complexity is ˜O ϵ−1 . Before going to the proof of the above convergence analy- sis, We explain some of the tradeoffs as follows: The tradeoff between sample size and communication com- plexity: From the above theorem, the sample and communica- tion complexities are represented as ˜O ϵ−3/2 and˜O ϵ−1 when E and b are selected efficiently. In the FL literature, we find the logarithmic factor complexity using either irre- spective of the sample or batch Lipschitz smooth assumption. Hence, our FL framework outperforms in terms of sample and communication complexities. Based on our assumptions the˜O ϵ−1 is the optimal complexity in comparison with SOTA methods. The tradeoff between the batch sizes and local updates: The number of local updates E and the batch sizes b are balanced using the parameters v∈[0,1]. The equations in 6 demon- strate the relation between E and b with the interval [0,1]. Asvgrows from o to 1 the batch size b decreases and the local updates E increases. At v= 1, the b is O(1), however, E=O(T1/3 N2/3). Conversely, at v= 0,b=O(T1/2 N), however, E isO(1). These concepts generalize the SGD and minibatch SGD with local and global updates using momentum-based variance reduction. Before going to the details of the proof, it is necessary to discuss some lemmas in detail.Preliminary Lemmas Lemma .2. We can define the error term as ϵt=mt−1 NPN n=1▽f(k)(ωn t), then the iterations according to Algorithm 1 E * (1−βt)ϵt−1,1 NNX n=11 bX xn t∈B(n) th ▽f(n) ωn t;x(n) t −f(n)(ωn t) −(1−βt) ▽f(n)(ωn t−1;x(n) t)−▽f(n)(ωn t−1)i+  = 0 Proof. Let for some value of N, the gradient error term is ϵt−1. For all n∈N, the only randomness in the left half of the Lemma statement w.r.t xn t. This suggests that we’ve E * (1−βt)ϵt−1,1 NNX n=11 bX xn t∈B(n) th ▽f(n)(ωn t;x(n) t) −f(n)(ωn t) −(1−βt) ▽f(n)(ωn t−1;x(n) t)−▽f(n)(ωn t−1)i+  =E * (1−βt)ϵt−1,1 NNX n=11 bX xn t∈B(n) th ▽f(n)(ωn t;x(n) t)− f(n)(ωn t) −(1−βt) ▽f(n)(ωn t−1;x(n) t)−▽f(n)(ωn t−1)i |Ft+  where Ft=σ(ωn 1, ωn 2,···, ωn t)for all n∈[N]. Ifxn tis chosen randomly with uniform distribution at each k∈K andE[▽fn(ω(n);xk t)] =▽f(n)(ω(n) t), then we have E 1 bX xt t∈B(n) th ▽f(n)(ωn t;xn t)− ▽f(n)(ω(n) t) −(1−βt) ▽f(n)(ωn t−1;xn t−1)−▽f(n)(ω(n) t)i |Ft = 0 for all n∈N. Hence, proved. Theorem .3. Choosing the parameters as •k=(bN)2/3σ2/3 L •c=(8L)2 bN+σ2 24LEk3=L2(64 bN+1 24(bN)2E) •w=max{(4LkE)3−σ2t,2σ2,ck 16LE3 } Page 10: and for any v∈[0,1]at each client and the total number of local updates E=O (T/N2)v/3 , batch size b=O (T/N2)1/2−v/2 and the initial batch size,B=bE, the proposed FL algorithms satisfies the following: (i)E∥▽f(ω)∥2=O f(ω1)−f∗ N2v/3T1−v/3 +˜O σ2 N2v/3T1−v/3 + ˜O ▽2 N2v/3T1−v/3 . (ii)Sample complexity: To reach the ϵ-stationary point the proposed FL algorithm requires at most O(ϵ−3/2)gra- dient computations.Hence, each client requires at most O(N−1ϵ−3/2)gradient computations. (iii) Computation Complexity: To reach the ϵ-stationary point the proposed FL algorithm requires at most O(ϵ−1)com- munication rounds. Proof. •Proof of Statement (i): Put the values of B, E, and b in the given expression and replace B=bE, we get E|▽f(ω)|2≤[32LE T+2L (bN)2/3T2/3](f(ω1)−f∗) + [8E T+1 2(bN)2/3T2/3]σ2+[2562E T+642 (bN)2/3T2/3]σ2log(T+ 1) + [2562E T+642 (bN)2/3T2/3]▽2E−1 Elog(T+ 1) Con- sidering the fact that the total number of local updates E=O (T/N2)v/3 , batch size b=O (T/N2)1/2−v/2 we will get the ex- pression mentioned in (i). •Sample Complexity: From the above expression, the total number of required iterations to achieve ϵ-stationary point is as follows: O(1 N2v/3T1−v/3) =ϵ=⇒T=O(1 N2v/(3−v)ϵ3/(3−v)). Each client computes 2b stochastic gradients in each it- eration. Hence, each client computes 2bTnumber of it- erations in total. Using b=O (T/N2)1/2−v/2 , at each client, the number of total gradient computations is as follows: bT=O(T3/2−v/2 N1−v) =O(1 Nϵ3/2) Therefore, the sample complexity is O(ϵ−3/2) •Communication Complexity: To achieve ϵ-stationary point, the total number of communication rounds are T/B , with B=O (T/N2)v/3 and the T= O(1 N2v/(3−v)ϵ3/(3−v)), we get the communication com- plexity as follows: T/B =O(T1−v/3N2v/3) =O(1 ϵ) Hence, the theorem.

---