pgml-cms/docs/open-source/pgml/guides/supervised-learning/classification.md
This example trains models on the sklean digits dataset which is a copy of the test set of the UCI ML hand-written digits datasets. This demonstrates using a table with a single array feature column for classification. You could do something similar with a vector column.
-- load the sklearn digits dataset
SELECT pgml.load_dataset('digits');
-- view the dataset
SELECT left(image::text, 40) || ',...}', target FROM pgml.digits LIMIT 10;
-- train a simple model to classify the data
SELECT * FROM pgml.train('Handwritten Digits', 'classification', 'pgml.digits', 'target');
-- check out the predictions
SELECT target, pgml.predict('Handwritten Digits', image) AS prediction
FROM pgml.digits
LIMIT 10;
-- view raw class probabilities
SELECT target, pgml.predict_proba('Handwritten Digits', image) AS prediction
FROM pgml.digits
LIMIT 10;
We currently support classification algorithms from scikit-learn, XGBoost, LightGBM and Catboost.
| Algorithm | Reference |
|---|---|
xgboost | XGBClassifier |
xgboost_random_forest | XGBRFClassifier |
lightgbm | LGBMClassifier |
catboost | CatBoostClassifier |
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost_random_forest', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'lightgbm', hyperparams => '{"n_estimators": 1}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'catboost', hyperparams => '{"n_estimators": 1}');
| Algorithm | Reference |
|---|---|
ada_boost | AdaBoostClassifier |
bagging | BaggingClassifier |
extra_trees | ExtraTreesClassifier |
gradient_boosting_trees | GradientBoostingClassifier |
random_forest | RandomForestClassifier |
hist_gradient_boosting | HistGradientBoostingClassifier |
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'ada_boost');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'bagging');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'extra_trees', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'gradient_boosting_trees', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'random_forest', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'hist_gradient_boosting', hyperparams => '{"max_iter": 2}');
| Algorithm | Reference |
|---|---|
svm | SVC |
nu_svm | NuSVC |
linear_svm | LinearSVC |
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'svm');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'nu_svm');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear_svm');
| Algorithm | Reference |
|---|---|
linear | LogisticRegression |
ridge | RidgeClassifier |
stochastic_gradient_descent | SGDClassifier |
perceptron | Perceptron |
passive_aggressive | PassiveAggressiveClassifier |
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'ridge');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'stochastic_gradient_descent');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'perceptron');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'passive_aggressive');
| Algorithm | Reference |
|---|---|
gaussian_process | GaussianProcessClassifier |
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'gaussian_process', hyperparams => '{"max_iter_predict": 100, "warm_start": true}');