Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Select an option

  • Save Shenziger/d44c08d28848b2ae9c05a3b95013e0d1 to your computer and use it in GitHub Desktop.

Select an option

Save Shenziger/d44c08d28848b2ae9c05a3b95013e0d1 to your computer and use it in GitHub Desktop.

Revisions

  1. @nuzrub nuzrub created this gist Apr 30, 2020.
    126 changes: 126 additions & 0 deletions tensorflow2_customloops.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,126 @@
    # Author: Ygor Rebouças
    #


    ### The Training Loop
    #
    # 0) Imports
    import tensorflow as tf
    import numpy as np

    # 1) Dataset loading and preparation
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
    X_train = X_train.astype('float32') / 255
    X_test = X_test.astype('float32') / 255
    y_train = tf.keras.utils.to_categorical(y_train, 10)
    y_test = tf.keras.utils.to_categorical(y_test, 10)

    # 2) Model loading / creation
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.Input(shape=X_train.shape[1:]))
    for n_filters in [32, 64, 128]:
    model.add(tf.keras.layers.Conv2D(n_filters, (3, 3), padding='same', use_bias=False))
    model.add(tf.keras.layers.BatchNormalization())
    model.add(tf.keras.layers.Activation('elu'))
    model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
    model.add(tf.keras.layers.Flatten())
    model.add(tf.keras.layers.Dense(512, activation='elu'))
    model.add(tf.keras.layers.Dense(10, activation='softmax'))

    # 3) Compile and fit
    model.compile(loss='categorical_crossentropy', optimizer='adam')
    model.fit(x=X_train, y=y_train, validation_data=(X_test, y_test), batch_size=128, epochs=10, shuffle=True)




    ### The Custom Loop
    # The train_on_batch function
    loss = tf.keras.losses.categorical_crossentropy
    optimizer = tf.keras.optimizers.Adam()
    def train_on_batch(X, y):
    with tf.GradientTape() as tape:
    ŷ = model(X, training=True)
    loss_value = loss(y, ŷ)

    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    train_on_batch(X_train[0:128], y_train[0:128])

    # The validate_on_batch function
    def validate_on_batch(X, y):
    ŷ = model(X, training=False)
    loss_value = loss(y, ŷ)
    return loss_value

    validate_on_batch(X_test[0:128], y_test[0:128])

    # Putting it all together
    loss = tf.keras.losses.categorical_crossentropy
    optimizer = tf.keras.optimizers.Adam(0.001)
    batch_size = 1024
    epochs = 10

    for epoch in range(0, epochs):
    for i in range(0, len(X_train) // batch_size):
    X = X_train[i * batch_size:min(len(X_train), (i+1) * batch_size)]
    y = y_train[i * batch_size:min(len(y_train), (i+1) * batch_size)]
    train_on_batch(X, y)

    val_loss = []
    for i in range(0, len(X_test) // batch_size):
    X = X_test[i * batch_size:min(len(X_test), (i+1) * batch_size)]
    y = y_test[i * batch_size:min(len(y_test), (i+1) * batch_size)]
    val_loss.append(validate_on_batch(X, y))

    print('Validation Loss: ' + str(np.mean(val_loss)))




    ## Improving the Loop
    # The Dataset API
    train_data = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(buffer_size=len(X_train)).batch(batch_size)
    test_data = tf.data.Dataset.from_tensor_slices((X_test, y_test)).shuffle(buffer_size=len(X_test)).batch(batch_size)

    # Enumerating the Dataset
    for epoch in range(0, epochs):
    for batch, (X, y) in enumerate(train_data):
    train_on_batch(X, y)

    val_loss = []
    for batch, (X, y) in enumerate(test_data):
    val_loss.append(validate_on_batch(X, y))

    print('Validation Loss: ' + str(np.mean(val_loss)))

    # Model Checkpointing and better prints
    best_loss = 99999
    for epoch in range(0, epochs):
    for batch, (X, y) in enumerate(train_data):
    train_on_batch(X, y)
    print('\rEpoch [%d/%d] Batch: %d%s' % (epoch + 1, epochs, batch, '.' * (batch % 10)), end='')

    val_loss = np.mean([np.mean(validate_on_batch(X, y)) for (X, y) in test_data])
    print('. Validation Loss: ' + str(val_loss))
    if val_loss < best_loss:
    model.save_weights('model.h5')
    best_loss = val_loss


    ### The tf.function
    @tf.function
    def train_on_batch(X, y):
    with tf.GradientTape() as tape:
    ŷ = model(X, training=True)
    loss_value = loss(y, ŷ)

    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

    @tf.function
    def validate_on_batch(X, y):
    ŷ = model(X, training=False)
    loss_value = loss(y, ŷ)
    return loss_value