fix memory leaks for xgboost
This commit is contained in:
@ -180,6 +180,7 @@ typedef struct XGBoostState {
|
|||||||
|
|
||||||
typedef struct SerializedModelXgboost {
|
typedef struct SerializedModelXgboost {
|
||||||
int ft_cols;
|
int ft_cols;
|
||||||
|
SerializedModel *model = nullptr;
|
||||||
BoosterHandle booster = nullptr;
|
BoosterHandle booster = nullptr;
|
||||||
}SerializedModelXgboost;
|
}SerializedModelXgboost;
|
||||||
|
|
||||||
@ -678,6 +679,7 @@ void xgboost_deserialize(SerializedModel *xg_model, SerializedModelXgboost *xgbo
|
|||||||
|
|
||||||
xg_model->raw_data = placeholder;
|
xg_model->raw_data = placeholder;
|
||||||
xg_model->size = avail;
|
xg_model->size = avail;
|
||||||
|
xgboostm->model = xg_model;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ModelPredictor xgboost_predict_prepare(AlgorithmAPI *, SerializedModel const *model, Oid return_type)
|
ModelPredictor xgboost_predict_prepare(AlgorithmAPI *, SerializedModel const *model, Oid return_type)
|
||||||
@ -691,17 +693,17 @@ ModelPredictor xgboost_predict_prepare(AlgorithmAPI *, SerializedModel const *mo
|
|||||||
auto xg_model = const_cast<SerializedModel *>(model);
|
auto xg_model = const_cast<SerializedModel *>(model);
|
||||||
SerializedModelXgboost *xgboostm = (SerializedModelXgboost *)palloc0(sizeof(SerializedModelXgboost));
|
SerializedModelXgboost *xgboostm = (SerializedModelXgboost *)palloc0(sizeof(SerializedModelXgboost));
|
||||||
xgboost_deserialize(xg_model, xgboostm);
|
xgboost_deserialize(xg_model, xgboostm);
|
||||||
/* init XGBoost predictor */
|
|
||||||
safe_xgboost(g_xgboostApi->XGBoosterCreate(nullptr, 0, &xgboostm->booster));
|
|
||||||
/* load the decoded model */
|
|
||||||
safe_xgboost(g_xgboostApi->XGBoosterUnserializeFromBuffer(xgboostm->booster, xg_model->raw_data, xg_model->size));
|
|
||||||
|
|
||||||
return reinterpret_cast<ModelPredictor>(xgboostm);
|
return reinterpret_cast<ModelPredictor>(xgboostm);
|
||||||
}
|
}
|
||||||
|
|
||||||
Datum xgboost_predict(AlgorithmAPI *, ModelPredictor model, Datum *values, bool *isnull, Oid *types, int ncolumns)
|
Datum xgboost_predict(AlgorithmAPI *, ModelPredictor model, Datum *values, bool *isnull, Oid *types, int ncolumns)
|
||||||
{
|
{
|
||||||
SerializedModelXgboost *xgboostm = (SerializedModelXgboost *)model;
|
SerializedModelXgboost *xgboostm = (SerializedModelXgboost *)model;
|
||||||
|
|
||||||
|
/* init XGBoost predictor */
|
||||||
|
safe_xgboost(g_xgboostApi->XGBoosterCreate(nullptr, 0, &xgboostm->booster));
|
||||||
|
/* load the decoded model */
|
||||||
|
safe_xgboost(g_xgboostApi->XGBoosterUnserializeFromBuffer(xgboostm->booster, xgboostm->model->raw_data, xgboostm->model->size));
|
||||||
/* sanity checks */
|
/* sanity checks */
|
||||||
Assert(xgboostm->booster != nullptr);
|
Assert(xgboostm->booster != nullptr);
|
||||||
if (ncolumns != xgboostm->ft_cols)
|
if (ncolumns != xgboostm->ft_cols)
|
||||||
@ -729,7 +731,7 @@ Datum xgboost_predict(AlgorithmAPI *, ModelPredictor model, Datum *values, bool
|
|||||||
|
|
||||||
/* release memory of xgboost dmatrix structure */
|
/* release memory of xgboost dmatrix structure */
|
||||||
safe_xgboost(g_xgboostApi->XGDMatrixFree(dmat));
|
safe_xgboost(g_xgboostApi->XGDMatrixFree(dmat));
|
||||||
|
safe_xgboost(g_xgboostApi->XGBoosterFree(xgboostm->booster));
|
||||||
return Float8GetDatum(prediction);
|
return Float8GetDatum(prediction);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user