-
Understand the basics of machine learning techniques and the reasons behind why they are useful for solving clinical prediction problems.
-
Understand the intuition behind some machine learning models, including regression, decision trees, and support vector machines.
-
Understand how to apply these models to clinical prediction problems using publicly available datasets via case studies.
12.1 Why Machine Learning?
12.2 General Concepts of Learning
12.2.1 Learning Scenario for Clinical Prediction
-
Define the outcome of your task
-
Consult with domain experts to identify important features/variables
-
Select an appropriate algorithm (or design a new machine learning algorithm) with a suitable parameter selection
-
Find an optimized model with a subset of data (training data) with the algorithm
-
Evaluate the model with another subset of data (testing data) with appropriate metrics
-
Deploy the prediction model on real-world data.
12.2.2 Machine Learning Scenarios
12.2.2.1 Supervised Learning
12.2.2.2 Unsupervised Learning
12.2.2.3 Other Scenario
12.2.3 Find the Best Function
Task | Error type | Loss function | Note |
---|---|---|---|
Regression | Mean-squared error | \(\frac{1}{n} \sum _{i=1}^n (y_i - \hat{y_i})^2\) | Easy to learn but sensitive to outliers (MSE, L2 loss) |
Mean absolute error | \(\frac{1}{n} \sum _{i=1}^n |y_i - \hat{y_i}|\) | Robust to outliers but not differentiable (MAE, L1 loss) | |
Classification | Cross entropy = Log loss | \(- \frac{1}{n} \sum _{i=1}^n [y_i \log (\hat{y_i}) + (1-y_i) \log (1-\hat{y_i})] = - \frac{1}{n} \sum _{i=1}^n p_i \log q_i\) | Quantify the difference between two probability distributions |
Hinge loss | \(\frac{1}{n} \sum _{i=1}^n max(0, 1 - y_i \hat{y_i})\) | For support vector machine | |
KL divergence | \(D_{KL}(p||q) = \sum _{i} p_i (\log \frac{p_i}{q_i})\) | Quantify the difference between two probability distributions |
Predicted | ||||
---|---|---|---|---|
True | False | |||
Actual | True | True positive (TP) | False negative (FN) Type II error | Recall = Sensitivity = \(\frac{\mathrm {TP}}{\mathrm{TP}\,+\,\mathrm{FN}}\) |
False | False positive (FP) Type I error | True negative (TN) | Specificity = \(\frac{\mathrm {TN}}{\mathrm{TN}\,+\,\mathrm{FP}}\) | |
Precision = \(\frac{\mathrm {TP}}{\mathrm{TP}\,+\,\mathrm{FP}}\) | Accuracy = \(\frac{\mathrm{TP}\,+\,\mathrm{TN}}{\mathrm{TP}\,+\,\mathrm{TN}\,+\,\mathrm{FP}\,+\,\mathrm{FN}}\) F1 = \(\frac{\mathrm {2} \times \mathrm {Precision}\,\times \,\mathrm {Recall}}{{\mathrm{Precision}\,+\,\mathrm{Recall}}}\) |
12.2.4 Metrics
12.2.4.1 Supervised Learning
12.2.4.2 Unsupervised Learning
12.2.5 Model Validation
-
Training set for model training. You will run the selected machine learning algorithm only on this subset.
-
Development (a.k.a. dev, validation) set, also called hold-out, for parameter tuning and feature selection. This subset is only for optimization and model validation.
-
Testing set for evaluating model performance. We only apply the model for prediction here, but won’t change any content in the model at this moment.
-
It is better to have your training, dev and testing sets all from the same data distribution instead of having them too different (e.g. training/dev on male patients but testing on female patients), otherwise you may face the problem of overfitting, in which your model will fit the data too well in training or dev sets but find it difficult to generalize to the test data. In this situation, the trained model will not be able to be applied to other cases.
-
It is important to prevent using any data in the dev set or testing set for model training. Test data leakage, i.e. having part of testing data from training data, may cause the overfitting of the model to your test data and erroneously gives you a high performance but a bad model.
12.2.5.1 Cross-Validation
12.2.6 Diagnostics
12.2.6.1 Bias and Variance
12.2.6.2 Regularization
Training error | Validation error | Approach | |
---|---|---|---|
High bias | High | Low | Increase complexity |
High variance | Low | High | Decrease complexity Add more data |
Regularization | Equation |
---|---|
L1 (LASSO) | \(\sum _{i=1}^m (y_i - \sum _{j=1}^n \beta _j x_{ij})^2 + \lambda \sum _{j=1}^n | \beta _j|\) |
L2 (Ridge) | \(\sum _{i=1}^m (y_i - \sum _{j=1}^n \beta _j x_{ij})^2 + \lambda \sum _{j=1}^n \beta _j^2\) |
12.2.7 Error Analysis
12.2.8 Ablation Analysis
12.3 Learning Algorithms
12.3.1 Supervised Learning
12.3.1.1 Linear Models
12.3.1.2 Tree-Based Models
-
It looks across all possible thresholds across all possible features and picks the single feature split that best separates the data
-
The data is split on that feature at a specific threshold that yields the highest performance
-
It iteratively repeats the above two steps until reaching the maximal tree depth, or until all the leaves are pure.
-
Splitting criteria: by Gini index or entropy
-
Tree size: tree depth, tree pruning
-
Number of samples: minimal samples in a leaf, or minimal sample to split a node.
-
Pick a random subset of features
-
Create a bootstrap sample of data (randomly resample the data)
-
Build a decision tree on this data
-
Iteratively perform the above steps until termination.
12.3.1.3 Support Vector Machine (SVM)
12.3.2 Unsupervised Learning
12.3.2.1 Clustering
-
Randomly initializing k points as the centroids of the k clusters
-
Assigning data points to the nearest centroid and forming clusters
-
Recomputing and updating centroids based on the mean value of data points in the cluster
-
Repeating step 2 and 3 until there is convergence.
12.3.2.2 Dimensionality Reduction
12.4 Programming Exercise
-
Breast Cancer Wisconsin (Diagnostic) Database
-
Preprocessed ICU data from PhysioNet Challenge 2012 Database.
-
Learn how to use Google colab/Jupyter notebook
-
Learn how to build and diagnose machine learning models for clinical classification and clustering tasks.