Back to Postgresml

Classification

pgml-cms/docs/open-source/pgml/guides/supervised-learning/classification.md

2.10.07.6 KB
Original Source

Classification

Example

This example trains models on the sklean digits dataset which is a copy of the test set of the UCI ML hand-written digits datasets. This demonstrates using a table with a single array feature column for classification. You could do something similar with a vector column.

postgresql
-- load the sklearn digits dataset
SELECT pgml.load_dataset('digits');

-- view the dataset
SELECT left(image::text, 40) || ',...}', target FROM pgml.digits LIMIT 10;

-- train a simple model to classify the data
SELECT * FROM pgml.train('Handwritten Digits', 'classification', 'pgml.digits', 'target');

-- check out the predictions
SELECT target, pgml.predict('Handwritten Digits', image) AS prediction
FROM pgml.digits 
LIMIT 10;

-- view raw class probabilities
SELECT target, pgml.predict_proba('Handwritten Digits', image) AS prediction
FROM pgml.digits
LIMIT 10;

Algorithms

We currently support classification algorithms from scikit-learn, XGBoost, LightGBM and Catboost.

Gradient Boosting

AlgorithmReference
xgboostXGBClassifier
xgboost_random_forestXGBRFClassifier
lightgbmLGBMClassifier
catboostCatBoostClassifier

Examples

postgresql
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'xgboost_random_forest', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'lightgbm', hyperparams => '{"n_estimators": 1}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'catboost', hyperparams => '{"n_estimators": 1}');

Scikit Ensembles

AlgorithmReference
ada_boostAdaBoostClassifier
baggingBaggingClassifier
extra_treesExtraTreesClassifier
gradient_boosting_treesGradientBoostingClassifier
random_forestRandomForestClassifier
hist_gradient_boostingHistGradientBoostingClassifier

Examples

postgresql
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'ada_boost');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'bagging');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'extra_trees', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'gradient_boosting_trees', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'random_forest', hyperparams => '{"n_estimators": 10}');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'hist_gradient_boosting', hyperparams => '{"max_iter": 2}');

Support Vector Machines

AlgorithmReference
svmSVC
nu_svmNuSVC
linear_svmLinearSVC

Examples

postgresql
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'svm');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'nu_svm');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'linear_svm');

Linear Models

AlgorithmReference
linearLogisticRegression
ridgeRidgeClassifier
stochastic_gradient_descentSGDClassifier
perceptronPerceptron
passive_aggressivePassiveAggressiveClassifier

Examples

postgresql
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'ridge');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'stochastic_gradient_descent');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'perceptron');
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'passive_aggressive');

Other

AlgorithmReference
gaussian_processGaussianProcessClassifier

Examples

postgresql
SELECT * FROM pgml.train('Handwritten Digits', algorithm => 'gaussian_process', hyperparams => '{"max_iter_predict": 100, "warm_start": true}');