Authors: Dipanwita Thakur, Antonella Guzzo, Giancarlo Fortino, Sajal K. Das
Paper Content:
Page 1:
Non-Convex Optimization in Federated Learning via Variance Reduction
and Adaptive Learning
Dipanwita Thakur1, Antonella Guzzo1, Giancarlo Fortino1Sajal K. Das2
1University of Calabria, Italy
2Missouri University of Science & Technology, USA
dipanwita.thakur@unical.it, antonella.guzzo@unical.it, giancarlo.fortino@unical.it, sdas@mst.edu
Abstract
This paper proposes a novel federated algorithm that leverages
momentum-based variance reduction with adaptive learning
to address non-convex settings across heterogeneous data. We
intend to minimize communication and computation overhead,
thereby fostering a sustainable federated learning system. We
aim to overcome challenges related to gradient variance, which
hinders the model’s efficiency, and the slow convergence result-
ing from learning rate adjustments with heterogeneous data.
The experimental results on the image classification tasks with
heterogeneous data reveal the effectiveness of our suggested
algorithms in non-convex settings with an improved commu-
nication complexity of O(ϵ−1)to converge to an ϵ-stationary
point - compared to the existing communication complexity
O(ϵ−2)of most prior works. The proposed federated version
maintains the trade-off between the convergence rate, number
of communication rounds, and test accuracy while mitigating
the client drift in heterogeneous settings. The experimental
results demonstrate the efficiency of our algorithms in image
classification tasks (MNIST, CIFAR-10) with heterogeneous
data.
Introduction
Federated learning (FL) is a distributed machine learn-
ing technique allowing multiple devices to train a model
collaboratively without sharing their data with a central
server(McMahan et al. 2017, Initial version posted on arXiv
in February 2016). Traditional federated learning architec-
ture consists of a global server and several local devices. The
global server, also known as the federated server, generates
a global model. The global model is sent to the local clients
to process the locally generated data. The parameters used in
the local models are sent to the global server, and the global
server aggregates the parameters and sends back the updated
global model to the local devices to enhance the performance
of the local models using locally generated data. Usually,
devices are heterogeneous. However, there are several issues
related to the traditional federated learning.
First, several global rounds are performed in a traditional
FL framework to converge with non-independent, Identically
Distributed (non-IID) client datasets and high communica-
tion costs per round. Hence, it enhances the communication
Copyright ©2025, Accepted the FLUID workshop@Association
for the Advancement of Artificial Intelligence (www.aaai.org). All
rights reserved.complexity. Second, the number of clients participating in
each iteration also affects the performance of the traditional
FL. Every client can’t participate in each round due to slow
convergence for tuning the learning rate. Tuning the learning
rate also enhances the computational complexity. A lower
number of client participants increases the number of global
rounds due to the non-IID data as each local device processes
its personalized data. Third, the training data are dispersed
widely among many devices, and there is a complicated con-
nection time between each device and the central server. The
slow communication is a direct result, which inspired FL
algorithms with high communication efficiency. Fourth, non-
convex optimization is among the most vital techniques in
contemporary machine learning (e.g., neural network train-
ing). Convergence in non-convex settings is still challenging
in FL settings with non-IID data.
Cross-silo and cross-device are two different settings of
FL. The cross-silo environment is associated with a limited
number of dependable clients, usually businesses like banks
or medical facilities. On the other hand, there could be a
very big number of clients in the cross-device federated learn-
ing environment—all 6.84 billion active smartphones, for
instance. Because of this, we might never, ever go through all
of the clients’ data during training in that particular scenario.
Additional characteristics of the cross-device environment
include resource-poor clients interacting via a very unsta-
ble network. The combination of this setting’s fundamental
elements creates special difficulties that do not exist in the
cross-silo setting. In this paper, we consider the more difficult
cross-device scenario. Notably, recent developments in FL
optimization, including FedDyn (Acar et al. 2021), SCAF-
FOLD (Karimireddy et al. 2020b), and (Wu et al. 2023) are
no longer relevant because they were created for cross-silo
environments.
The computation related to federated learning design, de-
spite the model used for training, is directly related to energy
consumption. Hence, we need an optimized FL algorithm
that gives convergence guarantees. Due to the issues men-
tioned above, designing an optimized FL that converges is
challenging. Several FL algorithms proposed so far mainly
focus on non-iid data and communication efficiency. To de-
velop the convergence guarantee for the FL algorithm, the
researchers assumed either iid data or the participation of allarXiv:2412.11660v1 [cs.LG] 16 Dec 2024
Page 2:
the devices1. However, these assumptions are not feasible
in real scenarios. In the real world, it is necessary to handle
heterogeneous (non-IID) data to mitigate the “client-drift”
problem. Several optimization algorithms are proposed in
the FL literature to mitigate the client drift issue. However,
very few are considered in both the non-convex and non-IID
settings and provide the convergence analysis for the same.
This paper aims to decrease the number of communica-
tion rounds, which ultimately reduces energy consumption.
According to (McMahan et al. 2017, Initial version posted
on arXiv in February 2016), FL settings are more prone to
communication overhead than computation overhead. The
objective is to increase computation cost to minimize the com-
munication rounds to train a model. This can be achievable
either by increasing parallelism or increasing the computation
of each device. Participation of more clients with higher com-
putation can reduce the number of communication rounds
(McMahan et al. 2023). Most FL algorithms consider a fixed
learning rate for all clients in each iteration. However, not
all parameters will benefit from this uniform learning rate,
which could lead to slow convergence due to high compu-
tation for tuning of learning rate. While certain parameters
may require smaller adjustments to avoid overshooting the
optimum value, others may require more frequent updates to
speed up convergence.
To tackle this, algorithms with adjustable learning rates
were created. With these methods, the algorithm can traverse
the optimization landscape more quickly by modifying the
learning rate of each parameter according to its past gradients.
Even though several FL techniques have been put forth, their
undesirable convergence behavior prevents them from signif-
icantly lowering communication costs by avoiding frequent
transmission between the central server and local worker
nodes. Numerous factors contribute to its cause, including (1)
client drift (Karimireddy et al. 2020b), in which local client
models approach local rather than global optima; (2) lack of
adaptivity as stochastic gradient descent (SGD) based update
(Reddi et al. 2021) and (3) large training iterations to con-
verge as a model parameter and training dataset sizes keep
growing. Even with these new developments, most research
is devoted to fixing client drifts (Karimireddy et al. 2020b;
Khanduri et al. 2021; Xu et al. 2022). Not all problems can
be resolved by the FL system that is in place now.
Several optimization techniques are available for gradient
descent to reduce the loss function and converge the algo-
rithm. Momentum-based updates and adaptive learning rates
are the two popularly used optimization techniques for gra-
dient descent. The deep learning algorithms normally used
as a model in the FL frameworks are of nonconvex settings.
Momentum-based updates are used to overcome local min-
ima and accelerate convergence, and the adaptive learning
rates are used to improve the convergence speed and accuracy
of the algorithm. Traditional optimizers, such as Adaguard,
RMProp, and Adam, are used in adaptive federated learning.
However, the structure of adaptive FL systems is challenging
because the local device moves in diverse directions, and
1Throughout the article, we refer to network things like nodes,
clients, sensors, or organizations as “devices.”the global servers cannot be updated often in the FL sce-
nario. Convergence problems may arise from the adaptive
FL method’s flawed architecture (Chen, Li, and Li 2020).
FedAdagrad, FedYogi, and FedAdam are the federated ver-
sions of adaptive optimizers first proposed in (Reddi et al.
2021). However, this analysis is only valid when β1 = 0 .
Therefore, it is unable to make use of momentum. To over-
come this issue, MimeAdam is proposed in (Karimireddy
et al. 2020a), which applies server statistics locally. How-
ever, MimeAdam needs to calculate the entire local gradient,
which may not be allowed in real life.
FedAMS is a more recent proposal (Wang, Lin, and Chen
2023) that offers complete proof but does not increase the
pace of convergence. FedAdagrad, FedYogi, FedAdam, and
FedAMS have sample convergence rates of O(ϵ−4)over-
all, which is not better than FedAvg. They also need an
additional global learning rate to tune simultaneously. In
(Rostami and Kia 2023), the authors used a stochastic and
variance-reduced technique called Stochastic Variance Re-
duced Gradient (SVRG) for local updates. This method cal-
culated the full batch gradient of each participating device
at each iteration which ultimately enhanced the computa-
tional complexity. In (Wu et al. 2023), the authors proposed
an effective adaptive FL algorithm (FAFED) in cross-silo
settings using the momentum-based variance-reducing tech-
nique. FAFED shows the fast convergence with the best-
known sample complexity O(ϵ−3)and communication com-
plexity O(ϵ−2). However, the non-convex setting required
for deep neural network (DNN) algorithms is not addressed.
Moreover, DNN requires higher GPU memory while ana-
lyzing the images. Hence, large batch sizes or calculation of
gradient checkpoints increase the computational complexity
with slow convergence.
Based on the issues mentioned above, we propose a new
FL algorithm in this paper using the adaptive learning rate
with variance reduction of momentum SGD in non-convex
settings. Specifically, we suggest a general energy-efficient,
optimized FL framework using the concept of reduced com-
munication rounds in which (1) the devices train over several
epochs using an adaptive learning rate optimizer to mini-
mize loss on their local data and (2) the server updates its
global model by applying a momentum-based gradient op-
timizer to the average of the clients’ model updates, while
the clients train over several epochs using a client optimizer
to minimize loss on their local data. The proposed method
combines momentum-based variance reduction updates and
adaptive learning rate optimization techniques without in-
creasing client storage or communication costs for local and
global updates.
Contributions. our major contributions are as follows.
•We design a federated learning algorithm where
momentum-based variance reduction is used for global
updates in addition to the momentum-based local updates
with adaptive learning. The design of the proposed FL al-
gorithm is aimed at addressing two main issues in order to
speed up the convergence of Federated Learning (FL) and
another issue to mitigate the high computation with non-
IID setting for non-convex problems. These issues are the
high variances associated with: (i) the noise from local
Page 3:
client-level stochastic gradients, and (ii) the global server
aggregation step due to varying client characteristics when
multiple local updates are present. A high computation
issue is associated with tuning the learning rate. We can
tackle these issues separately by reducing the variance in
the global server update and the local client updates using
both global and local momentum with adaptive learning.
•We give convergence analysis in generic non-convex set-
tings and design new federated optimization methods com-
patible with multiple devices using the proposed FL algo-
rithm. The convergence analysis shows that the proposed
FL algorithm converges to an ϵ−stationary point for
smooth non-convex problems with a non-IID setting with
improved communication complexity of O(ϵ−1)com-
pared with state-of-the-art approaches.
•Experimental results with non-IID data in non-convex set-
tings reveal the effectiveness of our suggested algorithm
which encourages communication efficiency and subtly
reduces client drift under heterogeneous data distribution.
Related Works
While several FL algorithms have been proposed, most of
them are related to communication efficiency and data pri-
vacy. Our work focuses on communication-efficient and early
convergent FL algorithms. Federated Averaging (FedAvg)
(McMahan et al. 2017, Initial version posted on arXiv in
February 2016) is the first and most popularly used FL al-
gorithm. However, FedAvg fails to guarantee theoretically
with non-IID in the convex optimization setting. FedAvg does
not completely address the fundamental issues related to het-
erogeneity, despite the fact that it has empirically proven
successful in diverse environments. FedAvg in the setting of
systems heterogeneity typically drops devices that are unable
to compute E epochs within a given time window, rather than
allowing participating devices to undertake varying amounts
of local work based on their underlying systems restrictions.
To address this issue, several regularization approaches (Acar
et al. 2021; Gao et al. 2022; Karargyris et al. 2023; Mendieta
et al. 2022) are used to enforce local optimization. FedProx
(Li et al. 2020) is another FL algorithm without assuming
IID data and the participation of all the client devices.
Specifically, FedProx exhibits noticeably more accurate
and steady convergence behavior than FedAvg in extremely
heterogeneous situations. The FedProx methodology relies
on a gradient similarity assumption. The FedProx adds a prox-
imal term for each local objective to prove the convergence.
However, several presumptions are essential to implement
FedProx in realistic federated contexts. The SCAFFOLD
approach (Karimireddy et al. 2020b) achieves convergence
rates independent of the degree of heterogeneity by utiliz-
ing control variates to lessen client drift. Although the ap-
proach works well for cross-silo FL, it cannot be used for
cross-device FL because it requires clients to keep their states
constant over rounds. A similar approach, FedDyn (Acar et al.
2021), was proposed to reduce communication rounds using
linear and quadratic penalty terms. Despite their best efforts,
FedAvg’s performance is not entirely understood. However,
it is proved from the above that client drift is one of the ma-jor issues for performance degradation in FL. It introduced
control variate to adjust local updates in order to address the
client drift issue. Furthermore, MOON (Li, He, and Song
2021) suggested using model contrastive learning to use the
similarity between local and global model representations
to rectify the local training. Various research attempts to en-
hance FedAvg in various ways. For example, in (Wang et al.
2021), the authors explained that the objective inconsistency
is the reason behind the slow convergence and also provided
the converge analysis of the proposed and previous methods.
On the other hand, FedMa (Wang et al. 2020) was proposed
to enhance the convergence rate by matching and averaging
hidden elements and also reduce the communication cost.
Compared with existing works on FL (McMahan et al.
2017, Initial version posted on arXiv in February 2016), (Li
et al. 2020), (Karimireddy et al. 2020b), this paper uses two
different optimization methods to reduce the communication
rounds to make the proposed FL algorithm an energy-efficient
one. In contrast to state-of-the-art FL algorithms, the client
devices optimize their local models using an adaptive learn-
ing rate optimizer to minimize loss on their local data, and
the server updates its global model by applying a momentum-
based gradient optimizer to the average of the clients’ model
updates. However, none of them studies the variance reduc-
tion with adaptive learning in the federated environment.
Table 1 the comparison of our proposed method with state-
of-the-art methods which considered non-IID settings and
non-convex functions.
Algorithms 1 and 2 depict the federated version for clients
and server updates using momentum-based variance reduc-
tion with the adaptive learning rate.
Preliminaries
Focusing on non-convex settings, our objective is to solve an
optimization problem with the help of the functions
min
x∈Rdf(ω)
where f(ω) =1
NNX
i=1Fi(ω)(1)
Fi(ω) =Ex∼DiLi(x;ω), which gives the loss of the pre-
diction on data sample (xi, yi)with model parameters ω,
ω∈ Z andDidenotes the data distribution for the ithclient.
When the distributions Diare different among the clients, it
is known as the heterogeneous data setting. Each Fi(ω)is
differentiable, possibly non-convex satisfies L-Lipschitz con-
tinuous gradient (i.e, L-smooth) for some parameter L >0.
For each client iandω, the true gradient ▽Fi(ω)is assumed
to have an unbiased stochastic gradient gi(ω). The loss on
an example is denoted by f(x, ω), and the model’s train-
ing loss is denoted by F. We can derive sample functions
f(·, x)such that E[f(·, x)] =F(·), where xdenotes a batch
of data samples. Problem 1 encompasses a wide range of
machine learning and imaging processing issues, such as
neural network training, non-convex loss ERM (empirical
risk minimization), and many more. As per the traditional
standard for non-convex optimization, our attention is di-
Page 4:
Reference Client Participation Compression? BCD? ALR? T
(Koloskova et al. 2020) Full No Yes No O(1
Nϵ2)
(Haddadpour et al. 2020) Full Yes Yes No O(1
Nϵ2)
(Khanduri et al. 2021) Full No Yes No O(1
Nϵ1.5)
(Karimireddy et al. 2020a) Partial No Yes No O(1√
Rϵ1.5)
(Karimireddy et al. 2020b) Partial No Yes No O(1
Rϵ1.5)
(Das et al. 2022) Partial Yes No No O(max (pα
N,1√
R)1
ϵ1.5)∗1
Ours Partial No No Yes O(1
ϵ1)
Table 1: Parameter description in the SOTA and our proposed FL algorithms: Here, N is the total number of clients, P is the number of
participating clients in each round, either N=P (Full) or P < N (Partial), compression denoted the either compressed communication or not,
BCD denotes the bounded client dissimilarity assumptions, ALR denotes the use of adaptive learning rate or not and T denotes the number of
gradient updates to achieve E[|fω|2]≤ϵon smooth non-convex settings
rected towards algorithms that can effectively identify an
approximate stationary point x that satisfies ∥▽f(x)∥2≤ϵ.
Furthermore, we also assume the following.
Assumption 1. (Smoothness). The function FiisL1-smooth
if its gradient is L1-Lipschitz, that is ∥▽Fi(x)−▽Fi(y)∥ ≤
L∥x−y∥for all x, y∈Rd. We also have: f(y)≤f(x) +
⟨▽f(x), y−x⟩+L
2∥y−x∥2.
Assumption 2. (Bounded Variance). The function Fi
have σ-bounded (local) variance i.e., E[∥▽[fi(x, ω)]j−
[▽Fi(x)]j∥2]≤σ2for all x∈Rd,j∈[d]andi∈
[n]. Moreover, we assume that the global variance is also
bounded, (1/n)Pn
i=1∥▽[Fi(x)]j−[f(x)]j∥2≤ ▽2g, jfor
allx∈Rd,j∈[d].
Assumption 3. (Bounded Gradients). The function fi(x, ω)
haveG-bounded gradients i.e., for any i∈[n], x∈Rdand
ω∈ Z,∥[▽fi(x, ω)]j∥ ≤G,∀j∈[d].
Assumption 4. (Unbiased Gradients). Each component
function fi(x;ω)computed at each client node is unbiased
∀xi∼ D i,i∈[N]andx∈Rd.
Assumption 5. Non-negativity . Considering each fi(ω)is
non-negative and hence, f∗
i=minf i(ω)≥0. While compu-
tation, most of the loss functions are usually positive. In case
of negative, we add some constant value to make it positive.
Definition I (ϵ-stationary Point). ϵ-stationary point, x
satisfies ∥▽f(x)∥ ≤ ϵ. Furthermore, in tinteractions, a
stochastic algorithm is considered to reach an ϵ-stationary
point if E[∥▽f(x)∥]≤ϵ, where the expectation is over the
algorithm’s randomness up to time t.
Our objective is to find a critical point of F, where
▽F(ω) = 0 , i.e., to converge to an ϵ-stationary point for
smooth non-convex functions, i.e., E[∥▽f(x)∥2]≤ϵ. Only
we are trying to access stochastic gradient on arbitrary points.
In this experiment, we consider stochastic gradient descent
(SGD) as a standard algorithm. Using the following recursion,
SGD generates a series of iterates ω1, ω2,···, ωN.
ωi+1=ωi−ηigi (2)
where gi=▽f(xi, ωi), f(·, x1),···, f(·, xN)are either the
i.i.d. or non-i.i.d samples from a distribution D. The learningratesη1, η2,···, ηN∈Raffect the performance if not tuned
properly. The learning rate, i.e., step size plays a crucial role
in the convergence of SGD. In order to address this sensitivity
and render SGD resilient to parameter selection, “adaptive”
SGD methods are frequently employed, in which the step-size
is determined dynamically utilizing stochastic gradient data
from both the current and previous samples (Cutkosky and
Orabona 2019; Khanduri et al. 2020). In this work, we present
one such “adaptive” technique for distributed non-convex
stochastic optimization that designs the step sizes based on
the available stochastic gradient information. Proper selection
of the learning rate, ηiin SGD ensures that a randomly chosen
recurrence ωiensures E[∥▽F(ωi)∥]≤O(1/N1/4).
Variance reduction
The use of momentum-based variance reduction performs
well in prior works to reduce the number of hyperparameters
such as batch size. The momentum-based variance reduction
is also used to reduce the variance of gradient in non-convex
updates. In this work, the variance is represented as follows:
mi= (1−β)mi−1+β▽f(xi, ωi)
+(1−β)(▽f(xi, ωi)− ▽f(xi−1, ωi))
ωi+1=ωi−ηimi(3)
One additional term, (1−β)(▽f(xi, ωi)−▽f(xi−1, ωi))is
used to the update with adaptive learning rate. The variance
reduction concept is similar to conventional reduction, where
two different gradients are used in each step.
Proposed Algorithm
To improve FL’s convergence rate and communication
overhead the following problems must be resolved: (i)
the high variance of simple averaging used in the global
server aggregation step when there are multiple local
updates, which is made worse by client heterogeneity; (ii)
the high variance linked to the noise of local client-level
stochastic gradients; and the (iii) heterogeneity among the
local functions. The key idea of the proposed FL algorithm
is to apply momentum-based variance reduction for the
global and local updates with the adaptive learning rate
for clients. Hence, in the optimized federated non-convex
Page 5:
Algorithm 1: Algorithm for Client Update
Input :c,k,w, Initial point: ω1
1:η0←k
w1/3
2:for all t= 1,2,···Tdo
3: for all clients i= 1,2,···Nin parallel do
4: for all local epochs j= 1,2,···Edo
5: ifj= 1then
6: Setm(t,j)
i =▽fi(ω(t,j)
i),ˆm(t−1,j)
i =
▽fi(ˆω(t−1,j)
i )
7: else
8: Each client i, randomly select a batch size,
B(t,j)
i.
9: Compute the stochastic gradient of
the non-convex loss function, fiover
B(t,j)
i atω(t,j)
i,ˆω(t−1,j)
i ,ω(t,j−1)
i , and
ˆω(t−1,j−1)
i , i.e., e∇fi(ω(t,j)
i;B(t,j)
i),
e∇fi(ω(t−1,j)
i ;B(t,j)
i),e∇fi(ω(t,j−1)
i ;B(t,j)
i)
ande∇fi(ω(t−1,j−1)
i ;B(t,j)
i).
10: ηt←k
(wt+e∇fi(ω(t,j)
i))1/3
11: Update m(t,j)
i =e∇fi(ω(t,j)
i;B(t,j)
i) +
m(t,j−1)
i −e∇fi(ω(t,j−1)
i ;B(t,j)
i) and
m(t−1,j)
i =e∇fi(ω(t−1,j)
i ;B(t,j)
i) +
m(t−1,j−1)
i −e∇fi(ω(t−1,j−1)
i ;B(t,j)
i)
12: end if
13: Update ω(t,j+1)
i =ω(t,j)
i−ηtm(t,j)
i and
ˆω(t−1,j+1)
i = ˆω(t−1,j)
i −ηtˆω(t−1,j)
i .
14: end for
15: end for
16:end for
17:Send (ωt−ωt,E
i)and((ωt−ωt,E
i)−(ωt−1−ˆωt−1,E
i ))
to the server.
heterogeneous settings, the total of Nnumber of clients
are jointly trying to solve the following optimization problem:
min
ω∈R[F(ω) :=NX
i=1SiFi(ω)] (4)
where Si=Ni
Nis the relative sample size, Fi(ω) =
1
NiP
x∈Difi(x;ω)is the ithclient’s local objective function.
In our case, the learning model defined fias a non-convex
loss function and x∈ Direpresents the data sample from the
local data Di. After receiving the current global model ω(t,0)
attthcommunication rounds, each client independently ex-
ecutes κiiterations of the local solver to optimize its local
objective. In our work, the local solver is the momentum-
based variance reduction with an adaptive learning rate.
The number of local updates κifor each client ican vary
in our theoretical framework. If a client run Eepochs with
batch size B, then κi=⌊ENi
B⌋, where Niis the sample data
of the ithclient. Recall from the simplest and most popular
FL algorithm, FedAvg (McMahan et al. 2017, Initial versionposted on arXiv in February 2016), the update rule can be
written as follows:
ω(t+1,0)−ω(t,0)=NX
i=1Si∆t
i=−NX
i=1Si.ηκi−1X
j=0gi(ω(t,j)
i)
(5)
where ω(t,j)
i represents the model of the i-th client in the t-th
communication rounds after the j-th local update, at round t,
∆t
iis the local parameters changes of i-th client. In our work,
we have modified the update rule where local updates are
performed using momentum-based variance reduction with
an adaptive learning rate, and the global update is performed
using momentum-based variance reduction. The client and
server update rules are demonstrated using the following
algorithms.
Algorithm 2: Algorithm for Server Update
Require: Initial parameter ω1, Number of communication
rounds T, Number of participating clients R≤N,
Epochs E
Ensure: ω0=ω1
1:for all round t= 1,2,···do
2: Server broadcasts ωt, ωt−1to a the set of Rnumber
of clients, St
3: for all client i∈ Stdo
4: Setω(t,0)
i=ωtandˆω(t−1,0)
i =ωt−1. Execute Al-
gorithm 1.
5: end for
6: ift= 0then
7: Setˆm(t)=1
RP
i∈S(t)(ω(t)−ω(t,E)
i){ˆmis the
global momentum}
8: else
9: Setˆm(t)=β(t)
RP
i∈S(t)(ω(t)−ω(t,E)
i) + (1 −
β(t)) ˆm(t−1)+1−β(t)
RP
i∈S(t)((ω(t)−ω(t,E)
i)−
(ω(t−1)−ˆω(t−1,E)
i ))
10: end if
11: Update ω(t+1)=ωt−ˆm(t)
12:end for
Cutosky and Orabona (Cutkosky and Orabona 2019) first
proposed the variance reduction method in non-convex set-
tings and proposed an algorithm, namely STORM. In this
paper, the idea of the STORM is considered as an update rule
in a slightly different way. Equation 3 represents the update
rule in STORM. In equation 3, β∈[0,1)is the momen-
tum parameter per iteration and xiis the randomly selected
sample in i-th iteration. The stochastic gradient at ωi+1is
calculated on xi. In Algorithm 1, lines 6, 11, and 13 give the
client updates using momentum-based variance reduction us-
ing adaptive learning rate after running Enumber of epochs.
It is pertinent to mention that the concept is similar to 3 where
β= 0and the full gradient is considered in the computation
of the initial gradient. Similarly, in the Algorithm 2, global
momentum is used for server aggregation. To overcome the
pitfall of the used server aggregation scheme in conventional
federated averaging such as FedAvg (McMahan et al. 2017,
Page 6:
Initial version posted on arXiv in February 2016), in this
paper, we consider the gradient as (ω(t)−ω(t,E)
i).
Empirical Analysis
This section empirically studies the proposed method and
compares it with different baseline methods to demonstrate its
efficacy by applying it to solve federated learning problems
and image classification tasks. Experiments are implemented
using PyTorch, and we run all experiments on CPU machines
with 3.50 GHz Intel Xeon with NVIDIA GeForce RTX 2080
Ti GPU.
Implementation Details
Dataset Our image classification tasks will involve utilizing
the MNIST dataset and CIFAR-10 dataset with 50 clients in
the network. The MNIST dataset comprises 60,000 training
images and 10,000 testing images, categorized into 10 classes.
Each image is comprised of 28 ×28 arrays of grayscale pix-
els. The CIFAR-10 dataset encompasses 50,000 training
images and 10,000 testing images, featuring 60,000 32 ×32
color images sorted into 10 categories. Each client will be
equipped with an identical Convolutional Neural Network
(CNN) model serving as the classifier, and the loss function
employed will be cross entropy we use Dirichlet distribution
to generate a non-IID dataset with α= 0.5
Baselines In our study we perform a comparative analysis
of our approach with SOTA methods: FedAvg (McMahan
et al. 2017, Initial version posted on arXiv in February 2016),
FedProx (Li et al. 2020), FedNova(Wang et al. 2021) and
FedGLOMO (Das et al. 2022). FedAvg is one of the federated
optimization approaches that is most frequently used among
them. One could think of FedProx as a re-parametrization and
generalization of FedAvg. FedNova is a quick error conver-
gence normalized averaging method that removes objective
inconsistency. FedGLOMO used the local and global update
using momentum-based variance reduction with quantized
messages to reduce communication complexity. The major
difference between our proposed method and FedGLOMO is
that no message quantization is used. We also use an adaptive
learning rate strategy for local updates for early convergence.
Network Architecture In our experiment, the neural network
is used with two different layers and the ReLU activation
function. The size of the hidden layers is 600. We train the
models by using the categorical cross-entropy loss with ℓ2-
regularization. PyTorch’s weight decay value is set to 1e−4
to apply ℓ2-regularization.
Validation
Figures 1 and 2 demonstrate the cross-entropy-based training
loss and accuracy with the CIFAR-10 dataset concerning the
number of epochs on the federated clients. Similarly, Figures
4 and 3 show the training loss and accuracy with the CIFAR-
10 dataset concerning the number of communication rounds.
Moreover, Figures 5 and 6 demonstrate the test accuracy and
loss in each communication round. The test loss and accu-
racy clearly depict the early convergence of our proposed
method. Similarly, Figures 7 and 8 demonstrate the cross-
entropy-based training loss and accuracy with the MNIST
Figure 1: Train Loss vs Epochs
with CIFAR-10
Figure 2: Train Accuracy vs
Epochs with CIFAR-10
Figure 3: Train Accuracy vs
Communications with CIFAR-
10
Figure 4: Train Loss vs Com-
munications with CIFAR-10
Figure 5: Test Accuracy with
CIFAR-10
Figure 6: Test Loss with
CIFAR-10
dataset concerning the number of epochs on the federated
clients. Similarly, Figures 10 and 9 show the training loss and
accuracy with the MNIST dataset concerning the number of
communication rounds. Figures 11 and 12 represent the test
accuracy and loss using the MNIST dataset and also prove
the early convergence. Using both the datasets, the experi-
mental results depict that, while our proposed FL algorithm
is marginally better than AdaGLOMO on test accuracy, on
both the training loss and accuracy our proposed algorithm
appears to be somewhat faster in terms of the number of
communication rounds. Moreover, the testing accuracy is
more stable than the SOTA methods as the variations of the
accuracy are less in our proposed method. The test accuracy
and loss results depict the early convergence of the proposed
method.
Page 7:
Mechanisms Non-IID Data Sample Communication
Adaptive Learning Momentum Variance Reduction MNIST CIFAR-10
Local Global Accuracy(%) Loss Accuracy(%) Loss
X X X X 75.24±0.45 0.67 72.45±0.34 0.86 O(ϵ−4) O(ϵ−4)
X ✓ X X 90.19±0.40 0.41 86.92±0.42 0.73 O(ϵ−4) O(ϵ−4)
✓ X X X 90.78±0.45 0.20 87.05±0.45 0.46 O(ϵ−4) O(ϵ−4)
✓ ✓ X X 91.24±0.45 0.0024 86.45±0.38 0.0078 O(ϵ−3) O(ϵ−3)
X ✓ ✓ X 89.24±0.45 0.0026 85.32± 0.0032 O(ϵ−4) O(ϵ−4)
X ✓ ✓ ✓ 96.23±0.45 0.0024 90.45±0.54 0.0120 O(ϵ−3) O(ϵ−1.5)
✓ ✓ ✓ ✓ 98.04±0.40 0.0017 92.12±0.35 0.0020 O(ϵ−3/2) O(ϵ−1)
Table 2: Ablation Studies with different components of the proposed method
Figure 7: Train Loss vs Epochs
with MNIST
Figure 8: Train Accuracy vs
Epochs with MNIST
Figure 9: Train Accuracy vs
Communications with MNIST
Figure 10: Train Loss vs Com-
munications with MNIST
Figure 11: Test Accuracy with
MNIST
Figure 12: Test Loss with
MNIST
Ablation Studies
In the proposed FL algorithm we use local updates with adap-
tive learning and global updates. Both the local and global
updates use momentum-based variance reduction. Table 2
shows the performance with different components such as
SGD with adaptive learning, momentum-based variance re-duction for local updates, and global updates. For this experi-
ment, we have used 80 epochs with 400 rounds and a batch
size of 50. Out of 100 clients, a random number of clients
participated in each round. It should be noted that the size of
the batch used in our experiment is small. This confirms the
effectiveness of the proposed method in overcoming the need
for large batch sizes for convergence.
Conclusion
In this study, a communication-efficient approach to feder-
ated learning is proposed, leveraging both momentum-based
variance reduction and adaptive learning rates. Both local and
global updates integrate momentum-based variance reduc-
tion and adaptive learning rates for the local models, thereby
reducing communication cycles for early convergence. The
proposed method exhibits robust performance in non-convex
heterogeneous settings with non-IID data, surpassing sev-
eral state-of-the-art methods. The principal aim of this paper
is to minimize loss, a demonstration accomplished through
thorough empirical and theoretical analyses.
The proposed method is experimented on a cross-device
FL framework. While performing the theoretical analysis,
several assumptions are made to prove the convergence. Here,
the belief that all the clients participate to show the reduced
complexity is a major limitation of this work. In the future,
we will come up with cross-silo FL with a similar strategy.
Page 8:
References
Acar, D. A. E.; Zhao, Y .; Navarro, R. M.; Mattina, M.; What-
mough, P. N.; and Saligrama, V . 2021. Federated Learning
Based on Dynamic Regularization. ArXiv , abs/2111.04263.
Chen, X.; Li, X.; and Li, P. 2020. Toward Communication
Efficient Adaptive Gradient Method. Proceedings of the 2020
ACM-IMS on Foundations of Data Science Conference .
Cutkosky, A.; and Orabona, F. 2019. Momentum-based vari-
ance reduction in non-convex SGD . Red Hook, NY , USA:
Curran Associates Inc.
Das, R.; Acharya, A.; Hashemi, A.; sujay sanghavi; Dhillon,
I. S.; and ufuk topcu. 2022. Faster Non-Convex Federated
Learning via Global and Local Momentum. In The 38th
Conference on Uncertainty in Artificial Intelligence .
Gao, L.; Fu, H.; Li, L.; Chen, Y .; Xu, M.; and Xu, C.-Z.
2022. FedDC: Federated Learning with Non-IID Data via
Local Drift Decoupling and Correction. In 2022 IEEE/CVF
Conference on Computer Vision and Pattern Recognition
(CVPR) , 10102–10111.
Haddadpour, F.; Kamani, M. M.; Mokhtari, A.; and Mahdavi,
M. 2020. Federated Learning with Compression: Unified
Analysis and Sharp Guarantees. In International Conference
on Artificial Intelligence and Statistics , volume 130.
Karargyris, A.; Umeton, R.; Sheller, M. J.; Aristizabal, A.;
George, J.; et al. 2023. Federated benchmarking of med-
ical artificial intelligence with MedPerf. Nature Machine
Intelligence , 5(7): 799–810.
Karimireddy, S. P.; Jaggi, M.; Kale, S.; Mohri, M.; Reddi,
S. J.; Stich, S. U.; and Suresh, A. T. 2020a. Mime: Mimicking
Centralized Stochastic Algorithms in Federated Learning.
CoRR , abs/2008.03606.
Karimireddy, S. P.; Kale, S.; Mohri, M.; Reddi, S.; Stich, S.;
and Suresh, A. T. 2020b. SCAFFOLD: Stochastic Controlled
Averaging for Federated Learning. In III, H. D.; and Singh,
A., eds., Proceedings of the 37th International Conference on
Machine Learning , volume 119 of Proceedings of Machine
Learning Research , 5132–5143. PMLR.
Khanduri, P.; Sharma, P.; Kafle, S.; Bulusu, S.; Rajawat,
K.; and Varshney, P. K. 2020. Distributed Stochastic Non-
Convex Optimization: Momentum-Based Variance Reduc-
tion. arXiv:2005.00224.
Khanduri, P.; Sharma, P.; Yang, H.; Hong, M.-F.; Liu, J.; Ra-
jawat, K.; and Varshney, P. K. 2021. STEM: A Stochastic
Two-Sided Momentum Algorithm Achieving Near-Optimal
Sample and Communication Complexities for Federated
Learning. ArXiv , abs/2106.10435.
Koloskova, A.; Loizou, N.; Boreiri, S.; Jaggi, M.; and Stich, S.
2020. A Unified Theory of Decentralized SGD with Chang-
ing Topology and Local Updates. In III, H. D.; and Singh, A.,
eds., Proceedings of the 37th International Conference on
Machine Learning , volume 119 of Proceedings of Machine
Learning Research , 5381–5393. PMLR.
Li, Q.; He, B.; and Song, D. 2021. Model-Contrastive Feder-
ated Learning. In 2021 IEEE/CVF Conference on Computer
Vision and Pattern Recognition (CVPR) , 10708–10717.Li, T.; Sahu, A. K.; Zaheer, M.; Sanjabi, M.; Talwalkar, A.;
and Smith, V . 2020. Federated Optimization in Heteroge-
neous Networks. arXiv:1812.06127.
McMahan, H. B.; Moore, E.; Ramage, D.; Hampson, S.;
and y Arcas, B. A. 2017, Initial version posted on arXiv in
February 2016. Communication-efficient learning of deep
networks from decentralized data. In In Proceedings of the
20th International Conference on Artificial Intelligence and
Statistics , 1273–1282.
McMahan, H. B.; Moore, E.; Ramage, D.; Hampson, S.; and
y Arcas, B. A. 2023. Communication-Efficient Learning of
Deep Networks from Decentralized Data. arXiv:1602.05629.
Mendieta, M.; Yang, T.; Wang, P.; Lee, M.; Ding, Z.; and
Chen, C. 2022. Local Learning Matters: Rethinking Data
Heterogeneity in Federated Learning. In 2022 IEEE/CVF
Conference on Computer Vision and Pattern Recognition
(CVPR) , 8387–8396.
Reddi, S.; Charles, Z. B.; Zaheer, M.; Garrett, Z.; Rush, K.;
Kone ˇcný, J.; Kumar, S.; and McMahan, B., eds. 2021. Adap-
tive Federated Optimization .
Rostami, M.; and Kia, S. S. 2023. Federated Learning Using
Variance Reduced Stochastic Gradient for Probabilistically
Activated Agents. In 2023 American Control Conference
(ACC) , 861–866.
Wang, H.; Yurochkin, M.; Sun, Y .; Papailiopoulos, D.; and
Khazaeni, Y . 2020. Federated Learning with Matched Aver-
aging. In International Conference on Learning Representa-
tions .
Wang, J.; Liu, Q.; Liang, H.; Joshi, G.; and Poor, H. V . 2021.
A Novel Framework for the Analysis and Design of Hetero-
geneous Federated Learning. IEEE Transactions on Signal
Processing , 69: 5234–5249.
Wang, Y .; Lin, L.; and Chen, J. 2023. Communication-
Efficient Adaptive Federated Learning. arXiv:2205.02719.
Wu, X.; Huang, F.; Hu, Z.; and Huang, H. 2023. Faster adap-
tive federated learning. In Proceedings of the Thirty-Seventh
AAAI Conference on Artificial Intelligence and Thirty-Fifth
Conference on Innovative Applications of Artificial Intelli-
gence and Thirteenth Symposium on Educational Advances
in Artificial Intelligence , AAAI’23/IAAI’23/EAAI’23. AAAI
Press. ISBN 978-1-57735-880-0.
Xu, J.; Du, W.; Jin, Y .; He, W.; and Cheng, R. 2022. Ternary
Compression for Communication-Efficient Federated Learn-
ing. IEEE Transactions on Neural Networks and Learning
Systems , 33(3): 1162–1176.
Page 9:
Appendix / Supplemental material
In this section, we present the convergence analysis and proof
of the convergence of the proposed federated learning algo-
rithm.
Convergence Analysis
The convergence is proved using the following theorems.
Theorem .1. Based on the Assumptions 1, 2, 3,
and 4 and for initial batch size B=bE, we
setηt=k
(wt+σ2t)1/3,k=(bN)2/3σ2/3
L. Also set
c=(8L)2
bN+σ2
24LEk3=L2(64
bN+1
24(bN)2E)and
w=max{(4LkE)3−σ2t,2σ2,ck
16LE3
}. Then
we can set the local updates, E and batch size, b as follows:
E=O
(T/N2)v/3
, b=O
(T/N2)1/2−v/2
(6)
where v∈[0,1].
After applying the variance reduction with the adaptive
learning rate for noise removal, according to Theorem 1, we
have
E∥▽f(ω)∥2=Of(ω1)−f∗
N2v/3T1−v/3
+˜Oσ2
N2v/3T1−v/3
+
˜O▽2
N2v/3T1−v/3
(7)
For any v∈[0,1], the sample complexity of the proposed
method is ˜O
ϵ−3/2
. Hence, each client involved in a com-
munication cycle required at most ˜O
N−1ϵ−3/2
gradient
computations. Moreover, the communication complexity is
˜O
ϵ−1
.
Before going to the proof of the above convergence analy-
sis, We explain some of the tradeoffs as follows:
The tradeoff between sample size and communication com-
plexity: From the above theorem, the sample and communica-
tion complexities are represented as ˜O
ϵ−3/2
and˜O
ϵ−1
when E and b are selected efficiently. In the FL literature,
we find the logarithmic factor complexity using either irre-
spective of the sample or batch Lipschitz smooth assumption.
Hence, our FL framework outperforms in terms of sample
and communication complexities. Based on our assumptions
the˜O
ϵ−1
is the optimal complexity in comparison with
SOTA methods.
The tradeoff between the batch sizes and local updates: The
number of local updates E and the batch sizes b are balanced
using the parameters v∈[0,1]. The equations in 6 demon-
strate the relation between E and b with the interval [0,1].
Asvgrows from o to 1 the batch size b decreases and the
local updates E increases. At v= 1, the b is O(1), however,
E=O(T1/3
N2/3). Conversely, at v= 0,b=O(T1/2
N), however,
E isO(1). These concepts generalize the SGD and minibatch
SGD with local and global updates using momentum-based
variance reduction.
Before going to the details of the proof, it is necessary to
discuss some lemmas in detail.Preliminary Lemmas
Lemma .2. We can define the error term as
ϵt=mt−1
NPN
n=1▽f(k)(ωn
t), then the iterations
according to Algorithm 1
E
*
(1−βt)ϵt−1,1
NNX
n=11
bX
xn
t∈B(n)
th
▽f(n)
ωn
t;x(n)
t
−f(n)(ωn
t)
−(1−βt)
▽f(n)(ωn
t−1;x(n)
t)−▽f(n)(ωn
t−1)i+
= 0
Proof. Let for some value of N, the gradient error term is
ϵt−1. For all n∈N, the only randomness in the left half of
the Lemma statement w.r.t xn
t. This suggests that we’ve
E
*
(1−βt)ϵt−1,1
NNX
n=11
bX
xn
t∈B(n)
th
▽f(n)(ωn
t;x(n)
t)
−f(n)(ωn
t)
−(1−βt)
▽f(n)(ωn
t−1;x(n)
t)−▽f(n)(ωn
t−1)i+
=E
*
(1−βt)ϵt−1,1
NNX
n=11
bX
xn
t∈B(n)
th
▽f(n)(ωn
t;x(n)
t)−
f(n)(ωn
t)
−(1−βt)
▽f(n)(ωn
t−1;x(n)
t)−▽f(n)(ωn
t−1)i
|Ft+
where Ft=σ(ωn
1, ωn
2,···, ωn
t)for all n∈[N]. Ifxn
tis
chosen randomly with uniform distribution at each k∈K
andE[▽fn(ω(n);xk
t)] =▽f(n)(ω(n)
t), then we have
E
1
bX
xt
t∈B(n)
th
▽f(n)(ωn
t;xn
t)− ▽f(n)(ω(n)
t)
−(1−βt)
▽f(n)(ωn
t−1;xn
t−1)−▽f(n)(ω(n)
t)i
|Ft
= 0
for all n∈N.
Hence, proved.
Theorem .3. Choosing the parameters as
•k=(bN)2/3σ2/3
L
•c=(8L)2
bN+σ2
24LEk3=L2(64
bN+1
24(bN)2E)
•w=max{(4LkE)3−σ2t,2σ2,ck
16LE3
}
Page 10:
and for any v∈[0,1]at each client and the total number of
local updates E=O
(T/N2)v/3
,
batch size b=O
(T/N2)1/2−v/2
and the initial batch
size,B=bE, the proposed FL algorithms satisfies the
following:
(i)E∥▽f(ω)∥2=O
f(ω1)−f∗
N2v/3T1−v/3
+˜O
σ2
N2v/3T1−v/3
+
˜O
▽2
N2v/3T1−v/3
.
(ii)Sample complexity: To reach the ϵ-stationary point the
proposed FL algorithm requires at most O(ϵ−3/2)gra-
dient computations.Hence, each client requires at most
O(N−1ϵ−3/2)gradient computations.
(iii) Computation Complexity: To reach the ϵ-stationary point
the proposed FL algorithm requires at most O(ϵ−1)com-
munication rounds.
Proof. •Proof of Statement (i): Put the values of B, E, and
b in the given expression and replace B=bE, we get
E|▽f(ω)|2≤[32LE
T+2L
(bN)2/3T2/3](f(ω1)−f∗) +
[8E
T+1
2(bN)2/3T2/3]σ2+[2562E
T+642
(bN)2/3T2/3]σ2log(T+
1) + [2562E
T+642
(bN)2/3T2/3]▽2E−1
Elog(T+ 1) Con-
sidering the fact that the total number of local updates
E=O
(T/N2)v/3
,
batch size b=O
(T/N2)1/2−v/2
we will get the ex-
pression mentioned in (i).
•Sample Complexity: From the above expression, the total
number of required iterations to achieve ϵ-stationary point
is as follows:
O(1
N2v/3T1−v/3) =ϵ=⇒T=O(1
N2v/(3−v)ϵ3/(3−v)).
Each client computes 2b stochastic gradients in each it-
eration. Hence, each client computes 2bTnumber of it-
erations in total. Using b=O
(T/N2)1/2−v/2
, at each
client, the number of total gradient computations is as
follows:
bT=O(T3/2−v/2
N1−v) =O(1
Nϵ3/2)
Therefore, the sample complexity is O(ϵ−3/2)
•Communication Complexity: To achieve ϵ-stationary
point, the total number of communication rounds are
T/B , with B=O
(T/N2)v/3
and the T=
O(1
N2v/(3−v)ϵ3/(3−v)), we get the communication com-
plexity as follows:
T/B =O(T1−v/3N2v/3) =O(1
ϵ)
Hence, the theorem.