From 0a68a3f4c1afd3283f810d2e69b6058de7e30bf9 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Wed, 16 Sep 2015 20:21:57 +0800 Subject: [PATCH] tidb-server: address comment --- Makefile | 2 +- README.md | 4 +-- tidb-server/main.go | 12 +++---- tidb-server/server/conn.go | 52 +++++++++++++++---------------- tidb-server/server/conn_stmt.go | 24 +++++++------- tidb-server/server/driver.go | 4 +-- tidb-server/server/driver_tidb.go | 19 ++++++----- tidb-server/server/server.go | 11 +++---- tidb-server/server/server_test.go | 16 +++++----- util/hack/hack.go | 17 ++++++++-- util/hack/hack_test.go | 13 ++++++++ 11 files changed, 98 insertions(+), 76 deletions(-) diff --git a/Makefile b/Makefile index 1ca5a20ad8..179d876af5 100644 --- a/Makefile +++ b/Makefile @@ -82,4 +82,4 @@ interpreter: @cd interpreter && $(GO) build -ldflags '$(LDFLAGS)' server: - @cd tidb-server && $(GO) build + @cd tidb-server && $(GO) build -ldflags '$(LDFLAGS)' diff --git a/README.md b/README.md index da616771cb..8b8d1de031 100644 --- a/README.md +++ b/README.md @@ -62,8 +62,8 @@ See [USAGE.md](./docs/USAGE.md) for detailed instructions to use TiDB as library make tidb-server cd tidb-server && ./tidb-server ``` -The default server address is `127.0.0.1:4000`. -After you started tidb-server, you can use official mysql client to connect to tidb. +The default server address is `:4000`. +After you started tidb-server, you can use official mysql client to connect to TiDB. ``` mysql -h 127.0.0.1 -P 4000 -D test ``` diff --git a/tidb-server/main.go b/tidb-server/main.go index 72a10e2b49..a74a45af64 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -24,6 +24,7 @@ import ( "github.com/ngaut/log" "github.com/pingcap/tidb" "github.com/pingcap/tidb/tidb-server/server" + "github.com/pingcap/tidb/util/printer" ) var ( @@ -40,7 +41,7 @@ var ( ) func main() { - fmt.Printf("Git Commit Hash:%s\nUTC Build Time :%s\n", githash, buildstamp) + printer.PrintTiDBInfo() runtime.GOMAXPROCS(runtime.NumCPU()) flag.Parse() @@ -55,17 +56,16 @@ func main() { log.SetLevelByString(cfg.LogLevel) store, err := tidb.NewStore(fmt.Sprintf("%s://%s", *store, *storePath)) if err != nil { - log.Error(err.Error()) - return + log.Fatal(err) } server.CreateTiDBTestDatabase(store) - var svr *server.Server + var driver server.IDriver driver = server.NewTiDBDriver(store) + var svr *server.Server svr, err = server.NewServer(cfg, driver) if err != nil { - log.Error(err.Error()) - return + log.Fatal(err) } sc := make(chan os.Signal, 1) diff --git a/tidb-server/server/conn.go b/tidb-server/server/conn.go index 3a17153034..1094c0c291 100644 --- a/tidb-server/server/conn.go +++ b/tidb-server/server/conn.go @@ -93,39 +93,39 @@ func (cc *clientConn) Close() error { func (cc *clientConn) writeInitialHandshake() error { data := make([]byte, 4, 128) - //min version 10 + // min version 10 data = append(data, 10) - //server version[00] + // server version[00] data = append(data, mysql.ServerVersion...) data = append(data, 0) - //connection id + // connection id data = append(data, byte(cc.connectionID), byte(cc.connectionID>>8), byte(cc.connectionID>>16), byte(cc.connectionID>>24)) - //auth-plugin-data-part-1 + // auth-plugin-data-part-1 data = append(data, cc.salt[0:8]...) - //filter [00] + // filter [00] data = append(data, 0) - //capability flag lower 2 bytes, using default capability here + // capability flag lower 2 bytes, using default capability here data = append(data, byte(defaultCapability), byte(defaultCapability>>8)) - //charset, utf-8 default + // charset, utf-8 default data = append(data, uint8(mysql.DefaultCollationID)) //status data = append(data, dumpUint16(mysql.ServerStatusAutocommit)...) - //below 13 byte may not be used - //capability flag upper 2 bytes, using default capability here + // below 13 byte may not be used + // capability flag upper 2 bytes, using default capability here data = append(data, byte(defaultCapability>>16), byte(defaultCapability>>24)) //filter [0x15], for wireshark dump, value is 0x15 data = append(data, 0x15) - //reserved 10 [00] + // reserved 10 [00] data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) - //auth-plugin-data-part-2 + // auth-plugin-data-part-2 data = append(data, cc.salt[8:]...) - //filter [00] + // filter [00] data = append(data, 0) err := cc.writePacket(data) if err != nil { - return err + return errors.Trace(err) } - return cc.flush() + return errors.Trace(cc.flush()) } func (cc *clientConn) readPacket() ([]byte, error) { @@ -167,26 +167,25 @@ func calcPassword(scramble, password []byte) []byte { func (cc *clientConn) readHandshakeResponse() error { data, err := cc.readPacket() - if err != nil { return errors.Trace(err) } pos := 0 - //capability + // capability cc.capability = binary.LittleEndian.Uint32(data[:4]) pos += 4 - //skip max packet size + // skip max packet size pos += 4 - //charset, skip, if you want to use another charset, use set names + // charset, skip, if you want to use another charset, use set names cc.collation = data[pos] pos++ - //skip reserved 23[00] + // skip reserved 23[00] pos += 23 - //user name + // user name cc.user = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) pos += len(cc.user) + 1 - //auth length and auth + // auth length and auth authLen := int(data[pos]) pos++ auth := data[pos : pos+authLen] @@ -200,8 +199,8 @@ func (cc *clientConn) readHandshakeResponse() error { if len(data[pos:]) == 0 { return nil } - - cc.dbname = string(data[pos : pos+bytes.IndexByte(data[pos:], 0)]) + idx := bytes.IndexByte(data[pos:], 0) + cc.dbname = string(data[pos : pos+idx]) } return nil @@ -224,6 +223,7 @@ func (cc *clientConn) Run() { data, err := cc.readPacket() if err != nil { if errors2.ErrorNotEqual(err, io.EOF) { + // TODO: another type of error is also normal, check it. log.Info(err) } return @@ -376,18 +376,18 @@ func (cc *clientConn) handleFieldList(sql string) (err error) { parts := strings.Split(sql, "\x00") columns, err := cc.ctx.FieldList(parts[0]) if err != nil { - return + return errors.Trace(err) } data := make([]byte, 4, 1024) for _, v := range columns { data = data[0:4] data = append(data, v.Dump(cc.alloc)...) if err := cc.writePacket(data); err != nil { - return err + return errors.Trace(err) } } if err := cc.writeEOF(); err != nil { - return err + return errors.Trace(err) } return errors.Trace(cc.flush()) } diff --git a/tidb-server/server/conn_stmt.go b/tidb-server/server/conn_stmt.go index 7d6232d343..dc21e218f6 100644 --- a/tidb-server/server/conn_stmt.go +++ b/tidb-server/server/conn_stmt.go @@ -19,6 +19,7 @@ import ( "math" "strconv" + "github.com/juju/errors" mysql "github.com/pingcap/tidb/mysqldef" "github.com/pingcap/tidb/util/hack" ) @@ -192,11 +193,11 @@ func parseStmtArgs(args []interface{}, boundParams [][]byte, nullBitmap, paramTy err = mysql.ErrMalformPacket return } - + valU16 := binary.LittleEndian.Uint16(paramValues[pos : pos+2]) if isUnsigned { - args[i] = uint64(binary.LittleEndian.Uint16(paramValues[pos : pos+2])) + args[i] = uint64(valU16) } else { - args[i] = int64((binary.LittleEndian.Uint16(paramValues[pos : pos+2]))) + args[i] = int64(valU16) } pos += 2 continue @@ -206,11 +207,11 @@ func parseStmtArgs(args []interface{}, boundParams [][]byte, nullBitmap, paramTy err = mysql.ErrMalformPacket return } - + valU32 := binary.LittleEndian.Uint32(paramValues[pos : pos+4]) if isUnsigned { - args[i] = uint64(binary.LittleEndian.Uint32(paramValues[pos : pos+4])) + args[i] = uint64(valU32) } else { - args[i] = int64(binary.LittleEndian.Uint32(paramValues[pos : pos+4])) + args[i] = int64(valU32) } pos += 4 continue @@ -220,11 +221,11 @@ func parseStmtArgs(args []interface{}, boundParams [][]byte, nullBitmap, paramTy err = mysql.ErrMalformPacket return } - + valU64 := binary.LittleEndian.Uint64(paramValues[pos : pos+8]) if isUnsigned { - args[i] = binary.LittleEndian.Uint64(paramValues[pos : pos+8]) + args[i] = valU64 } else { - args[i] = int64(binary.LittleEndian.Uint64(paramValues[pos : pos+8])) + args[i] = int64(valU64) } pos += 8 continue @@ -268,11 +269,10 @@ func parseStmtArgs(args []interface{}, boundParams [][]byte, nullBitmap, paramTy if !isNull { args[i] = hack.String(v) - continue } else { args[i] = nil - continue } + continue default: err = fmt.Errorf("Stmt Unknown FieldType %d", tp) return @@ -289,7 +289,7 @@ func (cc *clientConn) handleStmtClose(data []byte) (err error) { stmtID := int(binary.LittleEndian.Uint32(data[0:4])) stmt := cc.ctx.GetStatement(stmtID) if stmt != nil { - stmt.Close() + return errors.Trace(stmt.Close()) } return } diff --git a/tidb-server/server/driver.go b/tidb-server/server/driver.go index d572e5d09a..e481b8e47e 100644 --- a/tidb-server/server/driver.go +++ b/tidb-server/server/driver.go @@ -42,7 +42,7 @@ type IContext interface { // Prepare prepares a statement. Prepare(sql string) (statement IStatement, columns, params []*ColumnInfo, err error) - // GetStatement get IStatement by statement ID. + // GetStatement gets IStatement by statement ID. GetStatement(stmtID int) IStatement // FieldList returns columns of a table. @@ -69,7 +69,7 @@ type IStatement interface { // BoundParams returns bound parameters. BoundParams() [][]byte - // Reset remove all bound parameters. + // Reset removes all bound parameters. Reset() // Close closes the statement. diff --git a/tidb-server/server/driver_tidb.go b/tidb-server/server/driver_tidb.go index 80acdd0b4f..611f7bfe6b 100644 --- a/tidb-server/server/driver_tidb.go +++ b/tidb-server/server/driver_tidb.go @@ -24,12 +24,12 @@ import ( "github.com/pingcap/tidb/util/errors2" ) -// TiDBDriver implements IDriver +// TiDBDriver implements IDriver. type TiDBDriver struct { store kv.Storage } -// NewTiDBDriver creates a new TiDBDriver +// NewTiDBDriver creates a new TiDBDriver. func NewTiDBDriver(store kv.Storage) *TiDBDriver { driver := &TiDBDriver{ store: store, @@ -37,7 +37,7 @@ func NewTiDBDriver(store kv.Storage) *TiDBDriver { return driver } -// TiDBContext implements IContext +// TiDBContext implements IContext. type TiDBContext struct { session tidb.Session currentDB string @@ -45,7 +45,7 @@ type TiDBContext struct { stmts map[int]*TiDBStatement } -// TiDBStatement implements IStatement +// TiDBStatement implements IStatement. type TiDBStatement struct { id uint32 numParams int @@ -85,17 +85,17 @@ func (ts *TiDBStatement) AppendParam(paramID int, data []byte) error { return nil } -// NumParams implements IStatement NumParams method. +// NumParams implements IStatement NumParams method. func (ts *TiDBStatement) NumParams() int { return ts.numParams } -// BoundParams implements IStatement BoundParams method. +// BoundParams implements IStatement BoundParams method. func (ts *TiDBStatement) BoundParams() [][]byte { return ts.boundParams } -// Reset implements IStatement Reset method. +// Reset implements IStatement Reset method. func (ts *TiDBStatement) Reset() { for i := range ts.boundParams { ts.boundParams[i] = nil @@ -113,7 +113,7 @@ func (ts *TiDBStatement) Close() error { return nil } -// OpenCtx implements IDriver +// OpenCtx implements IDriver. func (qd *TiDBDriver) OpenCtx(capability uint32, collation uint8, dbname string) (IContext, error) { session, _ := tidb.CreateSession(qd.store) session.SetClientCapability(capability) @@ -136,7 +136,7 @@ func (tc *TiDBContext) Status() uint16 { return tc.session.Status() } -// LastInsertID implements IContext Status method. +// LastInsertID implements IContext LastInsertID method. func (tc *TiDBContext) LastInsertID() uint64 { return tc.session.LastInsertID() } @@ -282,6 +282,5 @@ func CreateTiDBTestDatabase(store kv.Storage) { log.Fatal(err) } tc.Execute("CREATE DATABASE IF NOT EXISTS test") - tc.Execute("CREATE DATABASE IF NOT EXISTS gotest") tc.Close() } diff --git a/tidb-server/server/server.go b/tidb-server/server/server.go index bef52cb53b..e4aa008016 100644 --- a/tidb-server/server/server.go +++ b/tidb-server/server/server.go @@ -22,7 +22,7 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" - "github.com/pingcap/tidb/mysqldef" + mysql "github.com/pingcap/tidb/mysqldef" "github.com/pingcap/tidb/util/arena" ) @@ -55,8 +55,8 @@ func (s *Server) newConn(conn net.Conn) (cc *clientConn, err error) { pkg: newPacketIO(conn), server: s, connectionID: atomic.AddUint32(&baseConnID, 1), - collation: mysqldef.DefaultCollationID, - charset: mysqldef.DefaultCharset, + collation: mysql.DefaultCollationID, + charset: mysql.DefaultCharset, alloc: arena.NewAllocator(32 * 1024), } cc.salt = make([]byte, 20) @@ -74,12 +74,11 @@ func (s *Server) skipAuth() bool { } func (s *Server) cfgGetPwd(user string) string { - return s.cfg.Password //TODO support multiple users + return s.cfg.Password // TODO: support multiple users } // NewServer creates a new Server. func NewServer(cfg *Config, driver IDriver) (*Server, error) { - log.Warningf("%#v", cfg) s := &Server{ cfg: cfg, driver: driver, @@ -140,8 +139,6 @@ func (s *Server) onConn(c net.Conn) { return } - const key = "connections" - defer func() { log.Infof("close %s", conn) }() diff --git a/tidb-server/server/server_test.go b/tidb-server/server/server_test.go index c004c5ec30..2f7dbe8713 100644 --- a/tidb-server/server/server_test.go +++ b/tidb-server/server/server_test.go @@ -69,7 +69,7 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) func (dbt *DBTest) mustQueryRows(query string, args ...interface{}) { rows := dbt.mustQuery(query, args...) - dbt.Assert(rows.Next(), Equals, true) + dbt.Assert(rows.Next(), IsTrue) rows.Close() } @@ -82,7 +82,7 @@ func runTestRegression(c *C) { // Test for unexpected data var out bool rows := dbt.mustQuery("SELECT * FROM test") - dbt.Assert(rows.Next(), Equals, false, Commentf("unexpected data in empty table")) + dbt.Assert(rows.Next(), IsFalse, Commentf("unexpected data in empty table")) // Create Data res := dbt.mustExec("INSERT INTO test VALUES (1)") @@ -98,8 +98,8 @@ func runTestRegression(c *C) { rows = dbt.mustQuery("SELECT val FROM test") if rows.Next() { rows.Scan(&out) - dbt.Check(out, Equals, true) - dbt.Check(rows.Next(), Equals, false, Commentf("unexpected data")) + dbt.Check(out, IsTrue) + dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) } else { dbt.Error("no data") } @@ -115,8 +115,8 @@ func runTestRegression(c *C) { rows = dbt.mustQuery("SELECT val FROM test") if rows.Next() { rows.Scan(&out) - dbt.Check(out, Equals, false) - dbt.Check(rows.Next(), Equals, false, Commentf("unexpected data")) + dbt.Check(out, IsFalse) + dbt.Check(rows.Next(), IsFalse, Commentf("unexpected data")) } else { dbt.Error("no data") } @@ -173,7 +173,7 @@ func runTestSpecialType(t *C) { dbt.mustExec("create table test (a decimal(10, 5), b datetime, c time)") dbt.mustExec("insert test values (1.4, '2012-12-21 12:12:12', '4:23:34')") rows := dbt.mustQuery("select * from test where a > ?", 0) - t.Assert(rows.Next(), Equals, true) + t.Assert(rows.Next(), IsTrue) var outA float64 var outB, outC string err := rows.Scan(&outA, &outB, &outC) @@ -189,7 +189,7 @@ func runTestPreparedString(t *C) { dbt.mustExec("create table test (a char(10), b char(10))") dbt.mustExec("insert test values (?, ?)", "abcdeabcde", "abcde") rows := dbt.mustQuery("select * from test where 1 = ?", 1) - t.Assert(rows.Next(), Equals, true) + t.Assert(rows.Next(), IsTrue) var outA, outB string err := rows.Scan(&outA, &outB) t.Assert(err, IsNil) diff --git a/util/hack/hack.go b/util/hack/hack.go index 1e3945f248..d792c9e5b3 100644 --- a/util/hack/hack.go +++ b/util/hack/hack.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package hack import ( @@ -5,7 +18,7 @@ import ( "unsafe" ) -// String change slice to string without copy. +// String converts slice to string without copy. // Use at your own risk. func String(b []byte) (s string) { if len(b) == 0 { @@ -18,7 +31,7 @@ func String(b []byte) (s string) { return } -// Slice change string to slice without copy. +// Slice converts string to slice without copy. // Use at your own risk. func Slice(s string) (b []byte) { pbytes := (*reflect.SliceHeader)(unsafe.Pointer(&b)) diff --git a/util/hack/hack_test.go b/util/hack/hack_test.go index 7b11b0b87b..8a3a2a9301 100644 --- a/util/hack/hack_test.go +++ b/util/hack/hack_test.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package hack import (