1. Introducción

Este ejemplo muestra cómo hacer un agrupamiento (clustering) para el dataset de de imágenes de Cifar-10. Nótese que esto no es un problema de clasificación, sino un clustering con un tipo de entrenamiento NO supervisado.

Diferencias entre clasificación y agrupamiento ó clustering:

  • En la clasificación, las clases resultantes son dadas como parte del set de entrenamiento. Esta información es realmente usada durante el entrenamiento apra construir el clasificador. Posteriormente, se aplica el clasificador resultante sobre imágenes nuevas (sin clasificar previamente).
  • En el agrupamiento, se particionan las imágenes en varios grupos (clases resultantes). No se conoce el significado de esas clases,simplemente se sabe que estadísticamente son parecidas. Ejemplos de redes para clasificación es una red convolucional (aprendizaje supervisado). Ejemplo de algoritmos de agrupamiento es kMeans.

En este notebook se va a usar k-Means, que es un algoritmo de clasificación no supervisado (clustering) que agrupa objetos en k grupos basándose en sus características. El clustering in k grupos se realiza minimizando la suma de distancias (puede ser media ó cuadrática) entre cada objeto y el centroide de su cluster.

Este ejemplo se ha probado y funciona en Colab.

2. Setup

Importamos las librerías que vamos a usar. Usaremos la función experimental tensorflow.numpy para aprovechar las GPUs durante operaciones con funciones numpy (por ejemplo, durante la inferencia):

!pip3 install --user --upgrade tf-nightly
!pip3 install --user scikit-learn
import matplotlib.pyplot as plt

import tensorflow as tf
import numpy as np
from sklearn.utils import shuffle

from tensorflow.keras import datasets, layers, models

Comprobamos si tenemos GPUs. En caso contrario, no notaremos diferencia de velocidad:

print("All logical devices:", tf.config.list_logical_devices())
print("All physical devices:", tf.config.list_physical_devices())
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
All logical devices: [LogicalDevice(name='/device:CPU:0', device_type='CPU')]
All physical devices: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU')]
Num GPUs Available:  0

3. Carga de datos

Cargamos el dataset desde tensorflow.keras.datasets:

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# Usando scikit-learn, Hacemos shuffle y usamos sólo 2000 imágenes para entrenar y 100 para test
train_images, train_labels = shuffle(train_images, train_labels)
test_images, test_labels = shuffle(test_images, test_labels)

train_images = train_images[:2000]
train_labels = train_labels[:2000]
test_images = test_images[:100]
test_labels = test_labels[:100]

# Normalizamos valores de píxeles entre 0 y 1
train_images, test_images = train_images / 255.0, test_images / 255.

Visualizamos los primeros 25 elementos del dataset:

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    # The CIFAR labels happen to be arrays, 
    # which is why you need the extra index
    plt.xlabel(class_names[train_labels[i][0]])
plt.show()
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:12:48.343799 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

4. Preparar datos

El set de entrenamiento original era de tamaño train_images.shape=(50000,32,32,3) y sus etiquetas train_labels.shape=(50000, 1). Pero usamos uno más pequeño de 2000. Aplanamos el set de entrenamiento y también el de pruebas:

train_images_rows = train_images.reshape(train_images.shape[0], 32 * 32 * 3) # train_images_rows.shape 2000 x 3072
test_images_rows = test_images.reshape(test_images.shape[0], 32 * 32 * 3) # test_images_rows.shape 100 x 3072

5. Construcción del modelo

El algoritmo k-Means consta de tres pasos:

  1. Inicialización: se selecciona el número de centroides (grupos, NUM_CENTROIDS) y se crean con el tamaño de las imágenes, aleatorios.
  2. Asignación: cada imagen es asignada a su centroide más cercano.
  3. Actualización centroides: se actualiza la posición del centroide de cada grupo tomando como nuevo centroide la posición del promedio de los objetos pertenecientes a dicho grupo.

Se repiten los pasos 2 y 3 hasta que los centroides no cambian.

Para medir la precisión, se usa la métrica Purity, que es una medida de evaluación muy sencilla y transparente, en donde se cuenta la etiqueta más frecuente en cada cluster, y se calcula la pureza dividiendo ese valor por el número total de elementos del cluster. La pureza del algoritmo es simplemente la suma de las purezas de cada cluster. En el mejor caso, para un cluster de 10 centroides, la pureza debería ser 10 (NUM_CENTROIDS). Una definición más formal de pureza se puede ver aqui(Stanford), donde define la pureza (purity) de la siguiente manera:

To compute purity , each cluster is assigned to the class which is most frequent in the cluster, and then the accuracy of this assignment is measured by counting the number of correctly assigned documents and dividing by N.

