/* * Copyright (c) 2018 MariaDB Corporation Ab * * Use of this software is governed by the Business Source License included * in the LICENSE.TXT file and at www.mariadb.com/bsl11. * * Change Date: 2024-10-14 * * On the date above, in accordance with the Business Source License, use * of this software will be governed by version 2 or later of the General * Public License. */ #pragma once #include #include #include #include #include #include #include #include #include namespace maxsql { /** * @class LEncInt * * @c LEncInt is a thin wrapper around a MySQL length encoded integer * that makes it simple to extract length encoded integers from packets. */ class LEncInt { public: /** * Constructor * * @param pData Pointer to the beginning of an length encoded integer. */ LEncInt(uint8_t* pData) { m_value = leint_value(pData); } /** * Constructor * * @param pData Pointer to a pointer to the beginning of an length * encoded integer. After the call, the pointer will be advanced * to point at the byte following the length encoded integer. */ LEncInt(uint8_t** ppData) { size_t nBytes = leint_bytes(*ppData); m_value = leint_value(*ppData); *ppData += nBytes; } /** * @return The value of the length encoded integer. */ uint64_t value() const { return m_value; } /** * @return The value of the length encoded integer. */ operator uint64_t() const { return value(); } /** * Write the integer to an @c std::ostream. * * @param out The stream. * * @return The stream provided as argument. */ std::ostream& print(std::ostream& out) const { out << m_value; return out; } private: uint64_t m_value; }; /** * Stream the integer to an @c std::ostream. * * @param out A stream. * @param i A length encoded integer. * * @return The stream provided as argument. */ inline std::ostream& operator<<(std::ostream& out, const LEncInt& i) { return i.print(out); } /** * @class LEncString * * @c LEncString is a thin wrapper around a MySQL length encoded string * that makes it simpler to use length encoded strings in conjunction with * @c char* and @c std::string strings. */ class LEncString { public: /** * @class iterator * * A _random access iterator_ to a @c LEncString. */ class iterator : public std::iterator { public: iterator(char* pS) : m_pS(pS) { } iterator& operator++() { mxb_assert(m_pS); ++m_pS; return *this; } iterator operator++(int) { iterator rv(*this); ++(*this); return rv; } iterator operator+(ptrdiff_t n) { mxb_assert(m_pS); iterator rv = m_pS; rv += n; return rv; } iterator& operator+=(ptrdiff_t n) { mxb_assert(m_pS); m_pS += n; return *this; } iterator& operator-=(ptrdiff_t n) { mxb_assert(m_pS); m_pS -= n; return *this; } ptrdiff_t operator-(const iterator& rhs) const { mxb_assert(m_pS); mxb_assert(rhs.m_pS); return m_pS - rhs.m_pS; } bool operator==(const iterator& rhs) const { return m_pS == rhs.m_pS; } bool operator!=(const iterator& rhs) const { return !(*this == rhs); } bool operator<(const iterator& rhs) const { return m_pS < rhs.m_pS; } bool operator<=(const iterator& rhs) const { return m_pS < rhs.m_pS; } bool operator>(const iterator& rhs) const { return m_pS > rhs.m_pS; } bool operator>=(const iterator& rhs) const { return m_pS > rhs.m_pS; } reference operator*() { mxb_assert(m_pS); return *m_pS; } reference operator[](ptrdiff_t i) { mxb_assert(m_pS); return m_pS[i]; } private: char* m_pS; }; /** * Constructor * * @param pData Pointer to the beginning of a length encoded string */ LEncString(uint8_t* pData) { // NULL is sent as 0xfb. See https://dev.mysql.com/doc/internals/en/com-query-response.html if (*pData != 0xfb) { m_pString = lestr_consume(&pData, &m_length); } else { m_pString = NULL; m_length = 0; } } /** * Constructor * * @param ppData Pointer to a pointer to the beginning of a length * encoded string. After the call, the pointer will point * one past the end of the length encoded string. */ LEncString(uint8_t** ppData) { // NULL is sent as 0xfb. See https://dev.mysql.com/doc/internals/en/com-query-response.html if (**ppData != 0xfb) { m_pString = lestr_consume(ppData, &m_length); } else { m_pString = NULL; m_length = 0; ++(*ppData); } } /** * Returns an iterator to the beginning of the string. * * @return A random access iterator. */ iterator begin() { return iterator(m_pString); } /** * Returns an iterator one past the end of the string. * * @return A random access iterator. */ iterator end() { return iterator(m_pString + m_length); } /** * @return The length of the string. */ size_t length() const { return m_length; } /** * @return True if the string is empty, false otherwise. */ bool empty() const { return m_length == 0; } /** * Compare for equality. * * @param s The string to compare with. * * @return True, if the strings are equal. */ bool eq(const LEncString& s) const { return m_length == s.m_length ? (memcmp(m_pString, s.m_pString, m_length) == 0) : false; } /** * Compare for equality. * * @param s The string to compare with. * * @return True, if the strings are equal. */ bool eq(const char* zString) const { size_t length = strlen(zString); return m_length == length ? (memcmp(m_pString, zString, m_length) == 0) : false; } /** * Compare for equality. * * @param s The string to compare with. * * @return True, if the strings are equal. */ bool eq(const std::string& s) const { return m_length == s.length() ? (memcmp(m_pString, s.data(), m_length) == 0) : false; } /** * Convert a @c LEncString to the equivalent @c std::string. * * @return An @c std::string */ std::string to_string() const { if (m_pString) { return std::string(m_pString, m_length); } else { return std::string("NULL"); } } /** * Print the string to a @c ostream. * * @param o The @c ostream to print the string to. * * @return The stream provided as parameter. */ std::ostream& print(std::ostream& o) const { o.write(m_pString, m_length); return o; } /** * Is NULL * * @return True, if the string represents a NULL value. */ bool is_null() const { return m_pString == NULL; } private: char* m_pString; /* type_names = { "Ok", "Err", "Eof", "LocalInfile", "Data" }; auto ind = static_cast(type); return os << ((ind < type_names.size()) ? type_names[ind] : "UNKNOWN"); } class ComEOF : public ComResponse { public: explicit ComEOF(const ComResponse& response) : ComResponse(response) { mxb_assert(is_eof()); extract_payload(); } uint16_t warnings() const { return m_warnings; } uint16_t status() const { return m_status; } bool more_results_exist() { return m_status & SERVER_MORE_RESULTS_EXIST; } private: void extract_payload() { auto pData = data(); m_warnings = *pData++; m_warnings += (*pData++ << 8); m_status = *pData++; m_status += (*pData++ << 8); } private: uint16_t m_warnings; uint16_t m_status; }; class ComOK : public ComResponse { public: explicit ComOK(const ComResponse& response) : ComResponse(response) { mxb_assert(is_ok()); extract_payload(); } uint64_t affected_rows() const { return m_affected_rows; } uint64_t last_insert_id() const { return m_last_insert_id; } uint16_t warnings() const { return m_warnings; } uint16_t status() const { return m_status; } bool more_results_exist() { return m_status & SERVER_MORE_RESULTS_EXIST; } private: void extract_payload() { auto pData = data(); m_affected_rows = LEncInt(&pData).value(); m_last_insert_id = LEncInt(&pData).value(); m_status = *pData++; m_status += (*pData++ << 8); m_warnings = *pData++; m_warnings += (*pData++ << 8); } private: uint64_t m_affected_rows; uint64_t m_last_insert_id; uint16_t m_status; uint16_t m_warnings; }; /** * @class ComRequest * * Base-class of all request packet classes. */ class ComRequest : public ComPacket { public: explicit ComRequest(const ComPacket& com_packet) : ComPacket(com_packet) , m_command(*payload()) { } uint8_t* data() { return payload() + 1; } uint8_t command() const { return m_command; } bool server_will_respond() const { return m_command != MXS_COM_STMT_SEND_LONG_DATA // what? && m_command != MXS_COM_QUIT && m_command != MXS_COM_STMT_CLOSE; } private: uint8_t m_command; }; /** * @class CQRColumnDef * * The column definition of the response of a @c ComQuery. * * @attention The name should not be used as such, but always using the * typedef @c ComQueryResponse::ColumnDef. */ class CQRColumnDef : public ComPacket { public: CQRColumnDef(const ComPacket& com_packet) : ComPacket(com_packet) , m_pData(payload()) , m_catalog(&m_pData) , m_schema(&m_pData) , m_table(&m_pData) , m_org_table(&m_pData) , m_name(&m_pData) , m_org_name(&m_pData) , m_length_fixed_fields(&m_pData) { m_character_set = *reinterpret_cast(m_pData); m_pData += 2; m_column_length = *reinterpret_cast(m_pData); m_pData += 4; m_type = static_cast(*m_pData); m_pData += 1; m_flags = *reinterpret_cast(m_pData); m_pData += 2; m_decimals = *m_pData; m_pData += 1; } const LEncString& catalog() const { return m_catalog; } const LEncString& schema() const { return m_schema; } const LEncString& table() const { return m_table; } const LEncString& org_table() const { return m_org_table; } const LEncString& name() const { return m_name; } const LEncString& org_name() const { return m_org_name; } enum_field_types type() const { return m_type; } std::string to_string() const { std::stringstream ss; ss << "\nCatalog : " << m_catalog << "\nSchema : " << m_schema << "\nTable : " << m_table << "\nOrg table : " << m_org_table << "\nName : " << m_name << "\nOrd name : " << m_org_name << "\nCharacer set : " << m_character_set << "\nColumn length: " << m_column_length << "\nType : " << (uint16_t)m_type << "\nFlags : " << m_flags << "\nDecimals : " << (uint16_t)m_decimals; return ss.str(); } private: uint8_t* m_pData; LEncString m_catalog; LEncString m_schema; LEncString m_table; LEncString m_org_table; LEncString m_name; LEncString m_org_name; LEncInt m_length_fixed_fields; uint16_t m_character_set; uint32_t m_column_length; enum_field_types m_type; uint16_t m_flags; uint8_t m_decimals; }; /** * @class CQRResultsetValue * * An instance of this class represents a value in a resultset row. As this * currently is for the purpose of the masking filter, it effectively is useful * for accessing NULL and string values. * * @attention The name should not be used as such, but instead either * @c ComQueryResponse::TextResultsetRow::Value or * @c ComQueryResponse::TextResultsetRow::Value. */ class CQRResultsetValue { public: CQRResultsetValue() : m_type(MYSQL_TYPE_NULL) , m_pData(NULL) { } CQRResultsetValue(enum_field_types type, uint8_t* pData) : m_type(type) , m_pData(pData) { } LEncString as_string() { mxb_assert(is_string()); return LEncString(m_pData); } bool is_null() const { return m_type == MYSQL_TYPE_NULL; } bool is_string() const { return is_string(m_type); } static bool is_string(enum_field_types type) { switch (type) { case MYSQL_TYPE_BLOB: case MYSQL_TYPE_LONG_BLOB: case MYSQL_TYPE_MEDIUM_BLOB: case MYSQL_TYPE_STRING: case MYSQL_TYPE_TINY_BLOB: case MYSQL_TYPE_VARCHAR: case MYSQL_TYPE_VAR_STRING: return true; // These, although returned as length-encoded strings, also in the case of // a binary resultset row, are not are not considered to be strings from the // perspective of masking. case MYSQL_TYPE_BIT: case MYSQL_TYPE_DECIMAL: case MYSQL_TYPE_ENUM: case MYSQL_TYPE_GEOMETRY: case MYSQL_TYPE_NEWDECIMAL: case MYSQL_TYPE_SET: return false; default: // Nothing else is considered to be strings even though, in the case of // a textual resultset, that's what they all are. return false; } } protected: enum_field_types m_type; private: uint8_t* m_pData; }; /** * @class CQRTextResultsetValue * * An instance of this class represents a value in a textual resultset row. * * @attention The name should not be used as such, but always using the * typedef @c ComQueryResponse::TextResultsetRow::Value. */ class CQRTextResultsetValue : public CQRResultsetValue { public: CQRTextResultsetValue(enum_field_types type, uint8_t* pData) : CQRResultsetValue(type, pData) { if (*pData == 0xfb) { m_type = MYSQL_TYPE_NULL; } } }; /** * @class CQRBinaryResultsetValue * * An instance of this class represents a value in a binary resultset row. * * @attention The name should not be used as such, but always using the * typedef @c ComQueryResponse::BinaryResultsetRow::Value. */ typedef CQRResultsetValue CQRBinaryResultsetValue; /** * @class CQRTextResultsetRowIterator * * An STL compatible iterator that iterates over the values in a textual resultset. * * @attention The name should not be used as such, but always using the * typedef @c ComQueryResponse::TextResultset::iterator. */ class CQRTextResultsetRowIterator : public std::iterator { public: typedef CQRTextResultsetValue Value; CQRTextResultsetRowIterator(uint8_t* pData, const std::vector& types) : m_pData(pData) , m_iTypes(types.begin()) { } CQRTextResultsetRowIterator(uint8_t* pData) : m_pData(pData) { } CQRTextResultsetRowIterator& operator++() { // In the textual protocol, every value is a length encoded string. LEncString s(&m_pData); ++m_iTypes; return *this; } CQRTextResultsetRowIterator operator++(int) { CQRTextResultsetRowIterator rv(*this); ++(*this); return rv; } bool operator==(const CQRTextResultsetRowIterator& rhs) const { return m_pData == rhs.m_pData; } bool operator!=(const CQRTextResultsetRowIterator& rhs) const { return !(*this == rhs); } CQRTextResultsetValue operator*() { return Value(*m_iTypes, m_pData); } private: uint8_t* m_pData; std::vector::const_iterator m_iTypes; }; /** * @class CQRBinaryResultsetRowIterator * * An STL compatible iterator that iterates over the values in a binary resultset. * * @attention The name should not be used as such, but always using the * typedef @c ComQueryResponse::BinaryResultset::iterator. */ class CQRBinaryResultsetRowIterator : public std::iterator { public: typedef CQRBinaryResultsetValue Value; /** * A bit_iterator is an iterator to bits in an array of bytes. * * Specifically, it is capable of iterating across the NULL bitmask of * a binary resultset. */ class bit_iterator { public: bit_iterator(uint8_t* pData = 0) : m_pData(pData) , m_mask(1 << 2)// The two first bits are not used. { } /** * @return True, if the current bit is on. That is, if the corresponding * column value is NULL. */ bool operator*() const { return (*m_pData & m_mask) ? true : false; } bit_iterator& operator++() { m_mask <<= 1; // Move to the next bit. if (m_mask == 0) { // We moved past the byte, so advance to next byte and the first bit of that. ++m_pData; m_mask = 1; } return *this; } bit_iterator operator++(int) { bit_iterator rv(*this); ++(*this); return rv; } private: uint8_t* m_pData; /*< Pointer to the NULL bitmap of a binary resultset row. */ uint8_t m_mask; /*< Mask representing the current bit of the current byte. */ }; CQRBinaryResultsetRowIterator(uint8_t* pData, const std::vector& types) : m_pData(pData) , m_iTypes(types.begin()) , m_iNulls(pData + 1) { mxb_assert(*m_pData == 0); ++m_pData; // See https://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html size_t nNull_bytes = (types.size() + 7 + 2) / 8; m_pData += nNull_bytes; } CQRBinaryResultsetRowIterator(uint8_t* pData) : m_pData(pData) { } CQRBinaryResultsetRowIterator& operator++() { // See https://dev.mysql.com/doc/internals/en/binary-protocol-value.html switch (*m_iTypes) { case MYSQL_TYPE_BIT: case MYSQL_TYPE_BLOB: case MYSQL_TYPE_DECIMAL: case MYSQL_TYPE_ENUM: case MYSQL_TYPE_GEOMETRY: case MYSQL_TYPE_LONG_BLOB: case MYSQL_TYPE_MEDIUM_BLOB: case MYSQL_TYPE_NEWDATE: case MYSQL_TYPE_NEWDECIMAL: case MYSQL_TYPE_SET: case MYSQL_TYPE_STRING: case MYSQL_TYPE_TINY_BLOB: case MYSQL_TYPE_VARCHAR: case MYSQL_TYPE_VAR_STRING: { LEncString s(&m_pData); // Advance m_pData to the byte following the string. } break; case MYSQL_TYPE_LONGLONG: m_pData += 8; break; case MYSQL_TYPE_LONG: case MYSQL_TYPE_INT24: m_pData += 4; break; case MYSQL_TYPE_SHORT: case MYSQL_TYPE_YEAR: m_pData += 2; break; case MYSQL_TYPE_TINY: m_pData += 1; break; case MYSQL_TYPE_DOUBLE: m_pData += 8; break; case MYSQL_TYPE_FLOAT: m_pData += 4; break; case MYSQL_TYPE_DATE: case MYSQL_TYPE_DATETIME: case MYSQL_TYPE_TIMESTAMP: { // A byte specifying the length, followed by that many bytes. // Either 0, 4, 7 or 11. uint8_t len = *m_pData++; m_pData += len; } break; case MYSQL_TYPE_TIME: { // A byte specifying the length, followed by that many bytes. // Either 0, 8 or 12. uint8_t len = *m_pData++; m_pData += len; } break; case MYSQL_TYPE_NULL: break; case MAX_NO_FIELD_TYPES: mxb_assert(!true); break; default: break; } ++m_iNulls; ++m_iTypes; return *this; } CQRBinaryResultsetRowIterator operator++(int) { CQRBinaryResultsetRowIterator rv(*this); ++(*this); return rv; } bool operator==(const CQRBinaryResultsetRowIterator& rhs) const { return m_pData == rhs.m_pData; } bool operator!=(const CQRBinaryResultsetRowIterator& rhs) const { return !(*this == rhs); } reference operator*() { if (*m_iNulls) { return Value(); } else { return Value(*m_iTypes, m_pData); } } private: uint8_t* m_pData; std::vector::const_iterator m_iTypes; bit_iterator m_iNulls; }; /** * @template CQRResultsetRow * * A template that when instantiated either represents a textual or a * binary resultset row. */ template class CQRResultsetRow : public ComPacket { public: typedef typename Iterator::Value Value; typedef Iterator iterator; CQRResultsetRow(GWBUF* pPacket, const std::vector& types) : ComPacket(pPacket) , m_types(types) { } CQRResultsetRow(const ComResponse& packet, const std::vector& types) : ComPacket(packet) , m_types(types) { } iterator begin() { return iterator(payload(), m_types); } iterator end() { uint8_t* pEnd = GWBUF_DATA(m_pPacket) + GWBUF_LENGTH(m_pPacket); return iterator(pEnd); } private: const std::vector& m_types; }; /** * @class CQRTextResultsetRow * * An instance of this class represents a textual resultset row. */ typedef CQRResultsetRow CQRTextResultsetRow; /** * @class CQRBinaryResultsetRow * * An instance of this class represents a binary resultset row. */ typedef CQRResultsetRow CQRBinaryResultsetRow; /** * @class ComQueryResponse * * An instance of this class represents the response to a @c ComQuery. */ class ComQueryResponse : public ComPacket { public: typedef CQRColumnDef ColumnDef; typedef CQRTextResultsetRow TextResultsetRow; typedef CQRBinaryResultsetRow BinaryResultsetRow; ComQueryResponse(const ComPacket& com_packet) : ComPacket(com_packet) , m_nFields(payload()) { } uint64_t nFields() const { return m_nFields; } private: LEncInt m_nFields; }; }