{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## The Effect of Dropout\n", "\n", "Let's see for ourselves how dropout actually affects training. We will use MNIST dataset and a simple convolutional network to do that:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from tensorflow import keras\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "\n", "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", "\n", "x_train = x_train.astype(\"float32\") / 255\n", "x_test = x_test.astype(\"float32\") / 255\n", "\n", "x_train = np.expand_dims(x_train, -1)\n", "x_test = np.expand_dims(x_test, -1)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will define `train` function that will take care of all training process, including:\n", "* Defining the neural network architecture with a given dropout rate `d`\n", "* Specifying suitable training parameters (optimizer and loss function)\n", "* Doing the training and collecting the history\n", "\n", "We will then run this function for a bunch of different dropout values:" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training with dropout = 0\n", "Epoch 1/5\n", "938/938 [==============================] - 26s 27ms/step - loss: 0.1949 - acc: 0.9435 - val_loss: 0.0596 - val_acc: 0.9802\n", "Epoch 2/5\n", "938/938 [==============================] - 27s 29ms/step - loss: 0.0592 - acc: 0.9816 - val_loss: 0.0433 - val_acc: 0.9857\n", "Epoch 3/5\n", "938/938 [==============================] - 26s 28ms/step - loss: 0.0438 - acc: 0.9867 - val_loss: 0.0472 - val_acc: 0.9849\n", "Epoch 4/5\n", "938/938 [==============================] - 27s 28ms/step - loss: 0.0355 - acc: 0.9890 - val_loss: 0.0353 - val_acc: 0.9882\n", "Epoch 5/5\n", "938/938 [==============================] - 26s 28ms/step - loss: 0.0294 - acc: 0.9910 - val_loss: 0.0305 - val_acc: 0.9894\n", "Training with dropout = 0.2\n", "Epoch 1/5\n", "938/938 [==============================] - 29s 31ms/step - loss: 0.2097 - acc: 0.9377 - val_loss: 0.0655 - val_acc: 0.9781\n", "Epoch 2/5\n", "938/938 [==============================] - 31s 33ms/step - loss: 0.0676 - acc: 0.9792 - val_loss: 0.0409 - val_acc: 0.9852\n", "Epoch 3/5\n", "938/938 [==============================] - 28s 30ms/step - loss: 0.0514 - acc: 0.9837 - val_loss: 0.0384 - val_acc: 0.9871\n", "Epoch 4/5\n", "938/938 [==============================] - 28s 29ms/step - loss: 0.0424 - acc: 0.9871 - val_loss: 0.0343 - val_acc: 0.9889\n", "Epoch 5/5\n", "938/938 [==============================] - 30s 32ms/step - loss: 0.0356 - acc: 0.9893 - val_loss: 0.0343 - val_acc: 0.9885\n", "Training with dropout = 0.5\n", "Epoch 1/5\n", "938/938 [==============================] - 30s 31ms/step - loss: 0.2586 - acc: 0.9212 - val_loss: 0.0666 - val_acc: 0.9797\n", "Epoch 2/5\n", "938/938 [==============================] - 28s 30ms/step - loss: 0.0860 - acc: 0.9734 - val_loss: 0.0441 - val_acc: 0.9860\n", "Epoch 3/5\n", "938/938 [==============================] - 29s 31ms/step - loss: 0.0674 - acc: 0.9792 - val_loss: 0.0414 - val_acc: 0.9868\n", "Epoch 4/5\n", "938/938 [==============================] - 30s 32ms/step - loss: 0.0564 - acc: 0.9822 - val_loss: 0.0326 - val_acc: 0.9886\n", "Epoch 5/5\n", "938/938 [==============================] - 29s 31ms/step - loss: 0.0511 - acc: 0.9843 - val_loss: 0.0298 - val_acc: 0.9899\n", "Training with dropout = 0.8\n", "Epoch 1/5\n", "938/938 [==============================] - 31s 32ms/step - loss: 0.3832 - acc: 0.8766 - val_loss: 0.0849 - val_acc: 0.9732\n", "Epoch 2/5\n", "938/938 [==============================] - 29s 31ms/step - loss: 0.1563 - acc: 0.9521 - val_loss: 0.0686 - val_acc: 0.9797\n", "Epoch 3/5\n", "938/938 [==============================] - 32s 34ms/step - loss: 0.1253 - acc: 0.9616 - val_loss: 0.0490 - val_acc: 0.9854\n", "Epoch 4/5\n", "938/938 [==============================] - 33s 35ms/step - loss: 0.1105 - acc: 0.9658 - val_loss: 0.0395 - val_acc: 0.9872\n", "Epoch 5/5\n", "938/938 [==============================] - 34s 36ms/step - loss: 0.1022 - acc: 0.9680 - val_loss: 0.0363 - val_acc: 0.9878\n" ] } ], "source": [ "def train(d):\n", " print(f\"Training with dropout = {d}\")\n", " model = keras.Sequential([\n", " keras.layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\", input_shape=(28,28,1)),\n", " keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", " keras.layers.Conv2D(64, kernel_size=(3, 3), activation=\"relu\"),\n", " keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", " keras.layers.Flatten(),\n", " keras.layers.Dropout(d),\n", " keras.layers.Dense(10, activation=\"softmax\")\n", " ])\n", " model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['acc'])\n", " hist = model.fit(x_train,y_train,validation_data=(x_test,y_test),epochs=5,batch_size=64)\n", " return hist\n", "\n", "res = { d : train(d) for d in [0,0.2,0.5,0.8] }" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's plot validation accuracy graphs for different dropout values to see how fast the training goes:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.legend.Legend at 0x235bc70f0d0>" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "for d,h in res.items():\n", " plt.plot(h.history['val_acc'],label=str(d))\n", "plt.legend()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "From this graph, you would probably be able to see the following:\n", "* Dropout values in the 0.2-0.5 range, you will see the fastest training the best overall results\n", "* Without dropout ($d=0$), you are likely to see less stable and slower training process\n", "* High dropout (0.8) makes things worse" ] } ], "metadata": { "interpreter": { "hash": "86193a1ab0ba47eac1c69c1756090baa3b420b3eea7d4aafab8b85f8b312f0c5" }, "kernelspec": { "display_name": "Python 3.9.5 64-bit ('base': conda)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.12" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }