From 43c81baf70fdaf5e94790742a23bac3c706f3c21 Mon Sep 17 00:00:00 2001 From: Johan Wikman Date: Tue, 10 Jan 2017 15:18:51 +0200 Subject: [PATCH] Add utility classes for dealing with binary resultsets --- server/modules/filter/masking/mysql.hh | 325 ++++++++++++++++++++++++- 1 file changed, 312 insertions(+), 13 deletions(-) diff --git a/server/modules/filter/masking/mysql.hh b/server/modules/filter/masking/mysql.hh index 05200961f..98aaf94a9 100644 --- a/server/modules/filter/masking/mysql.hh +++ b/server/modules/filter/masking/mysql.hh @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -589,7 +590,7 @@ public: m_column_length = *reinterpret_cast(m_pI); m_pI += 4; - m_type = *m_pI; + m_type = static_cast(*m_pI); m_pI += 1; m_flags = *reinterpret_cast(m_pI); @@ -605,6 +606,7 @@ public: const LEncString& org_table() const { return m_org_table; } const LEncString& name() const { return m_name; } const LEncString& org_name() const { return m_org_name; } + enum_field_types type() const { return m_type; } std::string to_string() const { @@ -625,18 +627,18 @@ public: } private: - LEncString m_catalog; - LEncString m_schema; - LEncString m_table; - LEncString m_org_table; - LEncString m_name; - LEncString m_org_name; - LEncInt m_length_fixed_fields; - uint16_t m_character_set; - uint32_t m_column_length; - uint8_t m_type; - uint16_t m_flags; - uint8_t m_decimals; + LEncString m_catalog; + LEncString m_schema; + LEncString m_table; + LEncString m_org_table; + LEncString m_name; + LEncString m_org_name; + LEncInt m_length_fixed_fields; + uint16_t m_character_set; + uint32_t m_column_length; + enum_field_types m_type; + uint16_t m_flags; + uint8_t m_decimals; }; class ComQueryResponseRow : public ComPacket @@ -707,11 +709,308 @@ public: } }; +class ComQueryResponseBinaryRow : public ComPacket +{ +public: + /** + * An instance of Value represents a value in a binary resultset. + */ + class Value + { + public: + Value() + : m_type(MYSQL_TYPE_NULL) + , m_pData(NULL) + { + } + + Value(enum_field_types type, uint8_t* pData) + : m_type(type) + , m_pData(pData) + { + } + + enum_field_types type() const + { + return m_type; + } + + LEncString as_string() + { + ss_dassert(is_string(m_type)); + return LEncString(m_pData); + } + + bool is_string() const + { + return is_string(m_type); + } + + static bool is_string(enum_field_types type) + { + switch (type) + { + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_VARCHAR: + case MYSQL_TYPE_VAR_STRING: + return true; + + // These, although returned as length-encoded strings are not considered + // to be strings from the perspective of masking. + case MYSQL_TYPE_BIT: + case MYSQL_TYPE_BLOB: + case MYSQL_TYPE_DECIMAL: + case MYSQL_TYPE_ENUM: + case MYSQL_TYPE_GEOMETRY: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_NEWDECIMAL: + case MYSQL_TYPE_SET: + case MYSQL_TYPE_TINY_BLOB: + return false; + + default: + return false; + } + } + + private: + enum_field_types m_type; + uint8_t* m_pData; + }; + + /** + * iterator is an iterator to values in a binary resultset. + */ + class iterator : public std::iterator + { + public: + /** + * A bit_iterator is an iterator to bits in an array of bytes. + * + * Specifically, it is capable of iterating across the NULL bitmask of + * a binary resultset. + */ + class bit_iterator + { + public: + bit_iterator(uint8_t* pData = 0) + : m_pData(pData) + , m_mask(1 << 2) // The two first bits are not used. + { + } + + /** + * @return True, if the current bit is on. That is, if the corresponding + * column value is NULL. + */ + bool operator * () const + { + return (*m_pData & m_mask) ? true : false; + } + + bit_iterator& operator ++ () + { + m_mask <<= 1; // Move to the next bit. + if (m_mask == 0) + { + // We moved past the byte, so advance to next byte and the first bit of that. + ++m_pData; + m_mask = 1; + } + + return *this; + } + + bit_iterator operator ++ (int) + { + bit_iterator rv(*this); + ++(*this); + return rv; + } + + private: + uint8_t* m_pData; /*< Pointer to the NULL bitmap of a binary resultset row. */ + uint8_t m_mask; /*< Mask representing the current bit of the current byte. */ + }; + + iterator(uint8_t* pData, const std::vector& types) + : m_pData(pData) + , m_iTypes(types.begin()) + , m_iNulls(pData + 1) + { + ss_dassert(*m_pData == 0); + ++m_pData; + + // See https://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html + size_t nNull_bytes = (types.size() + 7 + 2) / 8; + + m_pData += nNull_bytes; + } + + iterator(uint8_t* pData) + : m_pData(pData) + { + } + + iterator& operator++() + { + // See https://dev.mysql.com/doc/internals/en/binary-protocol-value.html + switch (*m_iTypes) + { + case MYSQL_TYPE_BIT: + case MYSQL_TYPE_BLOB: + case MYSQL_TYPE_DECIMAL: + case MYSQL_TYPE_ENUM: + case MYSQL_TYPE_GEOMETRY: + case MYSQL_TYPE_LONG_BLOB: + case MYSQL_TYPE_MEDIUM_BLOB: + case MYSQL_TYPE_NEWDATE: + case MYSQL_TYPE_NEWDECIMAL: + case MYSQL_TYPE_SET: + case MYSQL_TYPE_STRING: + case MYSQL_TYPE_TINY_BLOB: + case MYSQL_TYPE_VARCHAR: + case MYSQL_TYPE_VAR_STRING: + { + LEncString s(&m_pData); // Advance m_pData to the byte following the string. + } + break; + + case MYSQL_TYPE_LONGLONG: + m_pData += 8; + break; + + case MYSQL_TYPE_LONG: + case MYSQL_TYPE_INT24: + m_pData += 4; + break; + + case MYSQL_TYPE_SHORT: + case MYSQL_TYPE_YEAR: + m_pData += 2; + break; + + case MYSQL_TYPE_TINY: + m_pData += 1; + break; + + case MYSQL_TYPE_DOUBLE: + m_pData += 8; + break; + + case MYSQL_TYPE_FLOAT: + m_pData += 4; + break; + + case MYSQL_TYPE_DATE: + case MYSQL_TYPE_DATETIME: + case MYSQL_TYPE_TIMESTAMP: + { + // A byte specifying the length, followed by that many bytes. + // Either 0, 4, 7 or 11. + uint8_t len = *m_pData++; + m_pData += len; + } + break; + + case MYSQL_TYPE_TIME: + { + // A byte specifying the length, followed by that many bytes. + // Either 0, 8 or 12. + uint8_t len = *m_pData++; + m_pData += len; + } + break; + + case MYSQL_TYPE_NULL: + break; + + case MAX_NO_FIELD_TYPES: + ss_dassert(!true); + break; + } + + ++m_iNulls; + ++m_iTypes; + + return *this; + } + + iterator operator++(int) + { + iterator rv(*this); + ++(*this); + return rv; + } + + bool operator == (const iterator& rhs) const + { + return m_pData == rhs.m_pData; + } + + bool operator != (const iterator& rhs) const + { + return !(*this == rhs); + } + + reference operator*() + { + if (*m_iNulls) + { + return Value(); + } + else + { + return Value(*m_iTypes, m_pData); + } + } + + private: + uint8_t* m_pData; + std::vector::const_iterator m_iTypes; + bit_iterator m_iNulls; + }; + + ComQueryResponseBinaryRow(GWBUF* pPacket, + const std::vector& types) + : ComPacket(pPacket) + , m_types(types) + { + } + + ComQueryResponseBinaryRow(const ComResponse& packet, + const std::vector& types) + : ComPacket(packet) + , m_types(types) + { + } + + iterator begin() + { + return iterator(m_pI, m_types); + } + + iterator end() + { + uint8_t* pEnd = GWBUF_DATA(m_pPacket) + GWBUF_LENGTH(m_pPacket); + return iterator(pEnd); + } + +private: + const std::vector& m_types; +}; + class ComQueryResponse : public ComPacket { public: typedef ComQueryResponseColumnDef ColumnDef; typedef ComQueryResponseRow Row; + typedef ComQueryResponseBinaryRow BinaryRow; ComQueryResponse(GWBUF* pPacket) : ComPacket(pPacket)