# Paso 1: Inicialización
NUM_IMAGES = train_images.shape[0] # 2000
NUM_CENTROIDS = 10 # We know number of centroids beforehand

# Paso 2: Asignación a centroide más cercano
def assignment(centroids):
  closest = []

  for j in range(NUM_IMAGES):
    distances_from_centroid = []
    for i in range(NUM_CENTROIDS):
      distances_from_centroid = np.append(distances_from_centroid, np.sum(np.abs(centroids[i, :] - train_images_rows[j,:])))
    #print(distances_from_centroid)
    closest.append(np.argmin(distances_from_centroid))
  #print(closest)
  return closest

# Paso 3: Actualización centroides
def update(closest):
  purity = []
  new_centroids = np.zeros((NUM_CENTROIDS, 32*32*3))
  for c in range(NUM_CENTROIDS):
    num = 0 # num of IMAGES in the cluster
    best = []
    # Takes all images assigned to the centroid and calculates average
    for i in range(NUM_IMAGES):
      if (closest[i] == c):
        num = num + 1
        #print(train_images_rows[i, :])
        new_centroids[c, :] += train_images_rows[i, :]
        best = np.append(best, train_labels[i])

    # Discard if there are no images (num=0) assigned to a centroid
    # Calculate average
    if (num>0):
      new_centroids[c, :] = new_centroids[c, :] / num

    # Calculate PURITY for each cluster separately
    # Note clusters do not follow label ordering of train images, so cluster 0 does not equal to label_0 (airplane)
    unique, counts = np.unique(best, return_counts=True)
    #print("Cluster ", c ," has ", num, " images assigned. Most frequent label is: ", unique[counts == counts.max()])
    #print(dict(zip(unique, counts)))
    correct_label_count = 0
    # Fix case where two or more labels are the most frequent in a cluster simultaneously
    if (unique[counts == counts.max()].size > 1):
      best_label = unique[counts == counts.max()][0]
    else:
      best_label = unique[counts == counts.max()]
    # Calculate accuracy of each cluster
    for i in range(NUM_IMAGES):
        if (closest[i] == c):
            if (best_label != train_labels[i]):
                correct_label_count += 1
    purity = np.append(purity, correct_label_count/num)
  return new_centroids, purity  
# Ejecución del algoritmo k-Means: bucle hasta que los centroides no cambian
centroides = np.random.rand(NUM_CENTROIDS, 32*32*3)
while True:
    cercanos = assignment(centroides)
    nuevos_centroides, pureza = update(cercanos)
    print("Purity: ", np.sum(pureza), " Centroids convergence: ", np.sum(nuevos_centroides)-np.sum(centroides))
    if (np.allclose(nuevos_centroides, centroides)):
      break
    centroides = nuevos_centroides
Purity:  8.143573240427953  Centroids convergence:  -1011.9732804228533
Purity:  8.117762662055938  Centroids convergence:  74.80783684575181
Purity:  7.915912560407891  Centroids convergence:  121.73650328388248
Purity:  7.796077104366344  Centroids convergence:  92.51681301552344
Purity:  7.7482000340978505  Centroids convergence:  99.12812322555874
Purity:  7.6869370898431315  Centroids convergence:  65.22459724608962
Purity:  7.622074432728336  Centroids convergence:  40.922291395248976
Purity:  7.603127348295604  Centroids convergence:  42.50720856783482
Purity:  7.6029885866448135  Centroids convergence:  35.76182789657105
Purity:  7.586641777683099  Centroids convergence:  20.81112928343464
Purity:  7.578510156816363  Centroids convergence:  9.413361443677786
Purity:  7.587036260912649  Centroids convergence:  13.923756300480818
Purity:  7.590103527738729  Centroids convergence:  12.267108683232436
Purity:  7.573613677399101  Centroids convergence:  13.31292766352999
Purity:  7.5810286257399815  Centroids convergence:  11.1740192087982
Purity:  7.584263425481467  Centroids convergence:  12.531641388104617
Purity:  7.583250467666781  Centroids convergence:  8.272761105219615
Purity:  7.5942143516161495  Centroids convergence:  12.022298838617644
Purity:  7.586518825764817  Centroids convergence:  3.870052221696824
Purity:  7.586031829046814  Centroids convergence:  2.3619532464254007
Purity:  7.5854165899455435  Centroids convergence:  5.006124093717517
Purity:  7.58502887424958  Centroids convergence:  3.0531509890861344
Purity:  7.584230071156004  Centroids convergence:  2.689206672284854
Purity:  7.584077839949038  Centroids convergence:  1.1983034891727584
Purity:  7.583391882373221  Centroids convergence:  1.8873040262260474
Purity:  7.57875427949938  Centroids convergence:  -0.31221975871449104
Purity:  7.57141359323422  Centroids convergence:  6.717176250793273
Purity:  7.570639863543898  Centroids convergence:  1.6986865756844054
Purity:  7.558790402858739  Centroids convergence:  2.8300456869510526
Purity:  7.5666057351014855  Centroids convergence:  5.180336900644761
Purity:  7.566982556559942  Centroids convergence:  1.7703563416289398
Purity:  7.566982556559942  Centroids convergence:  0.0

