diff --git a/src/gausskernel/dbmind/db4ai/executor/algorithms/xgboost.cpp b/src/gausskernel/dbmind/db4ai/executor/algorithms/xgboost.cpp index ccd09d3b9..48d69d3e1 100644 --- a/src/gausskernel/dbmind/db4ai/executor/algorithms/xgboost.cpp +++ b/src/gausskernel/dbmind/db4ai/executor/algorithms/xgboost.cpp @@ -180,6 +180,7 @@ typedef struct XGBoostState { typedef struct SerializedModelXgboost { int ft_cols; + SerializedModel *model = nullptr; BoosterHandle booster = nullptr; }SerializedModelXgboost; @@ -678,6 +679,7 @@ void xgboost_deserialize(SerializedModel *xg_model, SerializedModelXgboost *xgbo xg_model->raw_data = placeholder; xg_model->size = avail; + xgboostm->model = xg_model; } } ModelPredictor xgboost_predict_prepare(AlgorithmAPI *, SerializedModel const *model, Oid return_type) @@ -691,17 +693,17 @@ ModelPredictor xgboost_predict_prepare(AlgorithmAPI *, SerializedModel const *mo auto xg_model = const_cast(model); SerializedModelXgboost *xgboostm = (SerializedModelXgboost *)palloc0(sizeof(SerializedModelXgboost)); xgboost_deserialize(xg_model, xgboostm); - /* init XGBoost predictor */ - safe_xgboost(g_xgboostApi->XGBoosterCreate(nullptr, 0, &xgboostm->booster)); - /* load the decoded model */ - safe_xgboost(g_xgboostApi->XGBoosterUnserializeFromBuffer(xgboostm->booster, xg_model->raw_data, xg_model->size)); - return reinterpret_cast(xgboostm); } Datum xgboost_predict(AlgorithmAPI *, ModelPredictor model, Datum *values, bool *isnull, Oid *types, int ncolumns) { SerializedModelXgboost *xgboostm = (SerializedModelXgboost *)model; + + /* init XGBoost predictor */ + safe_xgboost(g_xgboostApi->XGBoosterCreate(nullptr, 0, &xgboostm->booster)); + /* load the decoded model */ + safe_xgboost(g_xgboostApi->XGBoosterUnserializeFromBuffer(xgboostm->booster, xgboostm->model->raw_data, xgboostm->model->size)); /* sanity checks */ Assert(xgboostm->booster != nullptr); if (ncolumns != xgboostm->ft_cols) @@ -729,7 +731,7 @@ Datum xgboost_predict(AlgorithmAPI *, ModelPredictor model, Datum *values, bool /* release memory of xgboost dmatrix structure */ safe_xgboost(g_xgboostApi->XGDMatrixFree(dmat)); - + safe_xgboost(g_xgboostApi->XGBoosterFree(xgboostm->booster)); return Float8GetDatum(prediction); }