Back to Qlib

Custom Model Integration

docs/start/integration.rst

0.9.77.3 KB
Original Source

======================== Custom Model Integration

Introduction

Qlib's Model Zoo includes models such as LightGBM, MLP, LSTM, etc.. These models are examples of Forecast Model. In addition to the default models Qlib provide, users can integrate their own custom models into Qlib.

Users can integrate their own custom models according to the following steps.

  • Define a custom model class, which should be a subclass of the qlib.model.base.Model <../reference/api.html#module-qlib.model.base>_.
  • Write a configuration file that describes the path and parameters of the custom model.
  • Test the custom model.

Custom Model Class

The Custom models need to inherit qlib.model.base.Model <../reference/api.html#module-qlib.model.base>_ and override the methods in it.

  • Override the __init__ method

    • Qlib passes the initialized parameters to the __init__ method.

    • The hyperparameters of model in the configuration must be consistent with those defined in the __init__ method.

    • Code Example: In the following example, the hyperparameters of model in the configuration file should contain parameters such as loss:mse.

      .. code-block:: Python

        def __init__(self, loss='mse', **kwargs):
            if loss not in {'mse', 'binary'}:
                raise NotImplementedError
            self._scorer = mean_squared_error if loss == 'mse' else roc_auc_score
            self._params.update(objective=loss, **kwargs)
            self._model = None
      
  • Override the fit method

    • Qlib calls the fit method to train the model.

    • The parameters must include training feature dataset, which is designed in the interface.

    • The parameters could include some optional parameters with default values, such as num_boost_round = 1000 for GBDT.

    • Code Example: In the following example, num_boost_round = 1000 is an optional parameter.

      .. code-block:: Python

        def fit(self, dataset: DatasetH, num_boost_round = 1000, **kwargs):
      
            # prepare dataset for lgb training and evaluation
            df_train, df_valid = dataset.prepare(
                ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
            )
            x_train, y_train = df_train["feature"], df_train["label"]
            x_valid, y_valid = df_valid["feature"], df_valid["label"]
      
            # Lightgbm need 1D array as its label
            if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
                y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
            else:
                raise ValueError("LightGBM doesn't support multi-label training")
      
            dtrain = lgb.Dataset(x_train.values, label=y_train)
            dvalid = lgb.Dataset(x_valid.values, label=y_valid)
      
            # fit the model
            self.model = lgb.train(
                self.params,
                dtrain,
                num_boost_round=num_boost_round,
                valid_sets=[dtrain, dvalid],
                valid_names=["train", "valid"],
                early_stopping_rounds=early_stopping_rounds,
                verbose_eval=verbose_eval,
                evals_result=evals_result,
                **kwargs
            )
      
  • Override the predict method

    • The parameters must include the parameter dataset, which will be used to get the test dataset.

    • Return the prediction score.

    • Please refer to Model API <../reference/api.html#module-qlib.model.base>_ for the parameter types of the fit method.

    • Code Example: In the following example, users need to use LightGBM to predict the label(such as preds) of test data x_test and return it.

      .. code-block:: Python

        def predict(self, dataset: DatasetH, **kwargs)-> pandas.Series:
            if self.model is None:
                raise ValueError("model is not fitted yet!")
            x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
            return pd.Series(self.model.predict(x_test.values), index=x_test.index)
      
  • Override the finetune method (Optional)

    • This method is optional to the users. When users want to use this method on their own models, they should inherit the ModelFT base class, which includes the interface of finetune.

    • The parameters must include the parameter dataset.

    • Code Example: In the following example, users will use LightGBM as the model and finetune it.

      .. code-block:: Python

        def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
            # Based on existing model and finetune by train more rounds
            dtrain, _ = self._prepare_data(dataset)
            self.model = lgb.train(
                self.params,
                dtrain,
                num_boost_round=num_boost_round,
                init_model=self.model,
                valid_sets=[dtrain],
                valid_names=["train"],
                verbose_eval=verbose_eval,
            )
      

Configuration File

The configuration file is described in detail in the Workflow <../component/workflow.html#complete-example>_ document. In order to integrate the custom model into Qlib, users need to modify the "model" field in the configuration file. The configuration describes which models to use and how we can initialize it.

  • Example: The following example describes the model field of configuration file about the custom lightgbm model mentioned above, where module_path is the module path, class is the class name, and args is the hyperparameter passed into the init method. All parameters in the field is passed to self._params by \*\*kwargs in __init__ except loss = mse.

    .. code-block:: YAML

      model:
          class: LGBModel
          module_path: qlib.contrib.model.gbdt
          args:
              loss: mse
              colsample_bytree: 0.8879
              learning_rate: 0.0421
              subsample: 0.8789
              lambda_l1: 205.6999
              lambda_l2: 580.9768
              max_depth: 8
              num_leaves: 210
              num_threads: 20
    

Users could find configuration file of the baselines of the Model in examples/benchmarks. All the configurations of different models are listed under the corresponding model folder.

Model Testing

Assuming that the configuration file is examples/benchmarks/LightGBM/workflow_config_lightgbm.yaml, users can run the following command to test the custom model:

.. code-block:: bash

cd examples  # Avoid running program under the directory contains `qlib`
qrun benchmarks/LightGBM/workflow_config_lightgbm.yaml

.. note:: qrun is a built-in command of Qlib.

Also, Model can also be tested as a single module. An example has been given in examples/workflow_by_code.ipynb.

Reference

To know more about Forecast Model, please refer to Forecast Model: Model Training & Prediction <../component/model.html>_ and Model API <../reference/api.html#module-qlib.model.base>_.