Understanding Model Training: From Hyperparameter Tuning to Result Evaluationposted in tech
Welcome back to our blog series about machine learning projects. Are you involved in planning an AI project? Then you're in the right place. We explain all the project phases in a series of blog posts.
Today, we're going to dive into the actual model training and finish phase 3 of our machine learning life cycle.
Our goal is to build the best possible model for our image classification task where we have to classify images as aesthetic or unaesthetic. These images are about dance and should be appropriate for a dance federation's marketing.
In our previous article about model evaluation, we have chosen our initial machine learning model. Now, we want to train this model with our 5,550 images.
Hyperparameters are used to control the training process, consequently they have to be set before the training starts. The process of finding the optimal hyperparameters to maximize model performance is called hyperparameter tuning.
Hyperparameters can be tuned manually or automatically using algorithms like Grid Search or Random Search. These algorithms test different parameters using a certain strategy, train several models and try to find the best hyperparameters. If we train a model from scratch this is very expensive for computer vision tasks, since huge amounts of data is required. Although in our example we can benefit from transfer-learning and therefore only need some thousands images, we want to focus on manual hyperparameter tuning in this blogpost.
Different models have different hyperparameters. Let's have a look at the most common hyperparameters we use in our projects.
1. Optimization algorithm
During training the optimization algorithm calculates the needed changes to the model weights in order to minimize the loss gradient (the loss is an indicator for the faultiness of the model predictions).
The most common optimization algorithm is the stochastic gradient descent, SGD in short. It calculates the needed changes to model weights based on some random samples of the training set (hence the name stochastic) to find a local minimum. Since SGD involves only some samples we can reduce the calculations enormously and speed up training.
There's another optimization algorithm called Adam which has a striking advantage: Adam automatically adapts the learning rate (the most important hyperparameter, see below) resulting in better training with fewer experiments. For our current classification task we choose Adam.
2. Batch size
The batch size defines the number of training samples propagated to the neural network before updating its model weights using the chosen optimization algorithm.
We can't pass all our training samples into the neural network at once. Therefore, we divide it into batches. In our example project we have 5,550 training samples and we set up a batch size of 50. That means 50 samples are propagated to the neural network and after that the model updates its weights based on the loss gradients of these 50 samples.
Then it takes the next 50 samples (from the 51th to the 100th) and trains the network again. This is repeated until all samples have been propagated to the network. The smaller the batch size, the more frequently the model weights are updated. The larger the batch size, the less frequently the model weights are updated.
If the batch size is very large, this can reduce the quality of the model, measured by its generalization ability. In addition, a lot of memory is needed.
The advantage of a small batch size is that the network usually trains faster and it also requires less memory.
3. Learning rate
The learning rate defines how much the model weights will be adjusted after each batch with respect to the loss gradient.
Usually we use an inital learning rate of 0.01 or 0.001. There is no perfect learning rate and also no perfect value to start with. In general, if the learning rate is too low or too high, the model will not learn at all.
Additionally, the learning rate can be changed over time using a decay rate. This means that the model adapts very strongly at the beginning and the longer it has learned, the less it adapts.
We do not want our model to overfit or underfit. If the model is not complex enough for the learning problem it underfits and we need to choose a more complex model.
If the model is too complex it overfits the training data leading to a poor ability to generalize on unseen data. In this scenario we can apply regularization techniques to force the learning algorithm to build a less complex model.
The dropout regularizer temporarily excludes some random units of the network while training. This helps to prevent the model to just memorize the training data which leads to overfitting.
There are other neural network specific regularizers like early-stopping and batch-normalization.
After setting the hyperparameters as described above we start training the model. The model accuracy we've reached after just some epochs is 86% which is already pretty good and signals that the model is complex enough and does not underfit.
Result Evaluation on Unseen Data
Now, it's time to test how well the model generalizes. We do this by comparing the training accuracy to the accuracy on unseen data, the validation set.
Both accuracies should be close to each other. Then, the model is good at transferring what it has learned from the training data to unknown data. If the validation accuracy far below the training accuracy, it means the model overfits and cannot generalize well. Since our validation accuracy is only 60% our model overfits and we have to tune hyperparameter accordingly.
In our case we introduce dropout as regularizer to avoid overfitting. Then we train again and measure the evaluate the results.
In the next article we'll face the reality check: Is our model already good enough to go into productive use? We'll try to understand the behavior of our model, which aspects are already very well solved and which are not, and most importantly: why? Stay tuned!
If you need help with your own AI project, just get in touch.