Source code for fedscale.cloud.execution.executor

# -*- coding: utf-8 -*-
import collections
import gc
import pickle
import random
import time
from argparse import Namespace

import numpy as np
import torch

import fedscale.cloud.channels.job_api_pb2 as job_api_pb2
import fedscale.cloud.logger.executor_logging as logger
from fedscale.cloud.channels.channel_context import ClientConnections
from fedscale.cloud.execution.tensorflow_client import TensorflowClient
from fedscale.cloud.execution.torch_client import TorchClient
from fedscale.cloud.execution.data_processor import collate, voice_collate_fn
from fedscale.cloud.execution.rl_client import RLClient
from fedscale.cloud.fllibs import *
from fedscale.dataloaders.divide_data import DataPartitioner, select_dataset


[docs]class Executor(object): """Abstract class for FedScale executor. Args: args (dictionary): Variable arguments for fedscale runtime config. defaults to the setup in arg_parser.py """ def __init__(self, args): # initiate the executor log path, and executor ips logger.initiate_client_setting() self.model_adapter = self.get_client_trainer(args).get_model_adapter(init_model()) self.args = args self.num_executors = args.num_executors # ======== env information ======== self.this_rank = args.this_rank self.executor_id = str(self.this_rank) # ======== model and data ======== self.training_sets = self.test_dataset = None # ======== channels ======== self.aggregator_communicator = ClientConnections( args.ps_ip, args.ps_port) # ======== runtime information ======== self.collate_fn = None self.round = 0 self.start_run_time = time.time() self.received_stop_request = False self.event_queue = collections.deque() super(Executor, self).__init__()
[docs] def setup_env(self): """Set up experiments environment """ logging.info(f"(EXECUTOR:{self.this_rank}) is setting up environ ...") self.setup_seed(seed=1)
[docs] def setup_communication(self): """Set up grpc connection """ self.init_control_communication() self.init_data_communication()
[docs] def setup_seed(self, seed=1): """Set random seed for reproducibility Args: seed (int): random seed """ torch.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed)
[docs] def init_control_communication(self): """Create communication channel between coordinator and executor. This channel serves control messages. """ self.aggregator_communicator.connect_to_server()
[docs] def init_data_communication(self): """In charge of jumbo data traffics (e.g., fetch training result) """ pass
[docs] def init_data(self): """Return the training and testing dataset Returns: Tuple of DataPartitioner class: The partioned dataset class for training and testing """ train_dataset, test_dataset = init_dataset() if self.args.task == "rl": return train_dataset, test_dataset if self.args.task == 'nlp': self.collate_fn = collate elif self.args.task == 'voice': self.collate_fn = voice_collate_fn # load data partitionxr (entire_train_data) logging.info("Data partitioner starts ...") training_sets = DataPartitioner( data=train_dataset, args=self.args, numOfClass=self.args.num_class) training_sets.partition_data_helper( num_clients=self.args.num_participants, data_map_file=self.args.data_map_file) testing_sets = DataPartitioner( data=test_dataset, args=self.args, numOfClass=self.args.num_class, isTest=True) testing_sets.partition_data_helper(num_clients=self.num_executors) logging.info("Data partitioner completes ...") return training_sets, testing_sets
[docs] def run(self): """Start running the executor by setting up execution and communication environment, and monitoring the grpc message. """ self.setup_env() self.training_sets, self.testing_sets = self.init_data() self.setup_communication() self.event_monitor()
[docs] def dispatch_worker_events(self, request): """Add new events to worker queues Args: request (string): Add grpc request from server (e.g. MODEL_TEST, MODEL_TRAIN) to event_queue. """ self.event_queue.append(request)
[docs] def deserialize_response(self, responses): """Deserialize the response from server Args: responses (byte stream): Serialized response from server. Returns: ServerResponse defined at job_api.proto: The deserialized response object from server. """ return pickle.loads(responses)
[docs] def serialize_response(self, responses): """Serialize the response to send to server upon assigned job completion Args: responses (string, bool, or bytes): TorchClient responses after job completion. Returns: bytes stream: The serialized response object to server. """ return pickle.dumps(responses)
[docs] def UpdateModel(self, model_weights): """Receive the broadcasted global model for current round Args: config (PyTorch or TensorFlow model): The broadcasted global model config """ self.round += 1 self.model_adapter.set_weights(model_weights)
[docs] def Train(self, config): """Load train config and data to start training on that client Args: config (dictionary): The client training config. Returns: tuple (int, dictionary): The client id and train result """ client_id, train_config = config['client_id'], config['task_config'] if 'model' not in config or not config['model']: raise "The 'model' object must be a non-null value in the training config." client_conf = self.override_conf(train_config) train_res = self.training_handler( client_id=client_id, conf=client_conf, model=config['model']) # Report execution completion meta information response = self.aggregator_communicator.stub.CLIENT_EXECUTE_COMPLETION( job_api_pb2.CompleteRequest( client_id=str(client_id), executor_id=self.executor_id, event=commons.CLIENT_TRAIN, status=True, msg=None, meta_result=None, data_result=None ) ) self.dispatch_worker_events(response) return client_id, train_res
[docs] def Test(self, config): """Model Testing. By default, we test the accuracy on all data of clients in the test group Args: config (dictionary): The client testing config. """ test_res = self.testing_handler() test_res = {'executorId': self.this_rank, 'results': test_res} # Report execution completion information response = self.aggregator_communicator.stub.CLIENT_EXECUTE_COMPLETION( job_api_pb2.CompleteRequest( client_id=self.executor_id, executor_id=self.executor_id, event=commons.MODEL_TEST, status=True, msg=None, meta_result=None, data_result=self.serialize_response(test_res) ) ) self.dispatch_worker_events(response)
[docs] def Stop(self): """Stop the current executor """ self.aggregator_communicator.close_sever_connection() self.received_stop_request = True
[docs] def report_executor_info_handler(self): """Return the statistics of training dataset Returns: int: Return the statistics of training dataset, in simulation return the number of clients """ return self.training_sets.getSize()
[docs] def override_conf(self, config): """ Override the variable arguments for different client Args: config (dictionary): The client runtime config. Returns: dictionary: Variable arguments for client runtime config. """ default_conf = vars(self.args).copy() for key in config: default_conf[key] = config[key] return Namespace(**default_conf)
[docs] def get_client_trainer(self, conf): """ Returns a framework-specific client that handles training and evaluation. :param conf: job config :return: framework-specific client instance """ if conf.engine == commons.TENSORFLOW: return TensorflowClient(conf) elif conf.engine == commons.PYTORCH: if conf.task == 'rl': return RLClient(conf) else: return TorchClient(conf) raise "Currently, FedScale supports tensorflow and pytorch."
[docs] def training_handler(self, client_id, conf, model): """Train model given client id Args: client_id (int): The client id. conf (dictionary): The client runtime config. Returns: dictionary: The train result """ self.model_adapter.set_weights(model) conf.client_id = client_id conf.tokenizer = tokenizer client_data = self.training_sets if self.args.task == "rl" else \ select_dataset(client_id, self.training_sets, batch_size=conf.batch_size, args=self.args, collate_fn=self.collate_fn ) client = self.get_client_trainer(self.args) train_res = client.train( client_data=client_data, model=self.model_adapter.get_model(), conf=conf) return train_res
[docs] def testing_handler(self): """Test model Args: args (dictionary): Variable arguments for fedscale runtime config. defaults to the setup in arg_parser.py config (dictionary): Variable arguments from coordinator. Returns: dictionary: The test result """ test_config = self.override_conf({ 'rank': self.this_rank, 'memory_capacity': self.args.memory_capacity, 'tokenizer': tokenizer }) client = self.get_client_trainer(test_config) data_loader = select_dataset(self.this_rank, self.testing_sets, batch_size=self.args.test_bsz, args=self.args, isTest=True, collate_fn=self.collate_fn) test_results = client.test(data_loader, self.model_adapter.get_model(), test_config) gc.collect() return test_results
[docs] def client_register(self): """Register the executor information to the aggregator """ start_time = time.time() while time.time() - start_time < 180: try: response = self.aggregator_communicator.stub.CLIENT_REGISTER( job_api_pb2.RegisterRequest( client_id=self.executor_id, executor_id=self.executor_id, executor_info=self.serialize_response( self.report_executor_info_handler()) ) ) self.dispatch_worker_events(response) break except Exception as e: logging.warning(f"Failed to connect to aggregator {e}. Will retry in 5 sec.") time.sleep(5)
[docs] def client_ping(self): """Ping the aggregator for new task """ response = self.aggregator_communicator.stub.CLIENT_PING(job_api_pb2.PingRequest( client_id=self.executor_id, executor_id=self.executor_id )) self.dispatch_worker_events(response)
[docs] def event_monitor(self): """Activate event handler once receiving new message """ logging.info("Start monitoring events ...") self.client_register() while not self.received_stop_request: if len(self.event_queue) > 0: request = self.event_queue.popleft() current_event = request.event if current_event == commons.CLIENT_TRAIN: train_config = self.deserialize_response(request.meta) train_model = self.deserialize_response(request.data) train_config['model'] = train_model train_config['client_id'] = int(train_config['client_id']) client_id, train_res = self.Train(train_config) # Upload model updates future_call = self.aggregator_communicator.stub.CLIENT_EXECUTE_COMPLETION.future( job_api_pb2.CompleteRequest(client_id=str(client_id), executor_id=self.executor_id, event=commons.UPLOAD_MODEL, status=True, msg=None, meta_result=None, data_result=self.serialize_response(train_res) )) future_call.add_done_callback(lambda _response: self.dispatch_worker_events(_response.result())) elif current_event == commons.MODEL_TEST: self.Test(self.deserialize_response(request.meta)) elif current_event == commons.UPDATE_MODEL: model_weights = self.deserialize_response(request.data) self.UpdateModel(model_weights) elif current_event == commons.SHUT_DOWN: self.Stop() elif current_event == commons.DUMMY_EVENT: pass else: time.sleep(1) try: self.client_ping() except Exception as e: logging.info(f"Caught exception {e} from aggregator, terminating executor {self.this_rank} ...") break
if __name__ == "__main__": executor = Executor(parser.args) executor.run()