Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set baseline expectations for categorical cross-entropy #472

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion episodes/2-keras.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@
The `palmerpenguins` data contains size measurements for three penguin species observed on three islands in the Palmer Archipelago, Antarctica.
The physical attributes measured are flipper length, beak length, beak width, body mass, and sex.

![*Artwork by @allison_horst*][palmer-penguins]

Check warning on line 79 in episodes/2-keras.Rmd

View workflow job for this annotation

GitHub Actions / Build markdown source files if valid

[image missing alt-text]: fig/palmer_penguins.png


![*Artwork by @allison_horst*][penguin-beaks]

Check warning on line 82 in episodes/2-keras.Rmd

View workflow job for this annotation

GitHub Actions / Build markdown source files if valid

[image missing alt-text]: fig/culmen_depth.png


These data were collected from 2007 - 2009 by Dr. Kristen Gorman with the [Palmer Station Long Term Ecological Research Program](https://pal.lternet.edu/), part of the [US Long Term Ecological Research Network](https://lternet.edu/). The data were imported directly from the [Environmental Data Initiative](https://environmentaldatainitiative.org/) (EDI) Data Portal, and are available for use by CC0 license ("No Rights Reserved") in accordance with the [Palmer Station Data Policy](https://pal.lternet.edu/data/policies).
Expand Down Expand Up @@ -140,7 +140,7 @@
sns.pairplot(penguins, hue="species")
```

![][pairplot]

Check warning on line 143 in episodes/2-keras.Rmd

View workflow job for this annotation

GitHub Actions / Build markdown source files if valid

[image missing alt-text]: fig/pairplot.png

::: challenge

Expand All @@ -165,7 +165,7 @@
sns.pairplot(penguins, hue='sex')
```

![][sex_pairplot]

Check warning on line 168 in episodes/2-keras.Rmd

View workflow job for this annotation

GitHub Actions / Build markdown source files if valid

[image missing alt-text]: fig/02_sex_pairplot.png

You see that for each species females have smaller bills and flippers, as well as a smaller body mass.
You would need a combination of the species and the numerical features to successfully distinguish males from females.
Expand Down Expand Up @@ -479,7 +479,7 @@
In Keras this is implemented in the `keras.losses.CategoricalCrossentropy` class.
This loss function works well in combination with the `softmax` activation function
we chose earlier.
The Categorical Crossentropy works by comparing the probabilities that the
The *categorical cross-entropy* works by comparing the probabilities that the
neural network predicts with 'true' probabilities that we generated using the one-hot encoding.
This is a measure for how close the distribution of the three neural network outputs corresponds to the distribution of the three values in the one-hot encoding.
It is lower if the distributions are more similar.
Expand Down Expand Up @@ -519,12 +519,20 @@

The fit method returns a history object that has a history attribute with the training loss and
potentially other metrics per training epoch.

### Setting baseline expectations
What might be a good value for loss here when looking at categorical cross-entropy loss? In a classification context, we can establish a baseline by determining what loss we would get for simply guessing each class randomly. If predictions are completely random, the cross-entropy loss will be approximately log(n), where n is the number of classes. Any useful model should be able to outperform (have a lower loss) this baseline.
```python
import numpy as np
np.log(3)
```

It can be very insightful to plot the training loss to see how the training progresses.
Using seaborn we can do this as follow:
```python
sns.lineplot(x=history.epoch, y=history.history['loss'])
```
![][training_curve]

Check warning on line 535 in episodes/2-keras.Rmd

View workflow job for this annotation

GitHub Actions / Build markdown source files if valid

[image missing alt-text]: fig/02_training_curve.png

This plot can be used to identify whether the training is well configured or whether there
are problems that need to be addressed.
Expand All @@ -543,7 +551,7 @@

3. (optional) Something went wrong here during training. What could be the problem, and how do you see that in the training curve?
Also compare the range on the y-axis with the previous training curve.
![](../fig/02_bad_training_history_1.png){alt='Very jittery training curve with the loss value jumping back and forth between 2 and 4. The range of the y-axis is from 2 to 4, whereas in the previous training curve it was from 0 to 2. The loss seems to decrease a litle bit, but not as much as compared to the previous plot where it dropped to almost 0. The minimum loss in the end is somewhere around 2.'}

Check warning on line 554 in episodes/2-keras.Rmd

View workflow job for this annotation

GitHub Actions / Build markdown source files if valid

[missing file]: [](../fig/02_bad_training_history_1.png)

:::: solution
## Solution
Expand Down Expand Up @@ -683,7 +691,7 @@
```python
sns.heatmap(confusion_df, annot=True)
```
![][confusion_matrix]

Check warning on line 694 in episodes/2-keras.Rmd

View workflow job for this annotation

GitHub Actions / Build markdown source files if valid

[image missing alt-text]: fig/confusion_matrix.png

::: challenge
## Confusion Matrix
Expand Down
Loading