import logging
import math
import pickle
from random import Random
from typing import Dict, List
from fedscale.core.internal.client import Client
[docs]class clientManager(object):
def __init__(self, mode, args, sample_seed=233):
self.Clients = {}
self.clientOnHosts = {}
self.mode = mode
self.filter_less = args.filter_less
self.filter_more = args.filter_more
self.ucbSampler = None
if self.mode == 'oort':
import os
import sys
current = os.path.dirname(os.path.realpath(__file__))
parent = os.path.dirname(current)
sys.path.append(parent)
from thirdparty.oort.oort import create_training_selector
# sys.path.append(current)
self.ucbSampler = create_training_selector(args=args)
self.feasibleClients = []
self.rng = Random()
self.rng.seed(sample_seed)
self.count = 0
self.feasible_samples = 0
self.user_trace = None
self.args = args
if args.device_avail_file is not None:
with open(args.device_avail_file, 'rb') as fin:
self.user_trace = pickle.load(fin)
self.user_trace_keys = list(self.user_trace.keys())
[docs] def registerClient(self, hostId, clientId, size, speed, duration=1):
self.register_client(hostId, clientId, size, speed, duration)
[docs] def register_client(self, hostId: int, clientId: int, size: int, speed: Dict[str, float], duration: float=1) -> None:
"""Register client information to the client manager.
Args:
hostId (int): executor Id.
clientId (int): client Id.
size (int): number of samples on this client.
speed (Dict[str, float]): device speed (e.g., compuutation and communication).
duration (float): execution latency.
"""
uniqueId = self.getUniqueId(hostId, clientId)
user_trace = None if self.user_trace is None else self.user_trace[self.user_trace_keys[int(
clientId) % len(self.user_trace)]]
self.Clients[uniqueId] = Client(hostId, clientId, speed, user_trace)
# remove clients
if size >= self.filter_less and size <= self.filter_more:
self.feasibleClients.append(clientId)
self.feasible_samples += size
if self.mode == "oort":
feedbacks = {'reward': min(size, self.args.local_steps*self.args.batch_size),
'duration': duration,
}
self.ucbSampler.register_client(clientId, feedbacks=feedbacks)
else:
del self.Clients[uniqueId]
[docs] def getAllClients(self):
return self.feasibleClients
[docs] def getAllClientsLength(self):
return len(self.feasibleClients)
[docs] def getClient(self, clientId):
return self.Clients[self.getUniqueId(0, clientId)]
[docs] def registerDuration(self, clientId, batch_size, upload_step, upload_size, download_size):
if self.mode == "oort" and self.getUniqueId(0, clientId) in self.Clients:
exe_cost = self.Clients[self.getUniqueId(0, clientId)].getCompletionTime(
batch_size=batch_size, upload_step=upload_step,
upload_size=upload_size, download_size=download_size
)
self.ucbSampler.update_duration(
clientId, exe_cost['computation']+exe_cost['communication'])
[docs] def getCompletionTime(self, clientId, batch_size, upload_step, upload_size, download_size):
return self.Clients[self.getUniqueId(0, clientId)].getCompletionTime(
batch_size=batch_size, upload_step=upload_step,
upload_size=upload_size, download_size=download_size
)
[docs] def registerSpeed(self, hostId, clientId, speed):
uniqueId = self.getUniqueId(hostId, clientId)
self.Clients[uniqueId].speed = speed
[docs] def registerScore(self, clientId, reward, auxi=1.0, time_stamp=0, duration=1., success=True):
self.register_feedback(clientId, reward, auxi=auxi, time_stamp=time_stamp, duration=duration, success=success)
[docs] def register_feedback(self, clientId: int, reward: float, auxi: float=1.0, time_stamp: float=0, duration: float=1., success: bool=True) -> None:
"""Collect client execution feedbacks of last round.
Args:
clientId (int): client Id.
reward (float): execution utilities (processed feedbacks).
auxi (float): unprocessed feedbacks.
time_stamp (float): current wall clock time.
duration (float): system execution duration.
success (bool): whether this client runs successfully.
"""
# currently, we only use distance as reward
if self.mode == "oort":
feedbacks = {
'reward': reward,
'duration': duration,
'status': True,
'time_stamp': time_stamp
}
self.ucbSampler.update_client_util(clientId, feedbacks=feedbacks)
[docs] def registerClientScore(self, clientId, reward):
self.Clients[self.getUniqueId(0, clientId)].registerReward(reward)
[docs] def getScore(self, hostId, clientId):
uniqueId = self.getUniqueId(hostId, clientId)
return self.Clients[uniqueId].getScore()
[docs] def getClientsInfo(self):
clientInfo = {}
for i, clientId in enumerate(self.Clients.keys()):
client = self.Clients[clientId]
clientInfo[client.clientId] = client.distance
return clientInfo
[docs] def nextClientIdToRun(self, hostId):
init_id = hostId - 1
lenPossible = len(self.feasibleClients)
while True:
clientId = str(self.feasibleClients[init_id])
csize = self.Clients[clientId].size
if csize >= self.filter_less and csize <= self.filter_more:
return int(clientId)
init_id = max(
0, min(int(math.floor(self.rng.random() * lenPossible)), lenPossible - 1))
[docs] def getUniqueId(self, hostId, clientId):
return str(clientId)
# return (str(hostId) + '_' + str(clientId))
[docs] def clientSampler(self, clientId):
return self.Clients[self.getUniqueId(0, clientId)].size
[docs] def clientOnHost(self, clientIds, hostId):
self.clientOnHosts[hostId] = clientIds
[docs] def getCurrentClientIds(self, hostId):
return self.clientOnHosts[hostId]
[docs] def getClientLenOnHost(self, hostId):
return len(self.clientOnHosts[hostId])
[docs] def getClientSize(self, clientId):
return self.Clients[self.getUniqueId(0, clientId)].size
[docs] def getSampleRatio(self, clientId, hostId, even=False):
totalSampleInTraining = 0.
if not even:
for key in self.clientOnHosts.keys():
for client in self.clientOnHosts[key]:
uniqueId = self.getUniqueId(key, client)
totalSampleInTraining += self.Clients[uniqueId].size
# 1./len(self.clientOnHosts.keys())
return float(self.Clients[self.getUniqueId(hostId, clientId)].size)/float(totalSampleInTraining)
else:
for key in self.clientOnHosts.keys():
totalSampleInTraining += len(self.clientOnHosts[key])
return 1./totalSampleInTraining
[docs] def getFeasibleClients(self, cur_time):
if self.user_trace is None:
clients_online = self.feasibleClients
else:
clients_online = [clientId for clientId in self.feasibleClients if self.Clients[self.getUniqueId(
0, clientId)].isActive(cur_time)]
logging.info(f"Wall clock time: {round(cur_time)}, {len(clients_online)} clients online, " +
f"{len(self.feasibleClients)-len(clients_online)} clients offline")
return clients_online
[docs] def isClientActive(self, clientId, cur_time):
return self.Clients[self.getUniqueId(0, clientId)].isActive(cur_time)
[docs] def select_participants(self, num_of_clients: int, cur_time: float=0) -> List[int]:
"""Select participating clients for current execution task.
Args:
num_of_clients (int): number of participants to select.
cur_time (float): current wall clock time.
Returns:
List[int]: indices of selected clients.
"""
self.count += 1
clients_online = self.getFeasibleClients(cur_time)
if len(clients_online) <= num_of_clients:
return clients_online
pickled_clients = None
clients_online_set = set(clients_online)
if self.mode == "oort" and self.count > 1:
pickled_clients = self.ucbSampler.select_participant(
num_of_clients, feasible_clients=clients_online_set)
else:
self.rng.shuffle(clients_online)
client_len = min(num_of_clients, len(clients_online) - 1)
pickled_clients = clients_online[:client_len]
return pickled_clients
[docs] def resampleClients(self, numOfClients, cur_time=0):
return self.select_participants(numOfClients, cur_time)
[docs] def getAllMetrics(self):
if self.mode == "oort":
return self.ucbSampler.getAllMetrics()
return {}
[docs] def getDataInfo(self):
return {'total_feasible_clients': len(self.feasibleClients), 'total_num_samples': self.feasible_samples}
[docs] def getClientReward(self, clientId):
return self.ucbSampler.get_client_reward(clientId)