Introduction
-
We propose an efficient federated training framework based on model structure pruning, which greatly reduces the demand for client computing and memory resources by dynamically selecting the optimal sub-model of the current global model for delivery.
-
We develop a new parameter aggregation update scheme, which provides training opportunities for global model parameters and maintains the complete model structure through model reconstruction and parameter reuse, reducing the error caused by pruning.
-
We conduct a large number of experiments on different data sets and data distributions to verify the effectiveness of the proposed framework, which reduces upstream and downstream communication while maintaining the accuracy of the global model and reducing client computing costs.
Related work
Edge inference/training based on model pruning
Efficient federated learning
Approach
Where to prune?
Model pruning and mask
Model aggregation and updating
Convergence analysis
Experiments
Performance indicators
Models and datasets
Different pruning rates
Different parameter selection criteria
-
Random. Parameters are randomly discarded.
-
L1 [15]. Using the sum of the absolute values of a filter as a criterion: \(\mathcal {L}\left( W_{i}^{l}\right) =\sum \vert \mathcal {W}(i,:,:,:)\vert \).
-
L2 [15]. \(\mathcal {L}\left( W_{i}^{l}\right) =\sum \Vert \mathcal {W}(i,:,:,:)\Vert _{2}\).
-
BN mask [45]. The \(\gamma \) of \(\hat{z}=\frac{z_\textrm{in}-\mu _{\mathcal {B}}}{\sqrt{\sigma _{\mathcal {B}}^{2}+\epsilon }}; z_\textrm{out}=\gamma \hat{z}+\beta \) in a BN layer is calculated as the corresponding filter’s importance score, where \(z_\textrm{in }\) and \(z_\textrm{out}\) be the input and output, \(\mu _{\mathcal {B}}\) and \(\sigma _{\mathcal {B}}\) are the mean and standard deviation values of input activations over the current minibatch \(\mathcal {B}\).
-
Similarity. Compare the similarity between filters and remove one of them: \(D^{(l)}=\textrm{dist}\left( W_{j}^{l}, W_{k}^{l}\right) , 0 \le j \le N_{l}, j \le k \le N_{l}\)
FC/MNIST (Acc = 0.90) | VGG/CIFAR-10 (Acc = 0.70) | ||||
---|---|---|---|---|---|
Rate | Com | Speed | Rate | Com | Speed |
Baseline | 292.17 M | 1 | Baseline | 1562.6 M | 1 |
0.41 | 17.09 M | 17.10\(\times \) | 0.35 | 623.2 M | 2.51\(\times \) |
0.52 | 18.41 M | 15.87\(\times \) | 0.50 | 843.64 M | 1.85\(\times \) |
0.63 | 16.76 M | 17.43\(\times \) | 0.63 | 1369.45 M | 1.14\(\times \) |
0.72 | 31.48 M | 9.28\(\times \) | |||
0.81 | 170.60 M | 1.71\(\times \) |
The efficiency of computation and communication
Method | VGG/CIFAR10 (Acc = 70%, Pruned = 50%) | FC/MNIST (Acc = 90%, Pruned = 70%) | ||
---|---|---|---|---|
IID | NonIID | IID | NonIID | |
FedAvg [6] | 425 | 438 | 108 | 144 |
Fed Dropout [46] | 569 | 432 | 280 | 267 |
PruneFL [11] | 378 | 396 | 176 | 127 |
Fed Pruning [12] | 386 | 419 | 138 | 156 |
AdaptCL [10] | 325 | 317 | 84 | 76 |
Ours | 242 | 210 | 46 | 51 |