diff --git a/executor/join_test.go b/executor/join_test.go index a930ae1af9..c62f1ffaeb 100644 --- a/executor/join_test.go +++ b/executor/join_test.go @@ -300,6 +300,17 @@ func (s *testSuite) TestJoinCast(c *C) { tk.MustExec("insert into t1 values(0), (9)") result = tk.MustQuery("select /*+ TIDB_INLJ(t) */ * from t left join t1 on t1.c1 = t.c1") result.Sort().Check(testkit.Rows("0.0 0.00", "2.0 ")) + + tk.MustExec("drop table if exists t") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t(c1 char(10))") + tk.MustExec("create table t1(c1 char(10))") + tk.MustExec("create table t2(c1 char(10))") + tk.MustExec("insert into t values('abd')") + tk.MustExec("insert into t1 values('abc')") + tk.MustExec("insert into t2 values('abc')") + result = tk.MustQuery("select * from (select * from t union all select * from t1) t1 join t2 on t1.c1 = t2.c1") + result.Sort().Check(testkit.Rows("abc abc")) } func (s *testSuite) TestUsing(c *C) { diff --git a/expression/builtin_string.go b/expression/builtin_string.go index 8005aaba98..f8d31e23fe 100644 --- a/expression/builtin_string.go +++ b/expression/builtin_string.go @@ -155,9 +155,9 @@ func reverseRunes(origin []rune) []rune { return origin } -// setBinFlagOrBinStr sets resTp to binary string if argTp is a binary string, +// SetBinFlagOrBinStr sets resTp to binary string if argTp is a binary string, // if not, sets the binary flag of resTp to true if argTp has binary flag. -func setBinFlagOrBinStr(argTp *types.FieldType, resTp *types.FieldType) { +func SetBinFlagOrBinStr(argTp *types.FieldType, resTp *types.FieldType) { if types.IsBinaryStr(argTp) { types.SetBinChsClnFlag(resTp) } else if mysql.HasBinaryFlag(argTp.Flag) { @@ -248,7 +248,7 @@ func (c *concatFunctionClass) getFunction(args []Expression, ctx context.Context } for i := range args { argType := args[i].GetType() - setBinFlagOrBinStr(argType, bf.tp) + SetBinFlagOrBinStr(argType, bf.tp) if argType.Flen < 0 { bf.tp.Flen = mysql.MaxBlobWidth @@ -300,7 +300,7 @@ func (c *concatWSFunctionClass) getFunction(args []Expression, ctx context.Conte for i := range args { argType := args[i].GetType() - setBinFlagOrBinStr(argType, bf.tp) + SetBinFlagOrBinStr(argType, bf.tp) // skip seperator param if i != 0 { @@ -375,7 +375,7 @@ func (c *leftFunctionClass) getFunction(args []Expression, ctx context.Context) } argType := args[0].GetType() bf.tp.Flen = argType.Flen - setBinFlagOrBinStr(argType, bf.tp) + SetBinFlagOrBinStr(argType, bf.tp) if types.IsBinaryStr(argType) { sig := &builtinLeftBinarySig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil @@ -448,7 +448,7 @@ func (c *rightFunctionClass) getFunction(args []Expression, ctx context.Context) } argType := args[0].GetType() bf.tp.Flen = argType.Flen - setBinFlagOrBinStr(argType, bf.tp) + SetBinFlagOrBinStr(argType, bf.tp) if types.IsBinaryStr(argType) { sig := &builtinRightBinarySig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil @@ -521,7 +521,7 @@ func (c *repeatFunctionClass) getFunction(args []Expression, ctx context.Context return nil, errors.Trace(err) } bf.tp.Flen = mysql.MaxBlobWidth - setBinFlagOrBinStr(args[0].GetType(), bf.tp) + SetBinFlagOrBinStr(args[0].GetType(), bf.tp) sig := &builtinRepeatSig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil } @@ -570,7 +570,7 @@ func (c *lowerFunctionClass) getFunction(args []Expression, ctx context.Context) } argTp := args[0].GetType() bf.tp.Flen = argTp.Flen - setBinFlagOrBinStr(argTp, bf.tp) + SetBinFlagOrBinStr(argTp, bf.tp) sig := &builtinLowerSig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil } @@ -699,7 +699,7 @@ func (c *upperFunctionClass) getFunction(args []Expression, ctx context.Context) } argTp := args[0].GetType() bf.tp.Flen = argTp.Flen - setBinFlagOrBinStr(argTp, bf.tp) + SetBinFlagOrBinStr(argTp, bf.tp) sig := &builtinUpperSig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil } @@ -781,7 +781,7 @@ func (c *replaceFunctionClass) getFunction(args []Expression, ctx context.Contex } bf.tp.Flen = c.fixLength(args) for _, a := range args { - setBinFlagOrBinStr(a.GetType(), bf.tp) + SetBinFlagOrBinStr(a.GetType(), bf.tp) } sig := &builtinReplaceSig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil @@ -898,7 +898,7 @@ func (c *substringFunctionClass) getFunction(args []Expression, ctx context.Cont argType := args[0].GetType() bf.tp.Flen = argType.Flen - setBinFlagOrBinStr(argType, bf.tp) + SetBinFlagOrBinStr(argType, bf.tp) var sig builtinFunc switch { @@ -1065,7 +1065,7 @@ func (c *substringIndexFunctionClass) getFunction(args []Expression, ctx context } argType := args[0].GetType() bf.tp.Flen = argType.Flen - setBinFlagOrBinStr(argType, bf.tp) + SetBinFlagOrBinStr(argType, bf.tp) sig := &builtinSubstringIndexSig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil } @@ -1421,7 +1421,7 @@ func (c *trimFunctionClass) getFunction(args []Expression, ctx context.Context) } argType := args[0].GetType() bf.tp.Flen = argType.Flen - setBinFlagOrBinStr(argType, bf.tp) + SetBinFlagOrBinStr(argType, bf.tp) sig := &builtinTrim1ArgSig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil @@ -1431,7 +1431,7 @@ func (c *trimFunctionClass) getFunction(args []Expression, ctx context.Context) return nil, errors.Trace(err) } argType := args[0].GetType() - setBinFlagOrBinStr(argType, bf.tp) + SetBinFlagOrBinStr(argType, bf.tp) sig := &builtinTrim2ArgsSig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil @@ -1442,7 +1442,7 @@ func (c *trimFunctionClass) getFunction(args []Expression, ctx context.Context) } argType := args[0].GetType() bf.tp.Flen = argType.Flen - setBinFlagOrBinStr(argType, bf.tp) + SetBinFlagOrBinStr(argType, bf.tp) sig := &builtinTrim3ArgsSig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil @@ -1552,7 +1552,7 @@ func (c *lTrimFunctionClass) getFunction(args []Expression, ctx context.Context) } argType := args[0].GetType() bf.tp.Flen = argType.Flen - setBinFlagOrBinStr(argType, bf.tp) + SetBinFlagOrBinStr(argType, bf.tp) sig := &builtinLTrimSig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil } @@ -1585,7 +1585,7 @@ func (c *rTrimFunctionClass) getFunction(args []Expression, ctx context.Context) } argType := args[0].GetType() bf.tp.Flen = argType.Flen - setBinFlagOrBinStr(argType, bf.tp) + SetBinFlagOrBinStr(argType, bf.tp) sig := &builtinRTrimSig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil } @@ -1651,8 +1651,8 @@ func (c *lpadFunctionClass) getFunction(args []Expression, ctx context.Context) return nil, errors.Trace(err) } bf.tp.Flen = getFlen4LpadAndRpad(bf.ctx.GetSessionVars().StmtCtx, args[1]) - setBinFlagOrBinStr(args[0].GetType(), bf.tp) - setBinFlagOrBinStr(args[2].GetType(), bf.tp) + SetBinFlagOrBinStr(args[0].GetType(), bf.tp) + SetBinFlagOrBinStr(args[2].GetType(), bf.tp) if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) { sig := &builtinLpadBinarySig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil @@ -1756,8 +1756,8 @@ func (c *rpadFunctionClass) getFunction(args []Expression, ctx context.Context) if err != nil { return nil, errors.Trace(err) } - setBinFlagOrBinStr(args[0].GetType(), bf.tp) - setBinFlagOrBinStr(args[2].GetType(), bf.tp) + SetBinFlagOrBinStr(args[0].GetType(), bf.tp) + SetBinFlagOrBinStr(args[2].GetType(), bf.tp) if types.IsBinaryStr(args[0].GetType()) || types.IsBinaryStr(args[2].GetType()) { sig := &builtinRpadBinarySig{baseStringBuiltinFunc{bf}} return sig.setSelf(sig), nil diff --git a/plan/logical_plan_builder.go b/plan/logical_plan_builder.go index bbb68985bc..9a9df3962d 100644 --- a/plan/logical_plan_builder.go +++ b/plan/logical_plan_builder.go @@ -532,6 +532,9 @@ func joinFieldType(a, b *types.FieldType) *types.FieldType { resultTp.Decimal = mathutil.Max(a.Decimal, b.Decimal) // `Flen - Decimal` is the fraction before '.' resultTp.Flen = mathutil.Max(a.Flen-a.Decimal, b.Flen-b.Decimal) + resultTp.Decimal + resultTp.Charset = a.Charset + resultTp.Collate = a.Collate + expression.SetBinFlagOrBinStr(b, resultTp) return resultTp }