Sleep Stage Classification

My work at the startup paq wearables has been focused on developing the machine learning model behind our second product: a wearable device connected to an app that predicts sleep stages — and ultimately detects the likelihood of sleep disorders such as obstructive sleep apnea (OSA). The device forgoes taking obtrusive measurements of electrical activity in the brain, and instead measures other biometric signals such as heart rate, respiratory rate, SpO2 levels, and motion.

This is a sequence classification problem over time, as each subject sleeps for a varying amount of time, and we want to classify the sleep stage for each 30-second epoch of their sleep time. My research led me to use a bidirectional LSTM-RNN for this problem.

Each LSTM cell has a “forget gate” with selective memory to combat the exploding/vanishing gradient problem.

Data Preprocessing

I trained the LSTM-RNN using the MESA dataset, which contains sleep data from 2000 subjects who underwent both polysomngraphy and actigraphy, allowing us to extract heart rate, SpO2 levels, and activity (motion counts) features for one night of sleep. Missing values were imputed with the median value.

Respiratory rates needed to be reverse-engineered, but since respiratory rates have a correlation with heart rate and more importantly with sleep stage (i.e. REM corresponds to the lowest respiratory rate with the highest variability), I generated plausible respiratory rates for each subject with enough noise to prevent overfitting.

Update (5/10/20): even though sleep stages being labeled every 30 seconds by a sleep clinician is considered the “gold standard” in polysomnography, an LSTM-RNN is best trained on as much good data as possible. All of the other features are sampled at a frequency of at least 1 Hz (some as high as 50-100 Hz), so only sampling the other features every 30 seconds to match the sleep stage labels is throwing out lots of valuable data and patterns occurring seconds apart.

Therefore I am upsampling the sleep stage labels to 1 Hz using midpoints between the current epochs: someone in NREM at time = 30 seconds will also be considered to be in NREM between time = 15.0, 16.0, … , 44.0 seconds as well. This is a reasonable assumption and even if it is not entirely correct (perhaps in reality the patient actually enters REM at time 40.0 seconds), the benefits of training the LSTM on a data set potentially 100x is enormous for the robustness of the model.

The sleep data for subject 0052 (shown above) is not normalized, but all features were normalized between 0 and 1 before being fed into the LSTM-RNN. Also shown are the Gaussian smoothed Heart Rate and SpO2 features. Due to size constraints of Chart Studio, observations are only shown every 20 seconds but the data used to train the LSTM is upsampled to have observations every 1 second.

Building the Machine Learning Model

I built the bidirectional LSTM-RNN in Keras. There were several considerations to make:

  • padding the features and output labels to account for unequal timesteps

  • one hot encoding output labels

  • using the Softmax function since this is a classification problem and the model is outputting probabilities for each label, e.g. [Wake, NREM, REM] —> [0.50, 0.20, 0.30]

  • calculating class weights for the model since the subjects spend the majority of their time in Wake or NREM (and REM is the minority class by far) so this is also a class imbalance problem. As the LSTM-RNN learns it assigns higher weights priority to more minority classes

Overall test accuracy of 89% is pretty good, but given that ‘Wake’ and ‘NREM’ make up the majority of the output labels, the model could achieve around 70% accuracy without even being able to predict ‘REM’. This leads us to the next section…

Visualization & Metrics

Below we see that the model is capable of predicting the REM stages, albeit not all the brief wakes.

To compare actual vs. predicted sleep stages at any point in time, toggle the button: [Compare data on hover]

For any classification problem we should look at a confusion matrix, but for a class imbalance, we should also consider a classification report which includes the metrics: precision, recall, and F1 score.

Confusion Matrix:

A closer look at the visualizations of predicted vs. actual sleep stages of the individual subjects (such as subject 0052) suggest that the LSTM-RNN tends to misclassify REM when a subject switches in and out of REM.

A closer look at the visualizations of predicted vs. actual sleep stages of the individual subjects (such as subject 0052) suggest that the LSTM-RNN tends to misclassify REM when a subject switches in and out of REM.

Classification Report:

Precision Recall F1-score Support
Wake 0.94 0.91 0.92 18234
NREM 0.91 0.88 0.90 20764
REM 0.68 0.85 0.76 4864
accuracy 0.89 43862
macro avg 0.84 0.88 0.86 43862
weighted avg 0.90 0.89 0.89 43862
Previous
Previous

Trigonometric Polynomial Least Squares App

Next
Next

Homelessness in San Diego