Source code for model

"""Interfaces for ClientModel and ServerModel."""

from abc import ABC, abstractmethod
import numpy as np
import os
import sys
import tensorflow as tf

from baseline_constants import ACCURACY_KEY

from utils.model_utils import batch_data
from utils.tf_utils import graph_size


[docs]class Model(ABC): def __init__(self, lr): self.lr = lr self._optimizer = None self.graph = tf.Graph() with self.graph.as_default(): self.features, self.labels, self.train_op, self.eval_metric_ops = self.create_model() self.saver = tf.train.Saver() self.sess = tf.Session(graph=self.graph) self.size = graph_size(self.graph) with self.graph.as_default(): self.sess.run(tf.global_variables_initializer()) metadata = tf.RunMetadata() opts = tf.profiler.ProfileOptionBuilder.float_operation() self.flops = tf.profiler.profile(self.graph, run_meta=metadata, cmd='scope', options=opts).total_float_ops @property def optimizer(self): """Optimizer to be used by the model.""" if self._optimizer is None: self._optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.lr) return self._optimizer
[docs] @abstractmethod def create_model(self): """Creates the model for the task. Returns: A 4-tuple consisting of: features: A placeholder for the samples' features. labels: A placeholder for the samples' labels. train_op: A Tensorflow operation that, when run with the features and the labels, trains the model. eval_metric_ops: A Tensorflow operation that, when run with features and labels, returns the accuracy of the model. """ return None, None, None, None
[docs] def train(self, data, num_epochs=1, batch_size=10): """ Trains the client model. Args: data: Dict of the form {'x': [list], 'y': [list]}. num_epochs: Number of epochs to train. batch_size: Size of training batches. Return: comp: Number of FLOPs computed while training given data update: List of np.ndarray weights, with each weight array corresponding to a variable in the resulting graph """ with self.graph.as_default(): init_values = [self.sess.run(v) for v in tf.trainable_variables()] batched_x, batched_y = batch_data(data, batch_size) for _ in range(num_epochs): for i, raw_x_batch in enumerate(batched_x): input_data = self.process_x(raw_x_batch) raw_y_batch = batched_y[i] target_data = self.process_y(raw_y_batch) with self.graph.as_default(): self.sess.run( self.train_op, feed_dict={self.features: input_data, self.labels: target_data} ) with self.graph.as_default(): update = [self.sess.run(v) for v in tf.trainable_variables()] update = [np.subtract(update[i], init_values[i]) for i in range(len(update))] comp = num_epochs * len(batched_y) * batch_size * self.flops return comp, update
[docs] def test(self, data): """ Tests the current model on the given data. Args: data: dict of the form {'x': [list], 'y': [list]} Return: dict of metrics that will be recorded by the simulation. """ x_vecs = self.process_x(data['x']) labels = self.process_y(data['y']) with self.graph.as_default(): tot_acc = self.sess.run( self.eval_metric_ops, feed_dict={self.features: x_vecs, self.labels: labels} ) acc = float(tot_acc) / x_vecs.shape[0] return {ACCURACY_KEY: acc}
[docs] def close(self): self.sess.close()
[docs] def process_x(self, raw_x_batch): """Pre-processes each batch of features before being fed to the model.""" return np.asarray(raw_x_batch)
[docs] def process_y(self, raw_y_batch): """Pre-processes each batch of labels before being fed to the model.""" return np.asarray(raw_y_batch)
[docs]class ServerModel: def __init__(self, model): self.model = model @property def size(self): return self.model.size @property def cur_model(self): return self.model
[docs] def send_to(self, clients): """Copies server model variables to each of the given clients Args: clients: list of Client objects """ var_vals = {} with self.model.graph.as_default(): all_vars = tf.trainable_variables() for v in all_vars: val = self.model.sess.run(v) var_vals[v.name] = val for c in clients: with c.model.graph.as_default(): all_vars = tf.trainable_variables() for v in all_vars: v.load(var_vals[v.name], c.model.sess)
[docs] def update(self, updates): """Updates server model using given client updates. Args: updates: list of (num_samples, update), where num_samples is the number of training samples corresponding to the update, and update is a list of variable weights """ tot_samples = np.sum([u[0] for u in updates]) weighted_vals = [np.zeros(np.shape(v), dtype=float) for v in updates[0][1]] for i, update in enumerate(updates): for j, weighted_val in enumerate(weighted_vals): weighted_vals[j] = np.add(weighted_val, update[0] * update[1][j]) weighted_updates = [v / tot_samples for v in weighted_vals] with self.model.graph.as_default(): all_vars = tf.trainable_variables() for i, v in enumerate(all_vars): init_val = self.model.sess.run(v) v.load(np.add(init_val, weighted_updates[i]), self.model.sess)
[docs] def save(self, path='checkpoints/model.ckpt'): return self.model.saver.save(self.model.sess, path)
[docs] def close(self): self.model.close()