Source code for fedscale.cloud.execution.tensorflow_client

import logging
import tensorflow as tf
from overrides import overrides
from fedscale.cloud.execution.client_base import ClientBase
import numpy as np

from fedscale.cloud.internal.tensorflow_model_adapter import TensorflowModelAdapter


[docs]class TensorflowClient(ClientBase): """Implements a TensorFlow-based client for training and evaluation.""" def __init__(self, args): """ Initializes a tf client. :param args: Job args """ self.args = args def _convert_np_to_tf_dataset(self, dataset): """ Converts the iterable numpy dataset to a tensorflow Dataset. :param dataset: numpy dataset :return: tf.data.Dataset """ def gen(): while True: for x, y in dataset: # Convert torch tensor to tf tensor nx, ny = tf.convert_to_tensor(x.swapaxes(1, 3).numpy()), \ tf.one_hot(tf.convert_to_tensor(y.numpy()), self.args.num_classes) yield nx, ny # Sample a batch to get tensor properties temp_x, temp_y = next(gen()) x_shape, y_shape = temp_x.shape.as_list(), temp_y.shape.as_list() x_shape[0], y_shape[0] = None, None return tf.data.Dataset.from_generator( gen, output_shapes=(tf.TensorShape(x_shape), tf.TensorShape(y_shape)), output_types=(temp_x.dtype, temp_y.dtype), )
[docs] @overrides def train(self, client_data, model, conf): """ Perform a training task. :param client_data: client training dataset :param model: the framework-specific model :param conf: job config :return: training results """ client_id = conf.client_id logging.info(f"Start to train (CLIENT: {client_id}) ...") tf_dataset = self._convert_np_to_tf_dataset(client_data).take(conf.local_steps) history = model.fit(tf_dataset, batch_size=conf.batch_size, verbose=1) # Report the training results results = {'client_id': client_id, 'moving_loss': sum(history.history['loss']) / (len(history.history['loss']) + 1e-4), 'trained_size': history.history['row_count'], 'success': True, 'utility': 1} logging.info(f"Training of (CLIENT: {client_id}) completes, {results}") results['update_weight'] = [np.asarray(layer.get_weights()) for layer in model.layers if layer.trainable] results['wall_duration'] = 0 return results
[docs] @overrides def test(self, client_data, model, conf): """ Perform a testing task. :param client_data: client evaluation dataset :param model: the framework-specific model :param conf: job config :return: testing results """ results = model.evaluate(self._convert_np_to_tf_dataset(client_data), batch_size=conf.batch_size, return_dict=True) for key, value in results.items(): if key != 'row_count': results[key] = results['row_count'] * value results['test_len'] = results['row_count'] return results
[docs] @overrides def get_model_adapter(self, model) -> TensorflowModelAdapter: """ Return framework-specific model adapter. :param model: the model :return: a model adapter containing the model """ return TensorflowModelAdapter(model)