model module

Interfaces for ClientModel and ServerModel.

class model.Model(lr)[source]

Bases: abc.ABC

close()[source]
create_model()[source]

Creates the model for the task.

Returns:features: A placeholder for the samples’ features. labels: A placeholder for the samples’ labels. train_op: A Tensorflow operation that, when run with the features and
the labels, trains the model.
eval_metric_ops: A Tensorflow operation that, when run with features and labels,
returns the accuracy of the model.
Return type:A 4-tuple consisting of
optimizer

Optimizer to be used by the model.

process_x(raw_x_batch)[source]

Pre-processes each batch of features before being fed to the model.

process_y(raw_y_batch)[source]

Pre-processes each batch of labels before being fed to the model.

test(data)[source]

Tests the current model on the given data.

Parameters:data – dict of the form {‘x’: [list], ‘y’: [list]}
Returns:dict of metrics that will be recorded by the simulation.
train(data, num_epochs=1, batch_size=10)[source]

Trains the client model.

Parameters:
  • data – Dict of the form {‘x’: [list], ‘y’: [list]}.
  • num_epochs – Number of epochs to train.
  • batch_size – Size of training batches.
Returns:

Number of FLOPs computed while training given data update: List of np.ndarray weights, with each weight array

corresponding to a variable in the resulting graph

Return type:

comp

class model.ServerModel(model)[source]

Bases: object

close()[source]
cur_model
save(path='checkpoints/model.ckpt')[source]
send_to(clients)[source]

Copies server model variables to each of the given clients

Parameters:clients – list of Client objects
size
update(updates)[source]

Updates server model using given client updates.

Parameters:updates – list of (num_samples, update), where num_samples is the number of training samples corresponding to the update, and update is a list of variable weights