-
Notifications
You must be signed in to change notification settings - Fork 0
/
simulation.py
117 lines (103 loc) · 4.35 KB
/
simulation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import flwr as fl
from omegaconf import DictConfig
from typing import Dict
import hydra
from flwr.common.typing import Scalar
from src.simulation.client import get_client_fn
from src.strategies import FedMAP
from datasets import synthetic_data_gen
from src.simulation import individual
def fit_config(server_round: int, epochs, fit_strategy, weights_contribution='sample_size', is_multi_class=False) -> Dict[str, Scalar]:
"""Return a configuration with static batch size and (local) epochs."""
return {
"epochs": epochs,
"batch_size": 64,
"server_round": server_round,
"fit_strategy": fit_strategy,
"weights_contribution": weights_contribution,
"is_multi_class": is_multi_class
}
def evaluate_config(fit_strategy, is_multi_class=False):
"""Return evaluation configuration dict for each round.
Perform five local evaluation steps on each client (i.e., use five
batches) during rounds, one to three, then increase to ten local
evaluation steps.
"""
return {
"fit_strategy": fit_strategy,
"is_multi_class": is_multi_class,
}
@hydra.main(config_path='./config', config_name='config', version_base=None)
def main(cfg: DictConfig):
for i in range(cfg.number_clients):
try:
os.remove(f"./models/model_{i}.pth")
except FileNotFoundError:
pass
is_multi_class = True
if cfg.datasets.name == 'synthetic':
synthetic_data_gen.generate_data_from_config(cfg)
is_multi_class = False
# Configure strategies
strategy_fedavg = fl.server.strategy.FedAvg(
fraction_fit=1,
fraction_evaluate=1,
min_fit_clients=cfg.number_clients,
min_evaluate_clients=cfg.number_clients,
min_available_clients=cfg.number_clients,
on_fit_config_fn=lambda server_round: fit_config(server_round, cfg.epochs, cfg.envs.name, False, is_multi_class),
on_evaluate_config_fn=lambda server_round: evaluate_config(cfg.envs.name, is_multi_class)
)
strategy_fedbn = fl.server.strategy.FedAvg(
fraction_fit=1,
fraction_evaluate=1,
min_fit_clients=cfg.number_clients,
min_evaluate_clients=cfg.number_clients,
min_available_clients=cfg.number_clients,
on_fit_config_fn=lambda server_round: fit_config(server_round, cfg.epochs, cfg.envs.name, False, is_multi_class),
on_evaluate_config_fn=lambda server_round: evaluate_config(cfg.envs.name, is_multi_class)
)
strategy_fedprox = fl.server.strategy.FedProx(
fraction_fit=1,
fraction_evaluate=1,
min_fit_clients=cfg.number_clients,
min_evaluate_clients=cfg.number_clients,
min_available_clients=cfg.number_clients,
on_fit_config_fn=lambda server_round: fit_config(server_round, cfg.epochs, cfg.envs.name, False, is_multi_class),
proximal_mu=5.0,
on_evaluate_config_fn=lambda server_round: evaluate_config(cfg.envs.name, is_multi_class)
)
strategy_fedmap = FedMAP(
fraction_fit=1,
fraction_evaluate=1,
min_fit_clients=cfg.number_clients,
min_evaluate_clients=cfg.number_clients,
min_available_clients=cfg.number_clients,
on_fit_config_fn=lambda server_round: fit_config(server_round, cfg.epochs, cfg.envs.name, weights_contribution='contribution', is_multi_class=is_multi_class),
on_evaluate_config_fn=lambda server_round: evaluate_config(cfg.envs.name, is_multi_class)
)
strategies = {
"fedavg": strategy_fedavg,
"fedmap": strategy_fedmap,
"fedprox": strategy_fedprox,
"fedbn": strategy_fedbn,
}
# Resources to be assigned to each virtual client
client_resources = {
"num_cpus": cfg.num_cpus,
"num_gpus": cfg.num_gpus,
}
if cfg.envs.name == "individual":
individual.train_val((cfg.datasets.name == 'synthetic'), cfg.envs.epochs, num_classes=cfg.datasets.num_classes, num_clients=cfg.number_clients)
else:
# Start simulation
fl.simulation.start_simulation(
client_fn=get_client_fn(cfg),
num_clients=cfg.number_clients,
client_resources=client_resources,
config=fl.server.ServerConfig(num_rounds=cfg.num_rounds),
strategy=strategies[cfg.envs.name]
)
if __name__ == "__main__":
main()