203 lines
7.1 KiB
C++
203 lines
7.1 KiB
C++
/*
|
|
* Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
|
*
|
|
* openGauss is licensed under Mulan PSL v2.
|
|
* You can use this software according to the terms and conditions of the Mulan PSL v2.
|
|
* You may obtain a copy of Mulan PSL v2 at:
|
|
*
|
|
* http://license.coscl.org.cn/MulanPSL2
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
|
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
|
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
|
* See the Mulan PSL v2 for more details.
|
|
*---------------------------------------------------------------------------------------
|
|
*
|
|
* nodeTrainModel.cpp
|
|
* Implementation of Model Training Operators
|
|
*
|
|
* IDENTIFICATION
|
|
* src/gausskernel/runtime/executor/nodeTrainModel.cpp
|
|
*
|
|
* ---------------------------------------------------------------------------------------
|
|
*/
|
|
|
|
#include "postgres.h"
|
|
#include "funcapi.h"
|
|
|
|
#include "executor/executor.h"
|
|
#include "executor/node/nodeTrainModel.h"
|
|
#include "db4ai/db4ai_api.h"
|
|
|
|
static TupleTableSlot* ExecTrainModel(PlanState* state);
|
|
|
|
static bool ExecFetchTrainModel(void *callback_data, ModelTuple * tuple)
|
|
{
|
|
TrainModelState *pstate = (TrainModelState*)callback_data;
|
|
PlanState *outer_plan = outerPlanState(pstate);
|
|
TupleTableSlot *slot = ExecProcNode(outer_plan);
|
|
if (TupIsNull(slot))
|
|
return false;
|
|
|
|
if (tuple != &pstate->tuple) {
|
|
// make sure the output tuple has all information
|
|
tuple->ncolumns = pstate->tuple.ncolumns;
|
|
tuple->typid = pstate->tuple.typid;
|
|
tuple->typlen = pstate->tuple.typlen;
|
|
tuple->typbyval = pstate->tuple.typbyval;
|
|
}
|
|
|
|
// support of tuples that are (physical) - i.e., not virtual
|
|
if (slot->tts_tuple != nullptr) {
|
|
if (!pstate->row_allocated) {
|
|
tuple->values = (Datum *)palloc(sizeof(Datum) * tuple->ncolumns);
|
|
tuple->isnull = (bool *)palloc(sizeof(bool) *tuple->ncolumns);
|
|
pstate->row_allocated = true;
|
|
}
|
|
/*
|
|
* When all or most of a tuple's fields need to be extracted,
|
|
* this routine will be significantly quicker than a loop around
|
|
* heap_getattr; the loop will become O(N^2) as soon as any
|
|
* noncacheable attribute offsets are involved.
|
|
*/
|
|
heap_deform_tuple((HeapTuple)slot->tts_tuple, slot->tts_tupleDescriptor,
|
|
tuple->values, tuple->isnull);
|
|
} else {
|
|
Assert(!pstate->row_allocated);
|
|
tuple->values = slot->tts_values;
|
|
tuple->isnull = slot->tts_isnull;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
static void ExecReScanTrainModel(void *callback_data)
|
|
{
|
|
TrainModelState *pstate = (TrainModelState*)callback_data;
|
|
PlanState *outer_plan = outerPlanState(pstate);
|
|
ExecReScan(outer_plan);
|
|
}
|
|
|
|
TrainModelState* ExecInitTrainModel(TrainModel* pnode, EState* estate, int eflags)
|
|
{
|
|
TrainModelState *pstate = NULL;
|
|
Plan *outer_plan = outerPlan(pnode);
|
|
|
|
// check for unsupported flags
|
|
Assert(!(eflags & (EXEC_FLAG_REWIND | EXEC_FLAG_BACKWARD | EXEC_FLAG_MARK)));
|
|
|
|
// create state structure
|
|
AlgorithmAPI *palgo = get_algorithm_api(pnode->algorithm);
|
|
Assert(palgo->create != nullptr);
|
|
pstate = palgo->create(palgo, pnode);
|
|
pstate->ss.ps.plan = (Plan *)pnode;
|
|
pstate->ss.ps.state = estate;
|
|
pstate->config = pnode;
|
|
pstate->algorithm = palgo;
|
|
pstate->finished = 0;
|
|
pstate->ss.ps.ExecProcNode = ExecTrainModel;
|
|
|
|
// Tuple table initialization
|
|
ExecInitScanTupleSlot(estate, &pstate->ss);
|
|
ExecInitResultTupleSlot(estate, &pstate->ss.ps);
|
|
|
|
// initialize child expressions
|
|
ExecAssignExprContext(estate, &pstate->ss.ps);
|
|
pstate->ss.ps.targetlist = (List *)ExecInitExpr((Expr *)pnode->plan.targetlist, (PlanState *)pstate);
|
|
|
|
// initialize outer plan
|
|
PlanState *outer_plan_state = ExecInitNode(outer_plan, estate, eflags);
|
|
outerPlanState(pstate) = outer_plan_state;
|
|
|
|
// Initialize result tuple type and projection info.
|
|
ExecAssignScanTypeFromOuterPlan(&pstate->ss); // input tuples
|
|
ExecAssignResultTypeFromTL(&pstate->ss.ps); // result tuple
|
|
ExecAssignProjectionInfo(&pstate->ss.ps, NULL);
|
|
pstate->ss.ps.ps_TupFromTlist = false;
|
|
|
|
// Input tuple initialization
|
|
TupleDesc tupdesc = ExecGetResultType(outer_plan_state);
|
|
pstate->tuple.ncolumns = tupdesc->natts;
|
|
pstate->tuple.typid = (Oid *)palloc(sizeof(Oid) * pstate->tuple.ncolumns);
|
|
pstate->tuple.typbyval = (bool *)palloc(sizeof(bool) * pstate->tuple.ncolumns);
|
|
pstate->tuple.typlen = (int16 *)palloc(sizeof(int16) * pstate->tuple.ncolumns);
|
|
for (int c = 0; c < pstate->tuple.ncolumns; c++) {
|
|
pstate->tuple.typid[c] = tupdesc->attrs[c].atttypid;
|
|
pstate->tuple.typbyval[c] = tupdesc->attrs[c].attbyval;
|
|
pstate->tuple.typlen[c] = tupdesc->attrs[c].attlen;
|
|
}
|
|
|
|
pstate->row_allocated = false;
|
|
pstate->fetch = ExecFetchTrainModel;
|
|
pstate->rescan = ExecReScanTrainModel;
|
|
pstate->callback_data = pstate;
|
|
|
|
// Output tuple
|
|
TupleDesc tup_desc_out = CreateTemplateTupleDesc(1, false);
|
|
TupleDescInitEntry(tup_desc_out, (AttrNumber)1, "model", BYTEARRAYOID, -1, 0);
|
|
BlessTupleDesc(tup_desc_out);
|
|
ExecAssignResultType(&pstate->ss.ps, tup_desc_out);
|
|
ExecAssignProjectionInfo(&pstate->ss.ps, nullptr);
|
|
pstate->ss.ps.ps_TupFromTlist = false;
|
|
pstate->ss.ps.ps_ProjInfo = nullptr;
|
|
|
|
return pstate;
|
|
}
|
|
|
|
static TupleTableSlot* ExecTrainModel(PlanState* state)
|
|
{
|
|
TrainModelState* pstate = castNode(TrainModelState, state);
|
|
// check if already finished
|
|
if (pstate->finished == pstate->config->configurations)
|
|
return NULL;
|
|
|
|
// If backwards scan, just return NULL without changing state
|
|
if (!ScanDirectionIsForward(pstate->ss.ps.state->es_direction))
|
|
return NULL;
|
|
|
|
MemoryContext oldcxt = MemoryContextSwitchTo(pstate->config->cxt);
|
|
Model *model = nullptr;
|
|
model = (Model *)palloc0(sizeof(Model));
|
|
model->status = ERRCODE_INVALID_STATUS;
|
|
model->memory_context = pstate->config->cxt;
|
|
MemoryContextSwitchTo(oldcxt);
|
|
|
|
Assert(pstate->algorithm->run != nullptr);
|
|
pstate->algorithm->run(pstate->algorithm, pstate, &model);
|
|
if (model->status != ERRCODE_SUCCESSFUL_COMPLETION) {
|
|
MemoryContextSwitchTo(pstate->config->cxt);
|
|
pfree(model);
|
|
MemoryContextSwitchTo(oldcxt);
|
|
return NULL;
|
|
}
|
|
|
|
TupleTableSlot *slot = pstate->ss.ps.ps_ResultTupleSlot;
|
|
Datum *values = slot->tts_values;
|
|
values[0] = PointerGetDatum(model);
|
|
ExecClearTuple(slot);
|
|
ExecStoreVirtualTuple(slot);
|
|
|
|
return slot;
|
|
}
|
|
|
|
void ExecEndTrainModel(TrainModelState* pstate)
|
|
{
|
|
AlgorithmAPI *palgo = get_algorithm_api(pstate->config->algorithm);
|
|
Assert(palgo->end != nullptr);
|
|
palgo->end(palgo, pstate);
|
|
|
|
if (pstate->row_allocated) {
|
|
pfree(pstate->tuple.values);
|
|
pfree(pstate->tuple.isnull);
|
|
}
|
|
pfree(pstate->tuple.typid);
|
|
pfree(pstate->tuple.typbyval);
|
|
pfree(pstate->tuple.typlen);
|
|
|
|
ExecClearTuple(pstate->ss.ps.ps_ResultTupleSlot);
|
|
|
|
ExecFreeExprContext(&pstate->ss.ps);
|
|
ExecEndNode(outerPlanState(pstate));
|
|
pfree(pstate);
|
|
}
|
|
|