Source code for fedscale.cloud.execution.rl_client

import logging
import math

from fedscale.cloud.execution.torch_client import TorchClient
from fedscale.cloud.execution.optimizers import ClientOptimizer

import fedscale.cloud.config_parser as parser
if parser.args.task == 'rl':
    from fedscale.dataloaders.dqn import *


[docs]class RLClient(TorchClient): """Basic client component in Federated Learning""" def __init__(self, conf): self.optimizer = ClientOptimizer() self.dqn = DQN(conf) pass
[docs] def train(self, client_data, model, conf): client_id = conf.client_id logging.info(f"Start to train (CLIENT: {client_id}) ...") device = self.device model = model.to(device=device) # self.dqn.eval_net = self.dqn.eval_net.to(device=device) # self.dqn.target_net = self.dqn.target_net.to(device=device) global_model = None if conf.gradient_policy == 'prox': # could be move to optimizer global_model = [param.data.clone() for param in model.parameters()] trained_unique_samples = conf.local_steps * conf.batch_size self.dqn.target_net.load_state_dict(model.state_dict()) completed_steps = 0 epoch_train_loss = 1e-4 error_type = None while completed_steps < conf.local_steps: try: s = client_data.env.reset() episode_reward_sum = 0 while True: a = self.dqn.choose_action(s) s_, r, done, info = client_data.env.step(a) x, x_dot, theta, theta_dot = s_ r1 = (client_data.env.x_threshold - abs(x)) / \ client_data.env.x_threshold - 0.8 r2 = (client_data.env.theta_threshold_radians - abs(theta) ) / client_data.env.theta_threshold_radians - 0.5 new_r = r1 + r2 self.dqn.store_transition(s, a, new_r, s_) episode_reward_sum += new_r s = s_ if self.dqn.memory_counter > conf.memory_capacity: loss = self.dqn.learn() loss_list = [loss.tolist()] loss = loss.mean() temp_loss = sum([l**2 for l in loss_list] )/float(len(loss_list)) if epoch_train_loss == 1e-4: epoch_train_loss = temp_loss else: epoch_train_loss = ( 1. - conf.loss_decay) * epoch_train_loss + conf.loss_decay * temp_loss completed_steps += 1 if done: # print('episode%s---reward_sum: %s' % (i, round(episode_reward_sum, 2))) break except Exception as ex: error_type = ex break model.load_state_dict(self.dqn.target_net.state_dict()) model_param = [param.data.cpu().numpy() for param in model.parameters()] results = {'client_id': client_id, 'moving_loss': epoch_train_loss, 'trained_size': completed_steps*conf.batch_size, 'success': completed_steps > 0} results['utility'] = math.sqrt( epoch_train_loss)*float(trained_unique_samples) if error_type is None: logging.info(f"Training of (CLIENT: {client_id}) completes, {results}") else: logging.info(f"Training of (CLIENT: {client_id}) failed as {error_type}") results['update_weight'] = model_param results['wall_duration'] = 0 return results
[docs] def test(self, client_data, model, conf): model = model.to(device=self.device) self.dqn.target_net.load_state_dict(model.state_dict()) self.dqn.set_eval_mode() env = gym.make('CartPole-v0').unwrapped reward_sum = 0 test_loss = 0 s = env.reset() while True: a = self.dqn.choose_action(s) s_, r, done, info = env.step(a) x, x_dot, theta, theta_dot = s_ r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8 r2 = (env.theta_threshold_radians - abs(theta)) / \ env.theta_threshold_radians - 0.5 new_r = r1 + r2 self.dqn.store_transition(s, a, new_r, s_) reward_sum += new_r s = s_ if self.dqn.memory_counter > conf['memory_capacity']: test_loss += self.dqn.learn() if done: break logging.info('Rank {}: Test set: Average loss: {}, Reward: {}' .format(conf['rank'], test_loss, reward_sum)) return 0, 0, 0, {'top_1': reward_sum, 'top_5': reward_sum, 'test_loss': test_loss, 'test_len': 1}