1. Introducción

Este tutorial usa el dataset de MNIST con el API de alto nivel de TFF, llamado FL API (Federated Layer API) cuya clase principal es tff.learning, la cual incluye interfaces de alto nivel para tareas típicas de federación como entrenamiento y otras.

Existe otra API de más bajo nivel, FC API (Federated Core API), que permite implementar tus propios algoritmos de federación.

2. Setup

Instalamos versiones concretas de TF y TFF, que deben ser compatibles entre ellas. Ver tabla

!pip install --quiet --upgrade tensorflow_federated==0.17
!pip install --quiet --upgrade tensorflow==2.3.0

import tensorflow as tf
import tensorflow_federated as tff

print(tf.__version__)
print(tff.__version__)

La instalación del paquete nest_asyncio evita errores de ejecución en los protocolos de federación en Jupyter (del tipo "can not join a run while another loop is running"):

import nest_asyncio
nest_asyncio.apply()

3. Carga de datos

Cargamos los datos. TFF ya tiene una clase para cargar datos y poder separarla en clientes. El dataset que devuelve load_data() son instancias de tff.simulation.ClientData, un interfaz que permite escoger un conjunto de clientes, para posteriormente construir un tf.data.Dataset que represente los datos de un cliente, y así poder entrenar.

Nótese que este interfaz permite iterar sobre ID de clientes, pero en este tutorial sólo los estamos simulando. Los ID de clientes no se usan en el algoritmo de federación, sólo se usan para escoger los datos de cada cliente.

# Load simulation data.
source, _ = tff.simulation.datasets.emnist.load_data()
def client_data(n):
  return source.create_tf_dataset_for_client(source.client_ids[n]).map(
      lambda e: (tf.reshape(e['pixels'], [-1]), e['label'])
  ).repeat(10).batch(20)

len(source.client_ids)

Seleccionamos unos cuantos clientes que participarán en el entrenamiento:

# Pick a subset of client devices to participate in training.
train_data = [client_data(n) for n in range(3)]
source.element_type_structure

4. Construcción del modelo

Usamos un modelo estándar de Keras, el cual debe paquetizarse dentro de la clase tff.learning. Se usará la llamada tff.learning.from_keras_model, pasándole el modelo Keras original y los datos.

Nota: no hay que compilar ahora. Las métricas, pérdidas y optimizadores se configuran más tarde.

# Wrap a Keras model for use with TFF.
def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Dense(10, tf.nn.softmax, input_shape=(784,),
                            kernel_initializer='zeros')
  ])
  return tff.learning.from_keras_model(
      model,
      input_spec=train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

5. Simulación y evaluación

Ahora que tenemos el modelo empaquetado en una clase tff.learning.Model para usarlo en TFF, podemos ejecutar al algoritmo Federated Averaging usando la función tff.learning.build_federated_averaging_process.

El argumento debe ser un constructor (como la model_fn de arriba), no una instancia, para que la construcción y compilado del modelo lo pueda implementar TFF. Los detalles técnicos están en este tutorial.

El algoritmo Federated Averaging tiene dos optimizadores:

  1. _client_optimizer_: se usa para computar updates en cada cliente.
  2. _server_optimizer_: aplica la media al modelo global en el servidor.

Por tanto, la elección de los hiperparámeotrs de optimizador y learning rate serán distintos que los que se usan en el modelo Keras original de MNIST. Se recomienda empezar con SGD, con un learning rate pequeño.

# Simulate a few rounds of training with the selected client devices.
trainer = tff.learning.build_federated_averaging_process(
  model_fn,
  client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
state = trainer.initialize()

TFF implementa el algoritmo de Federated Averaging con dos computaciones empaquetadas en la clase tff.templates.IterativeProcess. Estas dos computaciones son initialize y next.

  1. La operación initialize es una función, como todas las computaciones federadas. No usa argumentos, y devuelve un resultado: la representación del estado del proceso de Federated Averaging en el servidor.
  2. La operación next es una función, y representa un ciclo completo de Federated Averaging, que consiste en desplegar el estado del servidor (incluyendo el modelo inicial) a los clientes, entrenamiento en los clientes, recogida de parámetros de vuelta, y generación de un modelo nuevo en el servidor.

Actualmente el algoritmo se soporta en local por ahora (se ejecuta toda en una máquina).

Se puede pensar en next como este flujo:

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

Es decir, next no es una función que se ejecuta en un servidor, sino que se ejecuta de manera distribuída, con un input del servidor (SERVER_STATE), y unas contribuciones de cada cliente con sus propios datos locales.

for _ in range(5):
  state, metrics = trainer.next(state, train_data)
  print(metrics['train']['loss'])

En este tutorial, se deberían escoger clientes aleatoriamente, pero reusamos los clientes para una covergencia más rápida. Las pérdidas deberían disminuir después de cada ciclo de entrenamiento federado, mostrando que el modelo converge.