|
- import utils
- from typing import Dict, Optional, Tuple
- import argparse
- import AISyncore as syncore
-
- utils.setup_seed(1)
-
- _config = utils._config
-
-
-
- def get_eval_fn(model):
- """Return an evaluation function for server-side evaluation."""
-
- # Load data and model here to avoid the overhead of doing it in `evaluate` itself
-
-
- _, testloader, _ = utils.load_partition_dataloader(-1, batch_size=_config.BATCH_SIZE)
-
-
- # The `evaluate` function will be called after every round
- def evaluate(
- weights: syncore.common.Weights,
- ) -> Optional[Tuple[float, Dict[str, syncore.common.Scalar]]]:
- # Update model with the latest parameters
-
- utils.set_model_params(model, weights)
- loss, accuracy = utils.test(model, testloader, _config.DEVICE, utils.fit_server_config(None, 'server'))
-
- return loss, {"accuracy": accuracy}
-
- return evaluate
-
-
- def main():
- """Load model for
- 1. server-side parameter initialization
- 2. server-side parameter evaluation
- """
-
- # Parse command line argument `partition`
- parser = argparse.ArgumentParser(description="Flower")
-
- parser.add_argument(
- "--pretrained",
- action="store_true",
- )
-
- args = parser.parse_args()
- model = utils.load_model()
- model_weights = utils.get_model_params(model)
-
- # Create strategy
- strategy = syncore.server.strategy.FedAvg(
- fraction_fit=1,
- fraction_eval=1,
- min_fit_clients=_config.TOTAL_CLIENTS,
- min_eval_clients=_config.TOTAL_CLIENTS,
- min_available_clients=_config.TOTAL_CLIENTS,
- eval_fn=get_eval_fn(model),
- on_fit_config_fn=utils.fit_server_config,
- initial_parameters=model_weights,
- )
-
- # Start Flower server for four rounds of federated learning
- syncore.server.run_server(_config.TASK_SERVER_IP + ":" + _config.TASK_SERVER_PORT, config={"num_rounds": _config.EPOCHS}, strategy=strategy)
-
-
- if __name__ == "__main__":
- main()
|