diff --git a/src/gausskernel/storage/mot/jit_exec/jit_common.cpp b/src/gausskernel/storage/mot/jit_exec/jit_common.cpp index 62325c5a1..9e45a977d 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_common.cpp +++ b/src/gausskernel/storage/mot/jit_exec/jit_common.cpp @@ -242,6 +242,40 @@ extern bool IsTypeSupported(int resultType) } } +extern bool IsStringType(int type) +{ + switch (type) { + case VARCHAROID: + case BPCHAROID: + case TEXTOID: + case BYTEAOID: + return true; + default: + return false; + } +} + +extern bool IsPrimitiveType(int type) +{ + switch (type) { + case BOOLOID: + case CHAROID: + case INT1OID: + case INT2OID: + case INT4OID: + case INT8OID: + case TIMEOID: + case TIMESTAMPOID: + case DATEOID: + case FLOAT4OID: + case FLOAT8OID: + return true; + + default: + return false; + } +} + static bool IsEqualsWhereOperator(int whereOp) { bool result = false; @@ -623,4 +657,234 @@ extern bool PrepareSubQueryData(JitContext* jitContext, JitCompoundPlan* plan) return result; } + +static bool CloneStringDatum(Datum source, Datum* target, JitContextUsage usage) +{ + bytea* value = DatumGetByteaP(source); + size_t len = VARSIZE(value); // includes header len VARHDRSZ + char* src = VARDATA(value); + + // special case: empty string + if (len == 0) { + len = VARHDRSZ; + } + size_t strSize = len - VARHDRSZ; + MOT_LOG_TRACE("CloneStringDatum(): len = %u, src = %*.*s", (unsigned)len, strSize, strSize, src); + + bytea* copy = nullptr; + if (usage == JIT_CONTEXT_GLOBAL) { + copy = (bytea*)MOT::MemGlobalAlloc(len); + } else { + copy = (bytea*)MOT::MemSessionAlloc(len); + } + if (copy == nullptr) { + MOT_REPORT_ERROR( + MOT_ERROR_OOM, "JIT Compile", "Failed to allocate %u bytes for datum string constant", (unsigned)len); + return false; + } + + if (strSize > 0) { + errno_t erc = memcpy_s(VARDATA(copy), strSize, (uint8_t*)src, strSize); + securec_check(erc, "\0", "\0"); + } + + SET_VARSIZE(copy, len); + + *target = PointerGetDatum(copy); + return true; +} + +static bool CloneTimeTzDatum(Datum source, Datum* target, JitContextUsage usage) +{ + MOT::TimetzSt* value = (MOT::TimetzSt*)DatumGetPointer(source); + size_t allocSize = sizeof(MOT::TimetzSt); + MOT::TimetzSt* copy = nullptr; + if (usage == JIT_CONTEXT_GLOBAL) { + copy = (MOT::TimetzSt*)MOT::MemGlobalAlloc(allocSize); + } else { + copy = (MOT::TimetzSt*)MOT::MemSessionAlloc(allocSize); + } + if (copy == nullptr) { + MOT_REPORT_ERROR( + MOT_ERROR_OOM, "JIT Compile", "Failed to allocate %u bytes for datum TimeTZ constant", (unsigned)allocSize); + return false; + } + copy->m_time = value->m_time; + copy->m_zone = value->m_zone; + + *target = PointerGetDatum(copy); + return true; +} + +static bool CloneIntervalDatum(Datum source, Datum* target, JitContextUsage usage) +{ + MOT::IntervalSt* value = (MOT::IntervalSt*)DatumGetPointer(source); + size_t allocSize = sizeof(MOT::IntervalSt); + MOT::IntervalSt* copy = nullptr; + if (usage == JIT_CONTEXT_GLOBAL) { + copy = (MOT::IntervalSt*)MOT::MemGlobalAlloc(allocSize); + } else { + copy = (MOT::IntervalSt*)MOT::MemSessionAlloc(allocSize); + } + if (copy == nullptr) { + MOT_REPORT_ERROR(MOT_ERROR_OOM, + "JIT Compile", + "Failed to allocate %u bytes for datum Interval constant", + (unsigned)allocSize); + return false; + } + copy->m_day = value->m_day; + copy->m_month = value->m_month; + copy->m_time = value->m_time; + + *target = PointerGetDatum(copy); + return true; +} + +static bool CloneTIntervalDatum(Datum source, Datum* target, JitContextUsage usage) +{ + MOT::TintervalSt* value = (MOT::TintervalSt*)DatumGetPointer(source); + size_t allocSize = sizeof(MOT::TintervalSt); + MOT::TintervalSt* copy = nullptr; + if (usage == JIT_CONTEXT_GLOBAL) { + copy = (MOT::TintervalSt*)MOT::MemGlobalAlloc(allocSize); + } else { + copy = (MOT::TintervalSt*)MOT::MemSessionAlloc(allocSize); + } + if (copy == nullptr) { + MOT_REPORT_ERROR(MOT_ERROR_OOM, + "JIT Compile", + "Failed to allocate %u bytes for datum TInterval constant", + (unsigned)allocSize); + return false; + } + copy->m_status = value->m_status; + copy->m_data[0] = value->m_data[0]; + copy->m_data[1] = value->m_data[1]; + + *target = PointerGetDatum(copy); + return true; +} + +static bool CloneNumericDatum(Datum source, Datum* target, JitContextUsage usage) +{ + varlena* var = (varlena*)DatumGetPointer(source); + Size len = VARSIZE(var); + struct varlena* result = nullptr; + if (usage == JIT_CONTEXT_GLOBAL) { + result = (varlena*)MOT::MemGlobalAlloc(len); + } else { + result = (varlena*)MOT::MemSessionAlloc(len); + } + if (result == nullptr) { + MOT_REPORT_ERROR( + MOT_ERROR_OOM, "JIT Compile", "Failed to allocate %u bytes for datum Numeric constant", (unsigned)len); + return false; + } + + errno_t rc = memcpy_s(result, len, var, len); + securec_check(rc, "\0", "\0"); + + *target = NumericGetDatum((Numeric)result); + return true; +} + +static bool CloneCStringDatum(Datum source, Datum* target, JitContextUsage usage) +{ + char* src = DatumGetCString(source); + size_t len = strlen(src) + 1; // includes terminating null + + char* copy = nullptr; + if (usage == JIT_CONTEXT_GLOBAL) { + copy = (char*)MOT::MemGlobalAlloc(len); + } else { + copy = (char*)MOT::MemSessionAlloc(len); + } + if (copy == nullptr) { + MOT_REPORT_ERROR( + MOT_ERROR_OOM, "JIT Compile", "Failed to allocate %u bytes for datum string constant", (unsigned)len); + return false; + } + + errno_t erc = memcpy_s(copy, len, src, len); + securec_check(erc, "\0", "\0"); + + *target = PointerGetDatum(copy); + return true; +} + +extern bool CloneDatum(Datum source, int type, Datum* target, JitContextUsage usage) +{ + bool result = true; + if (IsStringType(type)) { + result = CloneStringDatum(source, target, usage); + } else { + switch (type) { + case TIMETZOID: + result = CloneTimeTzDatum(source, target, usage); + break; + + case INTERVALOID: + result = CloneIntervalDatum(source, target, usage); + break; + + case TINTERVALOID: + result = CloneTIntervalDatum(source, target, usage); + break; + + case NUMERICOID: + result = CloneNumericDatum(source, target, usage); + break; + + case UNKNOWNOID: + result = CloneCStringDatum(source, target, usage); + break; + + default: + MOT_LOG_TRACE("Unsupported non-primitive constant type: %d", type); + result = false; + break; + } + } + + return result; +} + +extern bool PrepareDatumArray(Const* constArray, uint32_t constCount, JitDatumArray* datumArray) +{ + size_t allocSize = sizeof(JitDatum) * constCount; + JitDatum* datums = (JitDatum*)MOT::MemGlobalAlloc(allocSize); + if (datums == nullptr) { + MOT_REPORT_ERROR( + MOT_ERROR_OOM, "JIT Compile", "Failed to allocate %u bytes for constant datum array", (unsigned)allocSize); + return false; + } + + for (uint32_t i = 0; i < constCount; ++i) { + Const* constValue = &constArray[i]; + datums[i].m_isNull = constValue->constisnull; + datums[i].m_type = constValue->consttype; + if (!datums[i].m_isNull) { + if (IsPrimitiveType(constValue->constvalue)) { + datums[i].m_datum = constValue->constvalue; + } else { + if (!CloneDatum( + constValue->constvalue, constValue->consttype, &datums[i].m_datum, JIT_CONTEXT_GLOBAL)) { + MOT_LOG_TRACE("Failed to prepare datum value"); + for (uint32_t j = 0; j < i; ++j) { + if (!datums[j].m_isNull && !IsPrimitiveType(datums[j].m_type)) { + MOT::MemGlobalFree(DatumGetPointer(datums[j].m_datum)); + } + } + MOT::MemGlobalFree(datums); + return false; + } + } + } + } + + datumArray->m_datumCount = constCount; + datumArray->m_datums = datums; + return true; +} } // namespace JitExec diff --git a/src/gausskernel/storage/mot/jit_exec/jit_common.h b/src/gausskernel/storage/mot/jit_exec/jit_common.h index 128aecca3..920397f26 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_common.h +++ b/src/gausskernel/storage/mot/jit_exec/jit_common.h @@ -31,6 +31,9 @@ #include "mot_engine.h" +/** @define The maximum number of constant objects that can be used in a query. */ +#define MOT_JIT_MAX_CONST 1024 + // This file contains definitions used both by LLVM and TVM jitted code namespace JitExec { // forward declaration @@ -64,6 +67,12 @@ extern int BuildIndexColumnOffsets(MOT::Table* table, const MOT::Index* index, i /** @brief Queries whether a PG type is supported by MOT tables. */ extern bool IsTypeSupported(int resultType); +/** @brief Queries whether a PG type represents a string. */ +extern bool IsStringType(int type); + +/** @brief Queries whether a PG type represents a primitive type. */ +extern bool IsPrimitiveType(int type); + /** @brief Queries whether a WHERE clause operator is supported. */ extern bool IsWhereOperatorSupported(int whereOp); @@ -125,6 +134,12 @@ extern void DestroyTableInfo(TableInfo* table_info); * @return True if operations succeeded, otherwise false. */ extern bool PrepareSubQueryData(JitContext* jitContext, JitCompoundPlan* plan); + +/** @brief Prepares array of global datum objects from array of constants. */ +extern bool PrepareDatumArray(Const* constArray, uint32_t constCount, JitDatumArray* datumArray); + +/** @brief Clones an interval datum into global memory. */ +extern bool CloneDatum(Datum source, int type, Datum* target, JitContextUsage usage); } // namespace JitExec #endif diff --git a/src/gausskernel/storage/mot/jit_exec/jit_context.cpp b/src/gausskernel/storage/mot/jit_exec/jit_context.cpp index e3366da42..da0a868d8 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_context.cpp +++ b/src/gausskernel/storage/mot/jit_exec/jit_context.cpp @@ -97,6 +97,63 @@ extern void FreeJitContext(JitContext* jitContext) } } +static bool CloneDatumArray(JitDatumArray* source, JitDatumArray* target, JitContextUsage usage) +{ + uint32_t datumCount = source->m_datumCount; + if (datumCount == 0) { + target->m_datumCount = 0; + target->m_datums = nullptr; + return true; + } + + size_t allocSize = sizeof(JitDatum) * datumCount; + JitDatum* datumArray = nullptr; + if (usage == JIT_CONTEXT_GLOBAL) { + datumArray = (JitDatum*)MOT::MemGlobalAlloc(allocSize); + } else { + datumArray = (JitDatum*)MOT::MemSessionAlloc(allocSize); + } + if (datumArray == nullptr) { + MOT_REPORT_ERROR( + MOT_ERROR_OOM, "JIT Compile", "Failed to allocate %u bytes for datum array", (unsigned)allocSize); + return false; + } + + for (uint32_t i = 0; i < datumCount; ++i) { + JitDatum* datum = (JitDatum*)&source->m_datums[i]; + datumArray[i].m_isNull = datum->m_isNull; + datumArray[i].m_type = datum->m_type; + if (!datum->m_isNull) { + if (IsPrimitiveType(datum->m_type)) { + datumArray[i].m_datum = datum->m_datum; + } else { + if (!CloneDatum(datum->m_datum, datum->m_type, &datumArray[i].m_datum, usage)) { + MOT_REPORT_ERROR(MOT_ERROR_OOM, "JIT Compile", "Failed to clone datum array entry"); + for (uint32_t j = 0; j < i; ++j) { + if (!IsPrimitiveType(datumArray[j].m_type)) { + if (usage == JIT_CONTEXT_GLOBAL) { + MOT::MemGlobalFree(DatumGetPointer(datumArray[j].m_datum)); + } else { + MOT::MemGlobalFree(DatumGetPointer(datumArray[j].m_datum)); + } + } + } + if (usage == JIT_CONTEXT_GLOBAL) { + MOT::MemGlobalFree(datumArray); + } else { + MOT::MemSessionFree(datumArray); + } + return false; + } + } + } + } + + target->m_datums = datumArray; + target->m_datumCount = datumCount; + return true; +} + extern JitContext* CloneJitContext(JitContext* sourceJitContext) { MOT_LOG_TRACE("Cloning JIT context %p of query: %s", sourceJitContext, sourceJitContext->m_queryString); @@ -109,6 +166,11 @@ extern JitContext* CloneJitContext(JitContext* sourceJitContext) result->m_llvmFunction = sourceJitContext->m_llvmFunction; result->m_tvmFunction = sourceJitContext->m_tvmFunction; result->m_commandType = sourceJitContext->m_commandType; + if (!CloneDatumArray(&sourceJitContext->m_constDatums, &result->m_constDatums, JIT_CONTEXT_LOCAL)) { + MOT_REPORT_ERROR(MOT_ERROR_OOM, "JIT Compile", "Failed to clone constant datum array"); + DestroyJitContext(result); + return nullptr; + } result->m_table = sourceJitContext->m_table; result->m_index = sourceJitContext->m_index; result->m_indexId = sourceJitContext->m_indexId; @@ -492,6 +554,29 @@ extern bool PrepareJitContext(JitContext* jitContext) return true; } +static void DestroyDatumArray(JitDatumArray* datumArray, JitContextUsage usage) +{ + if (datumArray->m_datumCount > 0) { + MOT_ASSERT(datumArray->m_datums != nullptr); + for (uint32_t i = 0; i < datumArray->m_datumCount; ++i) { + if (!datumArray->m_datums[i].m_isNull && !IsPrimitiveType(datumArray->m_datums[i].m_type)) { + if (usage == JIT_CONTEXT_GLOBAL) { + MOT::MemGlobalFree(DatumGetPointer(datumArray->m_datums[i].m_datum)); + } else { + MOT::MemSessionFree(DatumGetPointer(datumArray->m_datums[i].m_datum)); + } + } + } + if (usage == JIT_CONTEXT_GLOBAL) { + MOT::MemGlobalFree(datumArray->m_datums); + } else { + MOT::MemSessionFree(datumArray->m_datums); + } + datumArray->m_datums = nullptr; + datumArray->m_datumCount = 0; + } +} + extern void DestroyJitContext(JitContext* jitContext) { if (jitContext != nullptr) { @@ -513,6 +598,9 @@ extern void DestroyJitContext(JitContext* jitContext) jitContext->m_jitSource = nullptr; } + // cleanup constant datum array + DestroyDatumArray(&jitContext->m_constDatums, jitContext->m_usage); + // cleanup sub-query data array CleanupJitContextSubQueryDataArray(jitContext); diff --git a/src/gausskernel/storage/mot/jit_exec/jit_context.h b/src/gausskernel/storage/mot/jit_exec/jit_context.h index de35beffc..c9b2e6613 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_context.h +++ b/src/gausskernel/storage/mot/jit_exec/jit_context.h @@ -34,6 +34,26 @@ namespace JitExec { struct JitContextPool; struct JitSource; +/** @struct Array of constant datum objects used in JIT execution. */ +struct JitDatum { + /** @var The constant value. */ + Datum m_datum; + + /** @var The constant type. */ + int m_type; + + /** @var The constant is-null property. */ + int m_isNull; +}; + +struct JitDatumArray { + /** @var The number of constant datum objects used by the jitted function (global copy for all contexts). */ + uint64_t m_datumCount; + + /** @var The array of constant datum objects used by the jitted function (global copy for all contexts). */ + JitDatum* m_datums; +}; + /** * @typedef The context for executing a jitted function. */ @@ -97,6 +117,9 @@ struct JitContext { /** @var The source query string. */ const char* m_queryString; // L1 offset 40 (constant) + /** @var The array of constant datum objects used by the jitted function (global copy for all contexts). */ + JitDatumArray m_constDatums; + /*---------------------- Range Scan execution state -------------------*/ /** @var Begin iterator for range select (stateful execution). */ MOT::IndexIterator* m_beginIterator; // L1 offset 48 diff --git a/src/gausskernel/storage/mot/jit_exec/jit_helpers.cpp b/src/gausskernel/storage/mot/jit_exec/jit_helpers.cpp index edc89ae9d..2aa3fd7a3 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_helpers.cpp +++ b/src/gausskernel/storage/mot/jit_exec/jit_helpers.cpp @@ -392,6 +392,22 @@ int getExprArgIsNull(int arg_pos) return result; } +Datum GetConstAt(int constId, int argPos) +{ + MOT_LOG_DEBUG("Retrieving constant datum by id %d", constId); + Datum result = PointerGetDatum(nullptr); + JitExec::JitContext* ctx = u_sess->mot_cxt.jit_context; + if (constId < (int)ctx->m_constDatums.m_datumCount) { + JitExec::JitDatum* datum = &ctx->m_constDatums.m_datums[constId]; + result = datum->m_datum; + setExprArgIsNull(argPos, datum->m_isNull); + DBG_PRINT_DATUM("Retrieved constant datum", datum->m_type, datum->m_datum, datum->m_isNull); + } else { + MOT_LOG_ERROR("Invalid constant identifier: %d", constId); + } + return result; +} + Datum getDatumParam(ParamListInfo params, int paramid, int arg_pos) { MOT_LOG_DEBUG("Retrieving datum param at index %d", paramid); diff --git a/src/gausskernel/storage/mot/jit_exec/jit_helpers.h b/src/gausskernel/storage/mot/jit_exec/jit_helpers.h index 670bf8739..543fb249a 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_helpers.h +++ b/src/gausskernel/storage/mot/jit_exec/jit_helpers.h @@ -96,6 +96,14 @@ void setExprArgIsNull(int arg_pos, int isnull); */ int getExprArgIsNull(int arg_pos); +/** + * @brief Retrieves a pooled constant by its identifier. + * @param constId The identifier of the constant value. + * @param argPos The ordinal position of the enveloping parameter expression. + * @return The constant value. + */ +Datum GetConstAt(int constId, int argPos); + /** * @brief Retrieves a datum parameter from parameters array. * @param params The parameter array. diff --git a/src/gausskernel/storage/mot/jit_exec/jit_llvm_blocks.cpp b/src/gausskernel/storage/mot/jit_exec/jit_llvm_blocks.cpp index eac5706fd..b3e91d387 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_llvm_blocks.cpp +++ b/src/gausskernel/storage/mot/jit_exec/jit_llvm_blocks.cpp @@ -454,10 +454,21 @@ static llvm::Value* ProcessConstExpr( if (IsTypeSupported(const_value->consttype)) { result_type = const_value->consttype; AddSetExprArgIsNull(ctx, arg_pos, const_value->constisnull); // mark expression null status - result = llvm::ConstantInt::get(ctx->INT64_T, const_value->constvalue, true); + if (IsPrimitiveType(result_type)) { + result = llvm::ConstantInt::get(ctx->INT64_T, const_value->constvalue, true); + } else { + int constId = AllocateConstId(ctx, result_type, const_value->constvalue, const_value->constisnull); + if (constId == -1) { + MOT_LOG_TRACE("Failed to allocate constant identifier"); + } else { + result = AddGetConstAt(ctx, constId, arg_pos); + } + } if (max_arg && (arg_pos > *max_arg)) { *max_arg = arg_pos; } + } else { + MOT_LOG_TRACE("Failed to process const expression: type %d unsupported", (int)result_type); } MOT_LOG_DEBUG("%*s <-- Processing CONST expression result: %p", depth, "", result); @@ -769,8 +780,18 @@ static llvm::Value* ProcessExpr( static llvm::Value* ProcessConstExpr(JitLlvmCodeGenContext* ctx, const JitConstExpr* expr, int* max_arg) { + llvm::Value* result = nullptr; AddSetExprArgIsNull(ctx, expr->_arg_pos, expr->_is_null); // mark expression null status - llvm::Value* result = llvm::ConstantInt::get(ctx->INT64_T, expr->_value, true); + if (IsPrimitiveType(expr->_const_type)) { + result = llvm::ConstantInt::get(ctx->INT64_T, expr->_value, true); + } else { + int constId = AllocateConstId(ctx, expr->_const_type, expr->_value, expr->_is_null); + if (constId == -1) { + MOT_LOG_TRACE("Failed to allocate constant identifier"); + } else { + result = AddGetConstAt(ctx, constId, expr->_arg_pos); + } + } if (max_arg && (expr->_arg_pos > *max_arg)) { *max_arg = expr->_arg_pos; } diff --git a/src/gausskernel/storage/mot/jit_exec/jit_llvm_funcs.h b/src/gausskernel/storage/mot/jit_exec/jit_llvm_funcs.h index 101fc0b25..7b272af90 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_llvm_funcs.h +++ b/src/gausskernel/storage/mot/jit_exec/jit_llvm_funcs.h @@ -632,6 +632,11 @@ inline void DefineGetSubQueryEndIteratorKey(JitLlvmCodeGenContext* ctx, llvm::Mo defineFunction(module, ctx->KeyType->getPointerTo(), "GetSubQueryEndIteratorKey", ctx->INT32_T, nullptr); } +inline void DefineGetConstAt(JitLlvmCodeGenContext* ctx, llvm::Module* module) +{ + ctx->GetConstAtFunc = defineFunction(module, ctx->DATUM_T, "GetConstAt", ctx->INT32_T, ctx->INT32_T, nullptr); +} + /*--------------------------- End of LLVM Helper Prototypes ---------------------------*/ /*--------------------------- Helpers to generate calls to Helper function via LLVM ---------------------------*/ @@ -1326,6 +1331,13 @@ inline llvm::Value* AddGetSubQueryEndIteratorKey(JitLlvmCodeGenContext* ctx, int return AddFunctionCall(ctx, ctx->GetSubQueryEndIteratorKeyFunc, subQueryIndexValue, nullptr); } +inline llvm::Value* AddGetConstAt(JitLlvmCodeGenContext* ctx, int constId, int argPos) +{ + llvm::ConstantInt* constIdValue = llvm::ConstantInt::get(ctx->INT32_T, constId, true); + llvm::ConstantInt* argPosValue = llvm::ConstantInt::get(ctx->INT32_T, argPos, true); + return AddFunctionCall(ctx, ctx->GetConstAtFunc, constIdValue, argPosValue, nullptr); +} + /** @brief Adds a call to issueDebugLog(function, msg). */ #ifdef MOT_JIT_DEBUG inline void IssueDebugLogImpl(JitLlvmCodeGenContext* ctx, const char* function, const char* msg) diff --git a/src/gausskernel/storage/mot/jit_exec/jit_llvm_query.h b/src/gausskernel/storage/mot/jit_exec/jit_llvm_query.h index 8ef7e1c54..6f996fcc2 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_llvm_query.h +++ b/src/gausskernel/storage/mot/jit_exec/jit_llvm_query.h @@ -171,6 +171,7 @@ struct JitLlvmCodeGenContext { llvm::FunctionCallee GetSubQueryIndexFunc; llvm::FunctionCallee GetSubQuerySearchKeyFunc; llvm::FunctionCallee GetSubQueryEndIteratorKeyFunc; + llvm::FunctionCallee GetConstAtFunc; // builtins #define APPLY_UNARY_OPERATOR(funcid, name) llvm::FunctionCallee _builtin_##name; @@ -224,10 +225,16 @@ struct JitLlvmCodeGenContext { TableInfo _inner_table_info; TableInfo* m_subQueryTableInfo; + // non-primitive constants + uint32_t m_constCount; + Const* m_constValues; + dorado::GsCodeGen* _code_gen; dorado::GsCodeGen::LlvmBuilder* _builder; llvm::Function* m_jittedQuery; }; + +extern int AllocateConstId(JitLlvmCodeGenContext* ctx, int type, Datum value, bool isNull); } // namespace JitExec #endif /* JIT_LLVM_QUERY_H */ diff --git a/src/gausskernel/storage/mot/jit_exec/jit_llvm_query_codegen.cpp b/src/gausskernel/storage/mot/jit_exec/jit_llvm_query_codegen.cpp index d931d30c2..465debd7f 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_llvm_query_codegen.cpp +++ b/src/gausskernel/storage/mot/jit_exec/jit_llvm_query_codegen.cpp @@ -142,6 +142,7 @@ void InitCodeGenContextFuncs(JitLlvmCodeGenContext* ctx) DefineGetSubQueryIndex(ctx, module); DefineGetSubQuerySearchKey(ctx, module); DefineGetSubQueryEndIteratorKey(ctx, module); + DefineGetConstAt(ctx, module); } #define APPLY_UNARY_OPERATOR(funcid, name) \ @@ -258,6 +259,18 @@ static bool InitCodeGenContext(JitLlvmCodeGenContext* ctx, GsCodeGen* code_gen, return false; } + ctx->m_constCount = 0; + size_t allocSize = sizeof(Const) * MOT_JIT_MAX_CONST; + ctx->m_constValues = (Const*)MOT::MemSessionAlloc(allocSize); + if (ctx->m_constValues == nullptr) { + MOT_REPORT_ERROR(MOT_ERROR_OOM, + "JIT Compile", + "Failed to allocate %u bytes for constant array in code-generation context", + allocSize); + DestroyCodeGenContext(ctx); + return false; + } + InitCodeGenContextTypes(ctx); InitCodeGenContextFuncs(ctx); InitCodeGenContextBuiltins(ctx); @@ -351,6 +364,9 @@ static void DestroyCodeGenContext(JitLlvmCodeGenContext* ctx) MOT::MemSessionFree(ctx->m_subQueryTableInfo); ctx->m_subQueryTableInfo = nullptr; } + if (ctx->m_constValues != nullptr) { + MOT::MemSessionFree(ctx->m_constValues); + } if (ctx->_code_gen != nullptr) { ctx->_code_gen->releaseResource(); delete ctx->_code_gen; @@ -358,6 +374,24 @@ static void DestroyCodeGenContext(JitLlvmCodeGenContext* ctx) } } +extern int AllocateConstId(JitLlvmCodeGenContext* ctx, int type, Datum value, bool isNull) +{ + int res = -1; + if (ctx->m_constCount == MOT_JIT_MAX_CONST) { + MOT_REPORT_ERROR(MOT_ERROR_RESOURCE_LIMIT, + "JIT Compile", + "Cannot allocate constant identifier, reached limit of %u", + ctx->m_constCount); + } else { + res = ctx->m_constCount++; + ctx->m_constValues[res].consttype = type; + ctx->m_constValues[res].constvalue = value; + ctx->m_constValues[res].constisnull = isNull; + MOT_LOG_TRACE("Allocated constant id: %d", res); + } + return res; +} + /** @brief Wraps up an LLVM function (compiles it and prepares a funciton pointer). */ static JitContext* FinalizeCodegen(JitLlvmCodeGenContext* ctx, int max_arg, JitCommandType command_type) { @@ -381,6 +415,15 @@ static JitContext* FinalizeCodegen(JitLlvmCodeGenContext* ctx, int max_arg, JitC #endif } + // prepare global constant array + JitDatumArray datumArray = {}; + if (ctx->m_constCount > 0) { + if (!PrepareDatumArray(ctx->m_constValues, ctx->m_constCount, &datumArray)) { + MOT_LOG_ERROR("Failed to generate jitted code for query: Failed to prepare constant datum array"); + return nullptr; + } + } + // that's it, we are ready JitContext* jit_context = AllocJitContext(JIT_CONTEXT_GLOBAL); if (jit_context == nullptr) { @@ -412,6 +455,8 @@ static JitContext* FinalizeCodegen(JitLlvmCodeGenContext* ctx, int max_arg, JitC MOT_LOG_TRACE("Installed inner index id: %" PRIu64, jit_context->m_innerIndexId); } jit_context->m_commandType = command_type; + jit_context->m_constDatums.m_datumCount = datumArray.m_datumCount; + jit_context->m_constDatums.m_datums = datumArray.m_datums; return jit_context; } diff --git a/src/gausskernel/storage/mot/jit_exec/jit_tvm_blocks.cpp b/src/gausskernel/storage/mot/jit_exec/jit_tvm_blocks.cpp index 2a39ba41f..654fc3ba3 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_tvm_blocks.cpp +++ b/src/gausskernel/storage/mot/jit_exec/jit_tvm_blocks.cpp @@ -424,7 +424,17 @@ static Expression* ProcessConstExpr( if (IsTypeSupported(const_value->consttype)) { result_type = const_value->consttype; - result = new (std::nothrow) ConstExpression(const_value->constvalue, arg_pos, (int)(const_value->constisnull)); + if (IsPrimitiveType(result_type)) { + result = + new (std::nothrow) ConstExpression(const_value->constvalue, arg_pos, (int)(const_value->constisnull)); + } else { + int constId = AllocateConstId(ctx, result_type, const_value->constvalue, const_value->constisnull); + if (constId == -1) { + MOT_LOG_TRACE("Failed to allocate constant identifier"); + } else { + result = AddGetConstAt(ctx, constId, arg_pos); + } + } if (max_arg && (arg_pos > *max_arg)) { *max_arg = arg_pos; } @@ -662,8 +672,18 @@ static Expression* ProcessExpr( static Expression* ProcessConstExpr(JitTvmCodeGenContext* ctx, const JitConstExpr* expr, int* max_arg) { - AddSetExprArgIsNull(ctx, expr->_arg_pos, (expr->_is_null ? 1 : 0)); // mark expression null status - Expression* result = new (std::nothrow) ConstExpression(expr->_value, expr->_arg_pos, (int)(expr->_is_null)); + Expression* result = nullptr; + AddSetExprArgIsNull(ctx, expr->_arg_pos, expr->_is_null); // mark expression null status + if (IsPrimitiveType(expr->_const_type)) { + result = new (std::nothrow) ConstExpression(expr->_value, expr->_arg_pos, (int)(expr->_is_null)); + } else { + int constId = AllocateConstId(ctx, expr->_const_type, expr->_value, expr->_is_null); + if (constId == -1) { + MOT_LOG_TRACE("Failed to allocate constant identifier"); + } else { + result = AddGetConstAt(ctx, constId, expr->_arg_pos); + } + } if (max_arg && (expr->_arg_pos > *max_arg)) { *max_arg = expr->_arg_pos; } diff --git a/src/gausskernel/storage/mot/jit_exec/jit_tvm_funcs.h b/src/gausskernel/storage/mot/jit_exec/jit_tvm_funcs.h index 4cdfb5136..4ddca70f9 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_tvm_funcs.h +++ b/src/gausskernel/storage/mot/jit_exec/jit_tvm_funcs.h @@ -2761,6 +2761,31 @@ private: int m_subQueryIndex; }; +/** @class GetConstAtExpression */ +class GetConstAtExpression : public tvm::Expression { +public: + explicit GetConstAtExpression(int constId, int argPos) + : Expression(tvm::Expression::CanFail), m_constId(constId), m_argPos(argPos) + {} + + ~GetConstAtExpression() final + {} + + Datum eval(tvm::ExecContext* exec_context) final + { + return (uint64_t)GetConstAt(m_constId, m_argPos); + } + + void dump() final + { + (void)fprintf(stderr, "GetConstAt(constId=%d, argPos=%d)", m_constId, m_argPos); + } + +private: + int m_constId; + int m_argPos; +}; + inline tvm::Instruction* AddIsSoftMemoryLimitReached(JitTvmCodeGenContext* ctx) { return ctx->_builder->addInstruction(new (std::nothrow) IsSoftMemoryLimitReachedInstruction()); @@ -3148,6 +3173,11 @@ inline void AddCopyAggregateToSubQueryResult(JitTvmCodeGenContext* ctx, int subQ ctx->_builder->addInstruction(new (std::nothrow) CopyAggregateToSubQueryResultInstruction(subQueryIndex)); } +inline tvm::Expression* AddGetConstAt(JitTvmCodeGenContext* ctx, int constId, int argPos) +{ + return new (std::nothrow) GetConstAtExpression(constId, argPos); +} + #ifdef MOT_JIT_DEBUG inline void IssueDebugLogImpl(JitTvmCodeGenContext* ctx, const char* function, const char* msg) { diff --git a/src/gausskernel/storage/mot/jit_exec/jit_tvm_query.h b/src/gausskernel/storage/mot/jit_exec/jit_tvm_query.h index 344352aa6..f9d17567d 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_tvm_query.h +++ b/src/gausskernel/storage/mot/jit_exec/jit_tvm_query.h @@ -61,7 +61,13 @@ struct JitTvmCodeGenContext { /** @var The resulting jitted function. */ tvm::Function* m_jittedQuery; + + // non-primitive constants + uint32_t m_constCount; + Const* m_constValues; }; + +extern int AllocateConstId(JitTvmCodeGenContext* ctx, int type, Datum value, bool isNull); } // namespace JitExec #endif /* JIT_TVM_QUERY_H */ diff --git a/src/gausskernel/storage/mot/jit_exec/jit_tvm_query_codegen.cpp b/src/gausskernel/storage/mot/jit_exec/jit_tvm_query_codegen.cpp index a83ed313f..974cab4c5 100644 --- a/src/gausskernel/storage/mot/jit_exec/jit_tvm_query_codegen.cpp +++ b/src/gausskernel/storage/mot/jit_exec/jit_tvm_query_codegen.cpp @@ -39,6 +39,8 @@ using namespace tvm; namespace JitExec { DECLARE_LOGGER(JitTvmQueryCodegen, JitExec) +static void DestroyCodeGenContext(JitTvmCodeGenContext* ctx); + /** @brief Initializes a context for compilation. */ static bool InitCodeGenContext(JitTvmCodeGenContext* ctx, Builder* builder, MOT::Table* table, MOT::Index* index, MOT::Table* inner_table = nullptr, MOT::Index* inner_index = nullptr) @@ -59,6 +61,18 @@ static bool InitCodeGenContext(JitTvmCodeGenContext* ctx, Builder* builder, MOT: return false; } + ctx->m_constCount = 0; + size_t allocSize = sizeof(Const) * MOT_JIT_MAX_CONST; + ctx->m_constValues = (Const*)MOT::MemSessionAlloc(allocSize); + if (ctx->m_constValues == nullptr) { + MOT_REPORT_ERROR(MOT_ERROR_OOM, + "JIT Compile", + "Failed to allocate %u bytes for constant array in code-generation context", + allocSize); + DestroyCodeGenContext(ctx); + return false; + } + return true; } @@ -130,9 +144,29 @@ static void DestroyCodeGenContext(JitTvmCodeGenContext* ctx) for (uint32_t i = 0; i < ctx->m_subQueryCount; ++i) { DestroyTableInfo(&ctx->m_subQueryTableInfo[i]); } + if (ctx->m_constValues != nullptr) { + MOT::MemSessionFree(ctx->m_constValues); + } } } +extern int AllocateConstId(JitTvmCodeGenContext* ctx, int type, Datum value, bool isNull) +{ + int res = -1; + if (ctx->m_constCount == MOT_JIT_MAX_CONST) { + MOT_REPORT_ERROR(MOT_ERROR_RESOURCE_LIMIT, + "JIT Compile", + "Cannot allocate constant identifier, reached limit of %u", + ctx->m_constCount); + } else { + res = ctx->m_constCount++; + ctx->m_constValues[res].consttype = type; + ctx->m_constValues[res].constvalue = value; + ctx->m_constValues[res].constisnull = isNull; + } + return res; +} + static JitContext* FinalizeCodegen(JitTvmCodeGenContext* ctx, int max_arg, JitCommandType command_type) { // do minimal verification and wrap up @@ -155,6 +189,16 @@ static JitContext* FinalizeCodegen(JitTvmCodeGenContext* ctx, int max_arg, JitCo return nullptr; } + // prepare global constant array + JitDatumArray datumArray = {}; + if (ctx->m_constCount > 0) { + if (!PrepareDatumArray(ctx->m_constValues, ctx->m_constCount, &datumArray)) { + MOT_LOG_ERROR("Failed to generate jitted code for query: Failed to prepare constant datum array"); + delete ctx->m_jittedQuery; + return nullptr; + } + } + // that's it, we are ready JitContext* jit_context = AllocJitContext(JIT_CONTEXT_GLOBAL); if (jit_context == nullptr) { @@ -178,6 +222,8 @@ static JitContext* FinalizeCodegen(JitTvmCodeGenContext* ctx, int max_arg, JitCo } jit_context->m_commandType = command_type; jit_context->m_subQueryCount = 0; + jit_context->m_constDatums.m_datumCount = datumArray.m_datumCount; + jit_context->m_constDatums.m_datums = datumArray.m_datums; return jit_context; }