DLP Ch3 examples
Ch 3 in Deep Learning with Python (DLP) includes 3 examples:
- Classify IMDb movie reviews as positive or negative
- Classify Reuters newswires into 46 different topics
- Predict median home prices in suburbs of Boston
My Python code and output can be found here.
The 1st example is a binary classification problem. For this type of problems, one should end the network with a Dense layer of 1 unit and a sigmoid activation. The usual loss function is binary_crossentropy. Based on the plots, it’s easy to see our model quickly overfits the training data as validation loss and accuracy stop improving.
The 2nd example is a single-label, multiclass classification problem. For this type of problems, one should end the network with a layer where softmax is the activation and the number of units equals to the total number of classes. Also, avoid having many intermediate layers where the number of units is fewer than the total number of classes. The usual loss function is categorical_crossentropy for one-hot label encoding and sparse_categorical_crossentropy for integer label encoding. Again, we can see from the plots our model quickly overfits the training data.
While we use validation data to spot overfitting, we need a way to tell how useful our model is at solving the problem of interest. One way is to create a simple, non-DL model and compare its results with our DL results. For the Reuters example, we simply shuffle the labels and randomly assign them to newswires. This non-DL, baseline model gives us an 18% accuracy rate. This is much worse than the DL results, thus justifying our DL approach.
The last example is a regression problem. Unlike the previous two, home price can be any positive number. Thus, our network ends with a Dense layer of 1 unit and no activation function. The popular loss function for regression is mean squared error (MAE). There are 3 things worth pointing out. 1st, normalize input data before feeding them to any type of DL network. 2nd, use shallower networks to avoid overfitting for small datasets. 3rd, use K-fold cross validation when the dataset is small.