fix memory leaks for xgboost

This commit is contained in:
nwen
2022-09-28 20:43:45 +08:00
parent 1068fe6df3
commit c17dedded4

View File

@ -180,6 +180,7 @@ typedef struct XGBoostState {
typedef struct SerializedModelXgboost {
int ft_cols;
SerializedModel *model = nullptr;
BoosterHandle booster = nullptr;
}SerializedModelXgboost;
@ -678,6 +679,7 @@ void xgboost_deserialize(SerializedModel *xg_model, SerializedModelXgboost *xgbo
xg_model->raw_data = placeholder;
xg_model->size = avail;
xgboostm->model = xg_model;
}
}
ModelPredictor xgboost_predict_prepare(AlgorithmAPI *, SerializedModel const *model, Oid return_type)
@ -691,17 +693,17 @@ ModelPredictor xgboost_predict_prepare(AlgorithmAPI *, SerializedModel const *mo
auto xg_model = const_cast<SerializedModel *>(model);
SerializedModelXgboost *xgboostm = (SerializedModelXgboost *)palloc0(sizeof(SerializedModelXgboost));
xgboost_deserialize(xg_model, xgboostm);
/* init XGBoost predictor */
safe_xgboost(g_xgboostApi->XGBoosterCreate(nullptr, 0, &xgboostm->booster));
/* load the decoded model */
safe_xgboost(g_xgboostApi->XGBoosterUnserializeFromBuffer(xgboostm->booster, xg_model->raw_data, xg_model->size));
return reinterpret_cast<ModelPredictor>(xgboostm);
}
Datum xgboost_predict(AlgorithmAPI *, ModelPredictor model, Datum *values, bool *isnull, Oid *types, int ncolumns)
{
SerializedModelXgboost *xgboostm = (SerializedModelXgboost *)model;
/* init XGBoost predictor */
safe_xgboost(g_xgboostApi->XGBoosterCreate(nullptr, 0, &xgboostm->booster));
/* load the decoded model */
safe_xgboost(g_xgboostApi->XGBoosterUnserializeFromBuffer(xgboostm->booster, xgboostm->model->raw_data, xgboostm->model->size));
/* sanity checks */
Assert(xgboostm->booster != nullptr);
if (ncolumns != xgboostm->ft_cols)
@ -729,7 +731,7 @@ Datum xgboost_predict(AlgorithmAPI *, ModelPredictor model, Datum *values, bool
/* release memory of xgboost dmatrix structure */
safe_xgboost(g_xgboostApi->XGDMatrixFree(dmat));
safe_xgboost(g_xgboostApi->XGBoosterFree(xgboostm->booster));
return Float8GetDatum(prediction);
}