|
- # Copyright 2020 Adap GmbH. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Minimal example on how to start a simple Flower server."""
-
-
- import argparse
- from typing import Callable, Dict, Optional, Tuple
-
- #import torch
- #import torchvision
-
- import flwr as fl
-
- from . import DEFAULT_SERVER_ADDRESS, cifar
- import os
- import argparse
- from mindspore import context
- from mindspore.common import set_seed
- from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
- from mindspore.train.model import Model
- from mindspore.train.serialization import load_checkpoint, load_param_into_net
- from mindspore import save_checkpoint
- from src.CrossEntropySmooth import CrossEntropySmooth
- from src.resnet import resnet50 as resnet
- from src.config import config1 as config
- from src.dataset import create_dataset1 as create_dataset
-
- # pylint: disable=no-member
- DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
-
- # pylint: enable=no-member
-
-
- def main() -> None:
- """Start server and train five rounds."""
- parser = argparse.ArgumentParser(description="Flower")
- parser.add_argument(
- "--server_address",
- type=str,
- default=DEFAULT_SERVER_ADDRESS,
- help=f"gRPC server address (default: {DEFAULT_SERVER_ADDRESS})",
- )
- parser.add_argument(
- "--rounds",
- type=int,
- default=1,
- help="Number of rounds of federated learning (default: 1)",
- )
- parser.add_argument(
- "--sample_fraction",
- type=float,
- default=1.0,
- help="Fraction of available clients used for fit/evaluate (default: 1.0)",
- )
- parser.add_argument(
- "--min_sample_size",
- type=int,
- default=2,
- help="Minimum number of clients used for fit/evaluate (default: 2)",
- )
- parser.add_argument(
- "--min_num_clients",
- type=int,
- default=2,
- help="Minimum number of available clients required for sampling (default: 2)",
- )
- parser.add_argument(
- "--log_host",
- type=str,
- help="Logserver address (no default)",
- )
- parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
- parser.add_argument('--dataset', type=str, default="/root/jointcloud/data/eval",
- help='Dataset, either cifar10 or imagenet2012')
- parser.add_argument('--val_dir', type=str, default="/root/jointcloud/data/eval",
- help='Dataset, either cifar10 or imagenet2012')
- parser.add_argument('--checkpoint_path', type=str, default='/root/jointcloud/models/avg',
- help='Checkpoint file path')
- parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path')
- parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
- parser.add_argument('--model_path', type=str, default='models/avg', help='model path')
- parser.add_argument('--uuid', type=str, default='test', help='uuid.')
- args = parser.parse_args()
- target = args.device_target
-
- # Load evaluation data
- #_, testset = cifar.load_data()
- testset = create_dataset(dataset_path=args.dataset_path, do_train=False, batch_size=config.batch_size,
- target=target)
-
-
- # Create strategy
- strategy = fl.server.strategy.FedAvg(
- fraction_fit=args.sample_fraction,
- min_fit_clients=args.min_sample_size,
- min_available_clients=args.min_num_clients,
- eval_fn=get_eval_fn(testset),
- on_fit_config_fn=fit_config,
- )
-
- # Configure logger and start server
- fl.common.logger.configure("server", host=args.log_host)
- fl.server.start_server(
- args.server_address,
- config={"num_rounds": args.rounds},
- strategy=strategy,
- )
-
-
- def fit_config(rnd: int) -> Dict[str, fl.common.Scalar]:
- """Return a configuration with static batch size and (local) epochs."""
- config: Dict[str, fl.common.Scalar] = {
- "epoch_global": str(rnd),
- "epochs": str(1),
- "batch_size": str(32),
- }
- return config
-
-
- def get_eval_fn(
- testset,
- ) -> Callable[[fl.common.Weights], Optional[Tuple[float, float]]]:
- """Return an evaluation function for centralized evaluation."""
-
- def evaluate(weights: fl.common.Weights) -> Optional[Tuple[float, float]]:
- """Use the entire CIFAR-10 test set for evaluation."""
- net = resnet(class_num=config.class_num)
- net.set_train(False)
- loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
- model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
- #model = cifar.load_model()
- model.set_weights(weights)
- model.to(DEVICE)
- # dataset = create_dataset(dataset_path=config.dataset_path, do_train=False, batch_size=config.batch_size,
- # target=target)
- # step_size = dataset.get_dataset_size()
- #testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
- accu = model.eval(testset, dataset_sink_mode=False)
- return accu
-
- return evaluate
-
-
- if __name__ == "__main__":
- main()
|