import random
from baseline_constants import BYTES_WRITTEN_KEY, BYTES_READ_KEY, LOCAL_COMPUTATIONS_KEY
[docs]class Server:
def __init__(self, model):
self.model = model # global model of the server.
self.selected_clients = []
self.updates = []
[docs] def select_clients(self, possible_clients, num_clients=20):
"""Selects num_clients clients randomly from possible_clients.
Note that within function, num_clients is set to
min(num_clients, len(possible_clients)).
Args:
possible_clients: Clients from which the server can select.
num_clients: Number of clients to select; default 20
Return:
list of (num_train_samples, num_test_samples)
"""
num_clients = min(num_clients, len(possible_clients))
self.selected_clients = random.sample(possible_clients, num_clients)
return [(len(c.train_data['y']), len(c.eval_data['y'])) for c in self.selected_clients]
[docs] def train_model(self, num_epochs=1, batch_size=10, minibatch=None, clients=None):
"""Trains self.model on given clients.
Trains model on self.selected_clients if clients=None;
each client's data is trained with the given number of epochs
and batches.
Args:
clients: list of Client objects.
num_epochs: Number of epochs to train.
batch_size: Size of training batches.
minibatch: fraction of client's data to apply minibatch sgd,
None to use FedAvg
Return:
bytes_written: number of bytes written by each client to server
dictionary with client ids as keys and integer values.
client computations: number of FLOPs computed by each client
dictionary with client ids as keys and integer values.
bytes_read: number of bytes read by each client from server
dictionary with client ids as keys and integer values.
"""
if clients is None:
clients = self.selected_clients
sys_metrics = {
c.id: {BYTES_WRITTEN_KEY: 0,
BYTES_READ_KEY: 0,
LOCAL_COMPUTATIONS_KEY: 0} for c in clients}
for c in clients:
self.model.send_to([c])
sys_metrics[c.id][BYTES_READ_KEY] += self.model.size
comp, num_samples, update = c.train(num_epochs, batch_size, minibatch)
sys_metrics[c.id][LOCAL_COMPUTATIONS_KEY] = comp
self.updates.append((num_samples, update))
sys_metrics[c.id][BYTES_WRITTEN_KEY] += self.model.size
return sys_metrics
[docs] def update_model(self):
self.model.update(self.updates)
self.updates = []
[docs] def test_model(self, clients_to_test=None):
"""Tests self.model on given clients.
Tests model on self.selected_clients if clients_to_test=None.
Args:
clients_to_test: list of Client objects.
"""
if clients_to_test is None:
clients_to_test = self.selected_clients
metrics = {}
self.model.send_to(clients_to_test)
for client in clients_to_test:
c_metrics = client.test(self.model.cur_model)
metrics[client.id] = c_metrics
return metrics
[docs] def get_clients_test_info(self, clients=None):
"""Returns the ids, hierarchies and num_test_samples for the given clients.
Returns info about self.selected_clients if clients=None;
Args:
clients: list of Client objects.
"""
if clients is None:
clients = self.selected_clients
ids = [c.id for c in clients]
groups = {c.id: c.group for c in clients}
num_samples = {c.id: c.num_test_samples for c in clients}
return ids, groups, num_samples