En el paso 2 anterior (asignación de centroides), se puede usar la distancia L1 ó L2 (cuadrática). Para la diferencia cuadrática, simplemente cambiar la línea por la siguiente: distances_from_centroid.append(np.sqrt(np.sum(np.abs(centroids[i, :] - train_images_rows[j,:]))))

6. Evaluación

La pureza (purity) la obtuvimos en el paso anterior. Cada cluster no sigue el orden de las etiquetas del dataset, es decir, el cluster 0 no corresponde a la clase airplane. Mostramos las 25 primeras imágenes del Cluster 0, aunque esto no es representativo ya que la precisión es baja y no estamos viendo todas las imágenes del cluster 0:

CENTROID_TO_EVALUATE = 0 # Cluster a mostrar
train_images_rows_img = train_images_rows.reshape(train_images_rows.shape[0], 32, 32, 3)

plt.figure(figsize=(10,10))
num = 0
for i in range(NUM_IMAGES):
    if (num == 25):
        break
    if (cercanos[i] == CENTROID_TO_EVALUATE):
        plt.subplot(5,5,num+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(train_images_rows_img[i], cmap=plt.cm.binary)
        plt.xlabel(cercanos[i])
        num += 1
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:13:08.355791 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Como curiosidad, vemos los 10 centroides generados y mostramos sus imágenes:

centroides_img = nuevos_centroides.reshape(nuevos_centroides.shape[0], 32, 32, 3)

plt.figure(figsize=(10,10))
for c in range(NUM_CENTROIDS):
    plt.subplot(5,5,c+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(centroides_img[c], cmap=plt.cm.binary)
    # The CIFAR labels happen to be arrays, 
    # which is why you need the extra index
    plt.xlabel(c)
plt.show()
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:13:11.155897 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

7. Comprobación con scikit-learn

Vamos a comprobar el resultado y la calcular la pureza (purity) con la librería scikit-learn, usando la clase sklearn.cluster.KMeans, para compararla con nuestro resultado anterior

from sklearn.cluster import KMeans
n = 10
k_means = KMeans(n_clusters=n)
k_means.fit(train_images_rows)
KMeans(n_clusters=10)
centroids = k_means.cluster_centers_
labels= k_means.labels_
Z = k_means.predict(train_images_rows)

purity = 0
best = []
for i in range(0,n):

    row = np.where(Z==i)[0]  # row in Z for elements of cluster i
    num = row.shape[0]       #  number of elements for each cluster
    r = np.floor(num/10.)    # number of rows in the figure of the cluster 

    print("Cluster "+str(i)+" has "+str(num)+" elements")

    plt.figure(figsize=(10,10))
    for k in range(0, num):
        plt.subplot(r+1, 10, k+1)
        image = train_images_rows[row[k], ]
        image = image.reshape(32, 32, 3)
        plt.imshow(image, cmap='gray')
        plt.axis('off')
        best = np.append(best, train_labels[row[k]])
    unique, counts = np.unique(best, return_counts=True)
    #print(str(counts.max())+" out of "+str(num))
    purity += counts.max()/num
    print("Purity: "+str(purity))
  
    plt.show()
Cluster 0 has 194 elements
Purity: 0.22164948453608246
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:13:21.549469 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
Cluster 1 has 207 elements
Purity: 0.525997310623039
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:13:30.297551 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
Cluster 2 has 273 elements
Purity: 0.8776456622713906
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:13:41.986646 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
Cluster 3 has 169 elements
Purity: 1.6113734729222782
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:13:49.847821 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
Cluster 4 has 113 elements
Purity: 2.806063738409004
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:13:54.986668 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
Cluster 5 has 177 elements
Purity: 3.591374472872281
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:14:02.281463 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
Cluster 6 has 161 elements
Purity: 4.541685031878492
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:14:09.478339 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
Cluster 7 has 288 elements
Purity: 5.159740587434047
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:14:20.742189 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
Cluster 8 has 196 elements
Purity: 6.175046709883027
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:14:29.637466 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
Cluster 9 has 222 elements
Purity: 7.161533196369514
<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> 2020-11-02T23:14:39.156964 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/