diff --git a/contrib/dblink/dblink.cpp b/contrib/dblink/dblink.cpp index 5fce08f26..195012503 100644 --- a/contrib/dblink/dblink.cpp +++ b/contrib/dblink/dblink.cpp @@ -55,11 +55,10 @@ #include "access/heapam.h" #include "commands/extension.h" #include "dblink.h" +#include "storage/ipc.h" PG_MODULE_MAGIC; - - /* * Internal declarations */ @@ -183,6 +182,26 @@ static bool UseODBCLinker(char* connstr); } \ } while (0) +static void DblinkQuitAndClean(int code, Datum arg) +{ + if (PCONN->linker != NULL) { + PCONN->linker->finish(); + PCONN->linker = NULL; + } + + HASH_SEQ_STATUS status; + remoteConnHashEnt* hentry = NULL; + + if (REMOTE_CONN_HASH) { + hash_seq_init(&status, REMOTE_CONN_HASH); + while ((hentry = (remoteConnHashEnt*)hash_seq_search(&status)) != NULL) { + hentry->rconn->linker->finish(); + } + hash_destroy(REMOTE_CONN_HASH); + REMOTE_CONN_HASH = NULL; + } +} + void set_extension_index(uint32 index) { dblink_index = index; @@ -198,6 +217,7 @@ void init_session_vars(void) psc->pconn = NULL; psc->remoteConnHash = NULL; + psc->needFree = TRUE; } dblink_session_context* get_session_context() @@ -208,6 +228,15 @@ dblink_session_context* get_session_context() return (dblink_session_context*)u_sess->attr.attr_common.extension_session_vars_array[dblink_index]; } +Linker::Linker() +{ + if (ENABLE_THREAD_POOL) { + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("dblink not support in thread pool"))); + } +} + PQLinker::PQLinker(char* connstr) { this->conn = NULL; @@ -541,7 +570,12 @@ void PQLinker::getNotify(ReturnSetInfo* rsinfo) ODBCLinker::ODBCLinker(char* connstr_or_name) { + this->msg = (char*)MemoryContextAlloc(SESS_GET_MEM_CXT_GROUP + (MEMORY_CONTEXT_COMMUNICATION), sizeof(char*) * MAX_ERR_MSG_LEN); + errno_t rc = strcpy_s(this->msg, MAX_ERR_MSG_LEN, "no error message"); + securec_check(rc, "\0", "\0"); SQLINTEGER error = 0; + error = SQLAllocHandle(SQL_HANDLE_ENV,SQL_NULL_HANDLE, &this->envHandle); if ((error != SQL_SUCCESS) && (error != SQL_SUCCESS_WITH_INFO)) { ereport(ERROR, @@ -568,7 +602,7 @@ ODBCLinker::ODBCLinker(char* connstr_or_name) /* atuo commit is the default value */ error = SQLConnect(this->connHandle, linfo.drivername, SQL_NTS, linfo.username, SQL_NTS, linfo.password, SQL_NTS); - errno_t rc = memset_s(connstr_or_name, len, 0, len); + rc = memset_s(connstr_or_name, len, 0, len); securec_check(rc, "\0", "\0"); if ((error != SQL_SUCCESS) && (error != SQL_SUCCESS_WITH_INFO)) { @@ -621,23 +655,17 @@ text* ODBCLinker::exec(char* conname, const char* sql, bool fail) } this->stmt = stmt; if ((error != SQL_SUCCESS) && (error != SQL_SUCCESS_WITH_INFO)) { - char* msg = this->errorMsg(); + SQLError(this->envHandle, this->connHandle, this->stmt, NULL, NULL, (SQLCHAR*)this->msg, MAX_ERR_MSG_LEN, NULL); ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), - errmsg("Error exec\n%s", msg))); + errmsg("Error exec\n%s", this->msg))); } - return cstring_to_text("OK"); } char* ODBCLinker::errorMsg() { - if (this->stmt == NULL) { - return NULL; - } - char* msg = (char*)palloc(sizeof(char) * MAX_ERR_MSG_LEN); - SQLError(this->envHandle, this->connHandle, this->stmt, NULL, NULL, (SQLCHAR*)msg, MAX_ERR_MSG_LEN, NULL); - return msg; + return this->msg; } int ODBCLinker::isBusy() @@ -681,12 +709,11 @@ int ODBCLinker::sendQuery(char *sql) } this->stmt = stmt; if ((error != SQL_SUCCESS) && (error != SQL_SUCCESS_WITH_INFO)) { - char* msg = this->errorMsg(); + SQLError(this->envHandle, this->connHandle, this->stmt, NULL, NULL, (SQLCHAR*)this->msg, MAX_ERR_MSG_LEN, NULL); ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), - errmsg("Error exec\n%s", msg))); + errmsg("Error exec\n%s", this->msg))); } - return 1; } @@ -705,15 +732,11 @@ char* ODBCLinker::close(char* conname, char* sql, bool fail) void ODBCLinker::getResult(char* conname, FunctionCallInfo fcinfo, char* sql, bool fail) { prepTuplestoreResult(fcinfo); - ReturnSetInfo* rsinfo = (ReturnSetInfo*)fcinfo->resultinfo; storeInfo sinfo; bool isFirst = true; SQLINTEGER error = 0; SQLSMALLINT nfields = 0; - /* prepTuplestoreResult must have been called previously */ - Assert(rsinfo->returnMode == SFRM_Materialize); - /* initialize storeInfo to empty */ (void)memset_s(&sinfo, sizeof(sinfo), 0, sizeof(sinfo)); sinfo.fcinfo = fcinfo; @@ -801,6 +824,11 @@ Datum dblink_connect(PG_FUNCTION_ARGS) DBLINK_INIT; + if (get_session_context()->needFree) { + on_proc_exit(DblinkQuitAndClean, 0); + get_session_context()->needFree = FALSE; + } + if (PG_NARGS() == 2) { conname_or_str = text_to_cstring(PG_GETARG_TEXT_PP(1)); conname = text_to_cstring(PG_GETARG_TEXT_PP(0)); @@ -830,7 +858,7 @@ Datum dblink_connect(PG_FUNCTION_ARGS) if (PCONN->linker) { PCONN->linker->finish(); } - PCONN->linker = olinker; + PCONN->linker = olinker; } } else { /* first check for valid foreign data server */ @@ -850,7 +878,7 @@ Datum dblink_connect(PG_FUNCTION_ARGS) if (PCONN->linker) { PCONN->linker->finish(); } - PCONN->linker = plinker; + PCONN->linker = plinker; } } PG_RETURN_TEXT_P(cstring_to_text("OK")); diff --git a/contrib/dblink/dblink.h b/contrib/dblink/dblink.h index aa7490772..cbc4b8e79 100644 --- a/contrib/dblink/dblink.h +++ b/contrib/dblink/dblink.h @@ -64,6 +64,7 @@ typedef struct LinkInfo { class Linker : public BaseObject { public: + Linker(); virtual void finish() = 0; virtual text* exec(char* conname, const char* sql, bool fail) = 0; virtual char* errorMsg() = 0; @@ -85,6 +86,7 @@ typedef struct remoteConn { typedef struct dblink_session_context { remoteConn* pconn; HTAB* remoteConnHash; + bool needFree; } dblink_session_context; /* @@ -121,12 +123,12 @@ public: void getNotify(ReturnSetInfo* rsinfo); }; - class ODBCLinker : public Linker { public: SQLHENV envHandle; /* Handle ODBC environment */ SQLHDBC connHandle; /* Handle connection */ SQLHSTMT stmt; /* Handle sql */ + char* msg; /* error message */ public: ODBCLinker(char* connstr_or_name); void finish();