Sleep Stage Classification
My work at the startup paq wearables has been focused on developing the machine learning model behind our second product: a wearable device connected to an app that predicts sleep stages — and ultimately detects the likelihood of sleep disorders such as obstructive sleep apnea (OSA). The device forgoes taking obtrusive measurements of electrical activity in the brain, and instead measures other biometric signals such as heart rate, respiratory rate, SpO2 levels, and motion.
This is a sequence classification problem over time, as each subject sleeps for a varying amount of time, and we want to classify the sleep stage for each 30-second epoch of their sleep time. My research led me to use a bidirectional LSTM-RNN for this problem.
Data Preprocessing
I trained the LSTM-RNN using the MESA dataset, which contains sleep data from 2000 subjects who underwent both polysomngraphy and actigraphy, allowing us to extract heart rate, SpO2 levels, and activity (motion counts) features for one night of sleep. Missing values were imputed with the median value.
Respiratory rates needed to be reverse-engineered, but since respiratory rates have a correlation with heart rate and more importantly with sleep stage (i.e. REM corresponds to the lowest respiratory rate with the highest variability), I generated plausible respiratory rates for each subject with enough noise to prevent overfitting.
Update (5/10/20): even though sleep stages being labeled every 30 seconds by a sleep clinician is considered the “gold standard” in polysomnography, an LSTM-RNN is best trained on as much good data as possible. All of the other features are sampled at a frequency of at least 1 Hz (some as high as 50-100 Hz), so only sampling the other features every 30 seconds to match the sleep stage labels is throwing out lots of valuable data and patterns occurring seconds apart.
Therefore I am upsampling the sleep stage labels to 1 Hz using midpoints between the current epochs: someone in NREM at time = 30 seconds will also be considered to be in NREM between time = 15.0, 16.0, … , 44.0 seconds as well. This is a reasonable assumption and even if it is not entirely correct (perhaps in reality the patient actually enters REM at time 40.0 seconds), the benefits of training the LSTM on a data set potentially 100x is enormous for the robustness of the model.
The sleep data for subject 0052 (shown above) is not normalized, but all features were normalized between 0 and 1 before being fed into the LSTM-RNN. Also shown are the Gaussian smoothed Heart Rate and SpO2 features. Due to size constraints of Chart Studio, observations are only shown every 20 seconds but the data used to train the LSTM is upsampled to have observations every 1 second.
Building the Machine Learning Model
I built the bidirectional LSTM-RNN in Keras. There were several considerations to make:
padding the features and output labels to account for unequal timesteps
one hot encoding output labels
using the Softmax function since this is a classification problem and the model is outputting probabilities for each label, e.g. [Wake, NREM, REM] —> [0.50, 0.20, 0.30]
calculating class weights for the model since the subjects spend the majority of their time in Wake or NREM (and REM is the minority class by far) so this is also a class imbalance problem. As the LSTM-RNN learns it assigns higher weights priority to more minority classes
Visualization & Metrics
Below we see that the model is capable of predicting the REM stages, albeit not all the brief wakes.
To compare actual vs. predicted sleep stages at any point in time, toggle the button: [Compare data on hover]
For any classification problem we should look at a confusion matrix, but for a class imbalance, we should also consider a classification report which includes the metrics: precision, recall, and F1 score.
Confusion Matrix:
Classification Report:
Precision | Recall | F1-score | Support | |
---|---|---|---|---|
Wake | 0.94 | 0.91 | 0.92 | 18234 |
NREM | 0.91 | 0.88 | 0.90 | 20764 |
REM | 0.68 | 0.85 | 0.76 | 4864 |
accuracy | 0.89 | 43862 | ||
macro avg | 0.84 | 0.88 | 0.86 | 43862 |
weighted avg | 0.90 | 0.89 | 0.89 | 43862 |