In questo articolo spiegherò come controllare il training di una rete neurale in TensorFlow attraverso l'utilizzo di callback.

Una callback è una funzione che viene chiamata ripetutamente durante un processo (ad esempio l'addestramento di una rete neurale) e che generalmente serve per validare o correggere determinati comportamenti.

Nel machine learning, possiamo usare delle callback per definire cosa succede prima, durante o alla fine di un'epoca.

Questo è utile soprattutto per fare logging delle performance oppure per interrompere il training se la nostra metrica di performance raggiunge una certa soglia. Questo meccanismo si chiama early stopping.

Ad esempio, se si impostano 1000 epoche e la precisione desiderata è già stata raggiunta all'epoca 200, l'addestramento si interromperà automaticamente. Vediamo come questo viene implementato in TensorFlow e Python.

Perché usare l'early stopping?

L'early stopping è uno dei metodi di regolarizzazione per reti neurali più comuni e efficaci.

Infatti, grazie ad esso, possiamo evitare overfitting e underfitting dei nostri dati. Senza andare nel dettaglio in questo articolo, quando il nostro modello overfitta i nostri dati, allora quel modello non è in grado di generalizzare la mondo reale poiché troppo sensibile ai dati di training.

In maniera complementare, un modello che underfitta sarà troppo generico e non sarà in grado mappare adeguatamente il nostro input al nostro output.

Il lettore interessato all'argomento può espandere la sua conoscenza su questi due fenomeni leggendo un articolo specifico sull'overfitting.

Gettiamo le basi andando a importare il dataset fashion_mnist da TensorFlow. Useremo questo dataset per spiegare come funzionano le callback.

import tensorflow as tf

# Importiamo il dataset dall'API di tensorflow
fmnist = tf.keras.datasets.fashion_mnist

# Carichiamo il dataset
(x_train, y_train),(x_test, y_test) = fmnist.load_data()

# Normalizziamo le immagini
x_train, x_test = x_train / 255.0, x_test / 255.0

Ecco come appare il nostro dataset

Sample images from Fashion-MNIST dataset.
Esempio del dataset Fashion MNIST. Un dataset di indumenti etichettati da sfruttare per task di computer vision

‍La classe EarlyStopping

Il secondo step è la creazione della classe dedicata all'early stopping. In questo caso creeremo una classe che erediterà da tf.keras.callbacks.Callback e permetterà di fermare il training al raggiungimento del 95% di accuracy. La callback userà la funzione on_epoch_end per fermare il training se la condizione è soddisfatta andando a guardare i log erogati dal modello di TensorFlow.

In pratica qui non facciamo altro che accedere al metodo on_epoch_end che è ereditato da tf.keras.callbacks.Callback e fare override il suo comportamento andando ad inserire la condizione che farà interrompere il training.

Continuiamo con l'implementazione del codice del nostro modello.

class EarlyStopping(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
    '''
    Interrompe l'addestramento al raggiungimento del 95% di accuracy
    '''

    # Controlliamo l'accuracy
    if(logs.get('accuracy') > 0.95):

      # Fermiamo il training se la condizione è soddisfatta
      print("\nSoglia di accuracy raggiunta. Training interrotto!")
      self.model.stop_training = True

# Creiamo un oggetto della nostra classe e assegnamolo ad una variabile
early_stopping = EarlyStopping()

Classificazione con deep neural network

Useremo una rete neurale con diversi strati per classificare gli indumenti del dataset. L'approccio migliore sarebbe quello di usare una rete neurale convoluzionale, ma per questo esempio una deep neural network andrà più che bene.

# Creiamo un modello sequenziale con tre strati
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)), # appiattiamo l'input
  tf.keras.layers.Dense(512, activation="relu"), # creiamo uno strato denso di 512 neuroni con attivazione ReLU
  # creiamo lo strato di output 10 neuroni, con attivazione Softmax per mappare 
  # ogni tipo di indumento nel dataset
  tf.keras.layers.Dense(10, activation="softmax") 
])

# Compiliamo il modello con Adam e settiamo le metriche canoniche per un task di classificazione
model.compile(optimizer="adam",
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

Ora siamo pronti ad addestrare il modello. Per passare la callback, basta inserire il nostro oggetto nella lista da fornire all'argomento callbacks nella funzione .fit() del modello.

# Addestriamo il modello con la nostra callback!
model.fit(x_train, y_train, epochs=10, callbacks=[early_stopping])
L'addestramento viene interrotto perché abbiamo raggiunto il 95% di accuracy

Ecco come si imposta una callback per controllare il training di una rete neurale.

Codice completo

import tensorflow as tf

# Importiamo il dataset dall'API di tensorflow
fmnist = tf.keras.datasets.fashion_mnist

# Carichiamo il dataset
(x_train, y_train),(x_test, y_test) = fmnist.load_data()

# Normalizziamo le immagini
x_train, x_test = x_train / 255.0, x_test / 255.0


class EarlyStopping(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
    '''
    Interrompe l'addestramento al raggiungimento del 95% di accuracy
    '''

    # Controlliamo l'accuracy
    if(logs.get('accuracy') > 0.95):

      # Fermiamo il training se la condizione è soddisfatta
      print("\nSoglia di accuracy raggiunta. Training interrotto!")
      self.model.stop_training = True

# Creiamo un oggetto della nostra classe e assegnamolo ad una variabile
early_stopping = EarlyStopping()

# Creiamo un modello sequenziale con tre strati
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)), # appiattiamo l'input
  tf.keras.layers.Dense(512, activation="relu"), # creiamo uno strato denso di 512 neuroni con attivazione ReLU
  # creiamo lo strato di output 10 neuroni, con attivazione Softmax per mappare 
  # ogni tipo di indumento nel dataset
  tf.keras.layers.Dense(10, activation="softmax") 
])

# Compiliamo il modello con Adam e settiamo le metriche canoniche per un task di classificazione
model.compile(optimizer="adam",
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Addestriamo il modello con la nostra callback!
model.fit(x_train, y_train, epochs=10, callbacks=[early_stopping])