Come usare una callback per stoppare il training a performance adeguata

In questo articolo spiegherò come controllare il training di una rete neurale in Tensorflow attraverso l'utilizzo di callback. Una callback è una funzione che viene chiamata ripetutamente durante un processo (ad esempio l'addestramento di una rete neurale) e che generalmente serve per validare o correggere determinati comportamenti.

Nel machine learning, possiamo usare delle callback per definire cosa succede prima, durante o alla fine di un'epoca. Questo è utile soprattutto per fare logging delle performance oppure per interrompere il training se la nostra metrica di performance raggiunge una certa soglia. Questo meccanismo si chiama early stopping. Ad esempio, se si impostano 1000 epoche e la precisione desiderata è già stata raggiunta all'epoca 200, l'addestramento si interromperà automaticamente. Vediamo come questo viene implementato in Tensorflow e Python.

Gettiamo le basi andando a importare il dataset fashion_mnist da Tensorflow. Useremo questo dataset per spiegare come funzionano le callback.

Ecco come appare il nostro dataset

Sample images from Fashion-MNIST dataset.
Esempio del dataset Fashion MNIST. Un dataset di indumenti etichettati da sfruttare per task di computer vision

‍

La classe EarlyStopping

Il secondo step è la creazione della classe dedicata all'early stopping. In questo caso creeremo una classe che erediterà da tf.keras.callbacks.Callback e permetterà di fermare il training al raggiungimento del 95% di accuracy. La callback userà la funzione on_epoch_end per fermare il training se la condizione è soddisfatta andando a guardare i log erogati dal modello di Tensorflow.

In pratica qui non facciamo altro che accedere al metodo on_epoch_end che è ereditato da tf.keras.callbacks.Callback e fare override il suo comportamento andando ad inserire la condizione che farà interrompere il training.

Continuiamo con l'implementazione del codice del nostro modello.

Classificazione con deep neural network

Useremo una rete neurale con diversi strati per classificare gli indumenti del dataset. L'approccio migliore sarebbe quello di usare una rete neurale convoluzionale, ma per questo esempio una deep neural network andrà più che bene.

Ora siamo pronti ad addestrare il modello. Per passare la callback, basta inserire il nostro oggetto nella lista da fornire all'argomento callbacks nella funzione .fit() del modello.

L'addestramento viene interrotto perché abbiamo raggiunto il 95% di accuracy

Ecco come si imposta una callback per controllare il training di una rete neurale.

Codice completo

Andrea D'Agostino
Ciao, sono Andrea D'Agostino e sono un data scientist con 6 anni di esperienza nel campo della business intelligence. Applico tecniche statistiche e di machine learning per aiutare i clienti a trovare e risolvere problemi nei loro asset digitali e a sfruttare le debolezze dei competitor a loro vantaggio.

Sono il fondatore e l'autore di questo blog, il cui obiettivo è raccogliere le informazioni più importanti che ho imparato durante il mio percorso lavorativo e accademico al fine di poter aiutare il lettore a migliorare le sue analisi.