[feature](nereids) Support row policy (#13879)

This pr did two things:
1. 【new logical plan】add **LogicalCheckPolicy** before UnboundRelation in LogicalPlanBuilder.
2. 【new rule】turn **LogicalCheckPolicy** to LogicalFilter if row policy exist, otherwise remove it.
This commit is contained in:
xiaojunjie
2022-11-25 22:57:56 +08:00
committed by GitHub
parent d159a8d24b
commit 2ae7dae925
20 changed files with 663 additions and 16 deletions

View File

@ -266,10 +266,12 @@ PARTITIONED: 'PARTITIONED';
PARTITIONS: 'PARTITIONS';
PERCENTILE_CONT: 'PERCENTILE_CONT';
PERCENTLIT: 'PERCENT';
PERMISSIVE: 'PERMISSIVE';
PHYSICAL: 'PHYSICAL';
PIVOT: 'PIVOT';
PLACING: 'PLACING';
PLAN: 'PLAN';
POLICY: 'POLICY';
POSITION: 'POSITION';
PRECEDING: 'PRECEDING';
PRIMARY: 'PRIMARY';
@ -292,6 +294,7 @@ REPLACE: 'REPLACE';
RESET: 'RESET';
RESPECT: 'RESPECT';
RESTRICT: 'RESTRICT';
RESTRICTIVE: 'RESTRICTIVE';
REVOKE: 'REVOKE';
REWRITTEN: 'REWRITTEN';
RIGHT: 'RIGHT';

View File

@ -50,6 +50,11 @@ singleStatement
statement
: explain? cte? query #statementDefault
| CREATE ROW POLICY (IF NOT EXISTS)? name=identifier
ON table=multipartIdentifier
AS type=(RESTRICTIVE | PERMISSIVE)
TO user=identifier
USING LEFT_PAREN booleanExpression RIGHT_PAREN #createRowPolicy
;
explain
@ -336,8 +341,6 @@ number
| MINUS? (EXPONENT_VALUE | DECIMAL_VALUE) #decimalLiteral
;
// When `SQL_standard_keyword_behavior=true`, there are 2 kinds of keywords in Spark SQL.
// - Reserved keywords:
// Keywords that are reserved and can't be used as identifiers for table, view, column,
@ -764,6 +767,7 @@ nonReserved
| PIVOT
| PLACING
| PLAN
| POLICY
| POSITION
| PRECEDING
| PRIMARY

View File

@ -21,6 +21,7 @@ import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.analysis.BindFunction;
import org.apache.doris.nereids.rules.analysis.BindRelation;
import org.apache.doris.nereids.rules.analysis.BindSlotReference;
import org.apache.doris.nereids.rules.analysis.CheckPolicy;
import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
@ -51,6 +52,7 @@ public class AnalyzeRulesJob extends BatchRulesJob {
)),
bottomUpBatch(ImmutableList.of(
new BindRelation(),
new CheckPolicy(),
new UserAuthentication(),
new BindSlotReference(scope),
new BindFunction(),

View File

@ -81,11 +81,11 @@ public class RewriteBottomUpJob extends Job {
if (!copyInResult.isPresent()) {
continue;
}
CopyInResult result = copyInResult.get();
boolean groupChanged = result.correspondingExpression.getOwnerGroup() != group;
if (result.generateNewExpression || groupChanged) {
pushJob(new RewriteBottomUpJob(result.correspondingExpression.getOwnerGroup(),
rules, context, !groupChanged));
Group correspondingGroup = copyInResult.get().correspondingExpression.getOwnerGroup();
if (copyInResult.get().generateNewExpression
|| correspondingGroup != group
|| logicalExpression.getOwnerGroup() == null) {
pushJob(new RewriteBottomUpJob(correspondingGroup, rules, context, false));
return;
}
}

View File

@ -82,14 +82,14 @@ public class RewriteTopDownJob extends Job {
if (!copyInResult.isPresent()) {
continue;
}
CopyInResult result = copyInResult.get();
boolean groupChanged = result.correspondingExpression.getOwnerGroup() != group;
if (result.generateNewExpression || groupChanged) {
Group correspondingGroup = copyInResult.get().correspondingExpression.getOwnerGroup();
if (copyInResult.get().generateNewExpression
|| correspondingGroup != group
|| logicalExpression.getOwnerGroup() == null) {
// new group-expr replaced the origin group-expr in `group`,
// run this rule against this `group` again.
context.setRewritten(true);
pushJob(new RewriteTopDownJob(result.correspondingExpression.getOwnerGroup(),
rules, context));
pushJob(new RewriteTopDownJob(correspondingGroup, rules, context));
return;
}
}

View File

@ -29,6 +29,7 @@ import org.apache.doris.nereids.DorisParser.ArithmeticUnaryContext;
import org.apache.doris.nereids.DorisParser.BooleanLiteralContext;
import org.apache.doris.nereids.DorisParser.ColumnReferenceContext;
import org.apache.doris.nereids.DorisParser.ComparisonContext;
import org.apache.doris.nereids.DorisParser.CreateRowPolicyContext;
import org.apache.doris.nereids.DorisParser.CteContext;
import org.apache.doris.nereids.DorisParser.DecimalLiteralContext;
import org.apache.doris.nereids.DorisParser.DereferenceContext;
@ -139,10 +140,13 @@ import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.commands.Command;
import org.apache.doris.nereids.trees.plans.commands.CreatePolicyCommand;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand.ExplainLevel;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
@ -155,6 +159,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.policy.PolicyTypeEnum;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
@ -266,6 +271,12 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
});
}
@Override
public Command visitCreateRowPolicy(CreateRowPolicyContext ctx) {
// Only wherePredicate is needed at present
return new CreatePolicyCommand(PolicyTypeEnum.ROW, getExpression(ctx.booleanExpression()));
}
@Override
public LogicalPlan visitQuery(QueryContext ctx) {
return ParserUtils.withOrigin(ctx, () -> {
@ -315,10 +326,15 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
});
}
private LogicalPlan withCheckPolicy(LogicalPlan plan) {
return new LogicalCheckPolicy(plan);
}
@Override
public LogicalPlan visitTableName(TableNameContext ctx) {
List<String> tableId = visitMultipartIdentifier(ctx.multipartIdentifier());
return withTableAlias(new UnboundRelation(tableId), ctx.tableAlias());
LogicalPlan checkedRelation = withCheckPolicy(new UnboundRelation(tableId));
return withTableAlias(checkedRelation, ctx.tableAlias());
}
@Override

View File

@ -68,6 +68,8 @@ public enum RuleType {
ADJUST_NULLABLE_FOR_AGGREGATE_SLOT(RuleTypeClass.REWRITE),
ADJUST_NULLABLE_FOR_REPEAT_SLOT(RuleTypeClass.REWRITE),
CHECK_ROW_POLICY(RuleTypeClass.REWRITE),
// check analysis rule
CHECK_ANALYSIS(RuleTypeClass.CHECK),

View File

@ -0,0 +1,56 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Optional;
/**
* CheckPolicy.
*/
public class CheckPolicy implements AnalysisRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.CHECK_ROW_POLICY.build(
logicalCheckPolicy(logicalSubQueryAlias()).then(checkPolicy -> checkPolicy.child())
),
RuleType.CHECK_ROW_POLICY.build(
logicalCheckPolicy(logicalRelation()).thenApply(ctx -> {
LogicalCheckPolicy<LogicalRelation> checkPolicy = ctx.root;
LogicalRelation relation = checkPolicy.child();
Optional<Expression> filter = checkPolicy.getFilter(relation, ctx.connectContext);
if (!filter.isPresent()) {
return relation;
}
return new LogicalFilter(filter.get(), relation);
})
)
);
}
}

View File

@ -47,6 +47,7 @@ public enum PlanType {
LOGICAL_ASSERT_NUM_ROWS,
LOGICAL_HAVING,
LOGICAL_MULTI_JOIN,
LOGICAL_CHECK_POLICY,
GROUP_PLAN,
// physical plan

View File

@ -0,0 +1,30 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.algebra;
import org.apache.doris.catalog.Database;
import org.apache.doris.catalog.Table;
import org.apache.doris.nereids.analyzer.Relation;
import org.apache.doris.nereids.exceptions.AnalysisException;
/** CatalogRelation */
public interface CatalogRelation extends Relation {
Table getTable();
Database getDatabase() throws AnalysisException;
}

View File

@ -0,0 +1,48 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.commands;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.policy.PolicyTypeEnum;
import java.util.Optional;
/**
* Create policy command.
*/
public class CreatePolicyCommand implements Command {
private PolicyTypeEnum type;
private final Optional<Expression> wherePredicate;
public CreatePolicyCommand(PolicyTypeEnum type, Expression expr) {
this.type = type;
this.wherePredicate = Optional.of(expr);
}
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visitCreatePolicyCommand(this, context);
}
public Optional<Expression> getWherePredicate() {
return wherePredicate;
}
}

View File

@ -0,0 +1,178 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.commands.CreatePolicyCommand;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.policy.PolicyMgr;
import org.apache.doris.policy.RowPolicy;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
/**
* Logical Check Policy
*/
public class LogicalCheckPolicy<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYPE> {
public LogicalCheckPolicy(CHILD_TYPE child) {
super(PlanType.LOGICAL_CHECK_POLICY, child);
}
public LogicalCheckPolicy(Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, CHILD_TYPE child) {
super(PlanType.LOGICAL_CHECK_POLICY, groupExpression, logicalProperties, child);
}
@Override
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
return visitor.visitLogicalCheckPolicy(this, context);
}
@Override
public List<? extends Expression> getExpressions() {
return ImmutableList.of();
}
@Override
public List<Slot> computeOutput() {
return child().getOutput();
}
@Override
public String toString() {
return Utils.toSqlString("LogicalCheckPolicy",
"child", child()
);
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
LogicalCheckPolicy that = (LogicalCheckPolicy) o;
return child().equals(that.child());
}
@Override
public int hashCode() {
return child().hashCode();
}
@Override
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
return new LogicalCheckPolicy<>(groupExpression, Optional.of(getLogicalProperties()), child());
}
@Override
public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
return new LogicalCheckPolicy<>(Optional.empty(), logicalProperties, child());
}
@Override
public Plan withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1);
return new LogicalCheckPolicy<>(children.get(0));
}
/**
* get wherePredicate of policy for logicalRelation.
*
* @param logicalRelation include tableName and dbName
* @param connectContext include information about user and policy
*/
public Optional<Expression> getFilter(LogicalRelation logicalRelation, ConnectContext connectContext) {
if (!(logicalRelation instanceof CatalogRelation)) {
return Optional.empty();
}
PolicyMgr policyMgr = connectContext.getEnv().getPolicyMgr();
UserIdentity currentUserIdentity = connectContext.getCurrentUserIdentity();
String user = connectContext.getQualifiedUser();
if (currentUserIdentity.isRootUser() || currentUserIdentity.isAdminUser()) {
return Optional.empty();
}
if (!policyMgr.existPolicy(user)) {
return Optional.empty();
}
CatalogRelation catalogRelation = (CatalogRelation) logicalRelation;
long dbId = catalogRelation.getDatabase().getId();
long tableId = catalogRelation.getTable().getId();
List<RowPolicy> policies = policyMgr.getMatchRowPolicy(dbId, tableId, currentUserIdentity);
if (policies.isEmpty()) {
return Optional.empty();
}
return Optional.ofNullable(mergeRowPolicy(policies));
}
private Expression mergeRowPolicy(List<RowPolicy> policies) {
List<Expression> orList = new ArrayList<>();
List<Expression> andList = new ArrayList<>();
for (RowPolicy policy : policies) {
String sql = policy.getOriginStmt();
NereidsParser nereidsParser = new NereidsParser();
CreatePolicyCommand command = (CreatePolicyCommand) nereidsParser.parseSingle(sql);
Optional<Expression> wherePredicate = command.getWherePredicate();
if (!wherePredicate.isPresent()) {
throw new AnalysisException("Invaild row policy [" + policy.getPolicyName() + "], " + sql);
}
switch (policy.getFilterType()) {
case PERMISSIVE:
orList.add(wherePredicate.get());
break;
case RESTRICTIVE:
andList.add(wherePredicate.get());
break;
default:
throw new IllegalStateException("Invalid operator");
}
}
if (!andList.isEmpty() && !orList.isEmpty()) {
return new And(ExpressionUtils.and(andList), ExpressionUtils.or(orList));
} else if (andList.isEmpty()) {
return ExpressionUtils.or(orList);
} else if (orList.isEmpty()) {
return ExpressionUtils.and(andList);
} else {
return null;
}
}
}

View File

@ -17,14 +17,18 @@
package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.catalog.Database;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.Table;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.PreAggStatus;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.Utils;
@ -40,7 +44,7 @@ import java.util.Optional;
/**
* Logical OlapScan.
*/
public class LogicalOlapScan extends LogicalRelation {
public class LogicalOlapScan extends LogicalRelation implements CatalogRelation {
private final long selectedIndexId;
private final ImmutableList<Long> selectedTabletId;
@ -93,6 +97,13 @@ public class LogicalOlapScan extends LogicalRelation {
return (OlapTable) table;
}
@Override
public Database getDatabase() throws AnalysisException {
Preconditions.checkArgument(!qualifier.isEmpty());
return Env.getCurrentInternalCatalog().getDbOrException(qualifier.get(0),
s -> new AnalysisException("Database [" + qualifier.get(0) + "] does not exist."));
}
@Override
public String toString() {
return Utils.toSqlString("LogicalOlapScan",

View File

@ -23,11 +23,13 @@ import org.apache.doris.nereids.analyzer.UnboundTVFRelation;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.commands.Command;
import org.apache.doris.nereids.trees.plans.commands.CreatePolicyCommand;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
@ -85,6 +87,10 @@ public abstract class PlanVisitor<R, C> {
return visitCommand(explain, context);
}
public R visitCreatePolicyCommand(CreatePolicyCommand explain, C context) {
return visitCommand(explain, context);
}
// *******************************
// Logical plans
// *******************************
@ -137,6 +143,10 @@ public abstract class PlanVisitor<R, C> {
return visit(filter, context);
}
public R visitLogicalCheckPolicy(LogicalCheckPolicy<? extends Plan> checkPolicy, C context) {
return visit(checkPolicy, context);
}
public R visitLogicalOlapScan(LogicalOlapScan olapScan, C context) {
return visitLogicalRelation(olapScan, context);
}

View File

@ -128,6 +128,14 @@ public class ExpressionUtils {
return combine(And.class, Lists.newArrayList(expressions));
}
public static Optional<Expression> optionalOr(List<Expression> expressions) {
if (expressions.isEmpty()) {
return Optional.empty();
} else {
return Optional.of(ExpressionUtils.or(expressions));
}
}
public static Expression or(Expression... expressions) {
return combine(Or.class, Lists.newArrayList(expressions));
}

View File

@ -22,6 +22,7 @@ import org.apache.doris.analysis.CompoundPredicate;
import org.apache.doris.analysis.CreatePolicyStmt;
import org.apache.doris.analysis.DropPolicyStmt;
import org.apache.doris.analysis.ShowPolicyStmt;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.Env;
import org.apache.doris.common.AnalysisException;
import org.apache.doris.common.DdlException;
@ -284,6 +285,26 @@ public class PolicyMgr implements Writable {
}
}
/**
* Match all row policy and return them.
**/
public List<RowPolicy> getMatchRowPolicy(long dbId, long tableId, UserIdentity user) {
RowPolicy checkedPolicy = new RowPolicy();
checkedPolicy.setDbId(dbId);
checkedPolicy.setTableId(tableId);
checkedPolicy.setUser(user);
readLock();
try {
return getPoliciesByType(PolicyTypeEnum.ROW).stream()
.filter(p -> p.matchPolicy(checkedPolicy))
.filter(p -> !p.isInvalid())
.map(p -> (RowPolicy) p)
.collect(Collectors.toList());
} finally {
readUnlock();
}
}
/**
* Show policy through stmt.
**/

View File

@ -80,7 +80,9 @@ public class LimitClauseTest extends ParserTestBase {
parsePlan("select a from tbl")
.matchesFromRoot(
logicalProject(
unboundRelation()
logicalCheckPolicy(
unboundRelation()
)
)
);
}

View File

@ -0,0 +1,148 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.analysis.CreateUserStmt;
import org.apache.doris.analysis.GrantStmt;
import org.apache.doris.analysis.TablePattern;
import org.apache.doris.analysis.UserDesc;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.catalog.AccessPrivilege;
import org.apache.doris.catalog.AggregateType;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.Database;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.catalog.PartitionInfo;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.FeConstants;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.util.PlanRewriter;
import org.apache.doris.system.SystemInfoService;
import org.apache.doris.thrift.TStorageType;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import java.util.List;
public class CheckRowPolicyTest extends TestWithFeService {
private static String dbName = "check_row_policy";
private static String fullDbName = "default_cluster:" + dbName;
private static String tableName = "table1";
private static String userName = "user1";
private static String policyName = "policy1";
private static OlapTable olapTable = new OlapTable(0L, tableName,
ImmutableList.<Column>of(new Column("k1", Type.INT, false, AggregateType.NONE, "0", ""),
new Column("k2", Type.INT, false, AggregateType.NONE, "0", "")),
KeysType.PRIMARY_KEYS, new PartitionInfo(), null);
@Override
protected void runBeforeAll() throws Exception {
FeConstants.runningUnitTest = true;
createDatabase(dbName);
useDatabase(dbName);
createTable("create table "
+ tableName
+ " (k1 int, k2 int) distributed by hash(k1) buckets 1"
+ " properties(\"replication_num\" = \"1\");");
Database db = Env.getCurrentInternalCatalog().getDbOrMetaException(fullDbName);
long tableId = db.getTableOrMetaException("table1").getId();
olapTable.setId(tableId);
olapTable.setIndexMeta(-1,
olapTable.getName(),
olapTable.getFullSchema(),
0, 0, (short) 0,
TStorageType.COLUMN,
KeysType.PRIMARY_KEYS);
// create user
UserIdentity user = new UserIdentity(userName, "%");
user.analyze(SystemInfoService.DEFAULT_CLUSTER);
CreateUserStmt createUserStmt = new CreateUserStmt(new UserDesc(user));
Env.getCurrentEnv().getAuth().createUser(createUserStmt);
List<AccessPrivilege> privileges = Lists.newArrayList(AccessPrivilege.ADMIN_PRIV);
TablePattern tablePattern = new TablePattern("*", "*", "*");
tablePattern.analyze(SystemInfoService.DEFAULT_CLUSTER);
GrantStmt grantStmt = new GrantStmt(user, null, tablePattern, privileges);
Env.getCurrentEnv().getAuth().grant(grantStmt);
}
@Test
public void checkUser() throws AnalysisException, org.apache.doris.common.AnalysisException {
LogicalRelation relation = new LogicalOlapScan(new RelationId(0), olapTable, Arrays.asList(fullDbName));
LogicalCheckPolicy<LogicalRelation> checkPolicy = new LogicalCheckPolicy<>(relation);
useUser("root");
Plan plan = PlanRewriter.bottomUpRewrite(checkPolicy, connectContext, new CheckPolicy());
Assertions.assertEquals(plan, relation);
useUser("notFound");
plan = PlanRewriter.bottomUpRewrite(checkPolicy, connectContext, new CheckPolicy());
Assertions.assertEquals(plan, relation);
}
@Test
public void checkNoPolicy() throws org.apache.doris.common.AnalysisException {
useUser(userName);
LogicalRelation relation = new LogicalOlapScan(new RelationId(0), olapTable, Arrays.asList(fullDbName));
LogicalCheckPolicy<LogicalRelation> checkPolicy = new LogicalCheckPolicy<>(relation);
Plan plan = PlanRewriter.bottomUpRewrite(checkPolicy, connectContext, new CheckPolicy());
Assertions.assertEquals(plan, relation);
}
@Test
public void checkOnePolicy() throws Exception {
useUser(userName);
LogicalRelation relation = new LogicalOlapScan(new RelationId(0), olapTable, Arrays.asList(fullDbName));
LogicalCheckPolicy<LogicalRelation> checkPolicy = new LogicalCheckPolicy<>(relation);
connectContext.getSessionVariable().setEnableNereidsPlanner(true);
createPolicy("CREATE ROW POLICY "
+ policyName
+ " ON "
+ tableName
+ " AS PERMISSIVE TO "
+ userName
+ " USING (k1 = 1)");
Plan plan = PlanRewriter.bottomUpRewrite(checkPolicy, connectContext, new CheckPolicy());
Assertions.assertTrue(plan instanceof LogicalFilter);
LogicalFilter filter = (LogicalFilter) plan;
Assertions.assertEquals(filter.child(), relation);
Assertions.assertTrue(filter.getPredicates() instanceof EqualTo);
Assertions.assertTrue(filter.getPredicates().toString().contains("k1 = 1"));
dropPolicy("DROP ROW POLICY "
+ policyName
+ " ON "
+ tableName);
}
}

View File

@ -148,7 +148,9 @@ public class RegisterCTETest extends TestWithFeService implements PatternMatchSu
logicalSubQueryAlias(
logicalProject(
logicalFilter(
unboundRelation()
logicalCheckPolicy(
unboundRelation()
)
)
)
)

View File

@ -0,0 +1,105 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("test_nereids_row_policy") {
def dbName = context.config.getDbNameByFile(context.file)
def tableName = "nereids_row_policy"
def viewName = "view_" + tableName
def user='row_policy_user'
def tokens = context.config.jdbcUrl.split('/')
def url=tokens[0] + "//" + tokens[2] + "/" + dbName + "?"
def assertQueryResult = { size ->
def result1 = connect(user=user, password='123456', url=url) {
sql "set enable_nereids_planner = false"
sql "SELECT * FROM ${tableName}"
}
def result2 = connect(user=user, password='123456', url=url) {
sql "set enable_nereids_planner = true"
sql "set enable_fallback_to_original_planner = false"
sql "SELECT * FROM ${tableName}"
}
def result3 = connect(user=user, password='123456', url=url) {
sql "set enable_nereids_planner = true"
sql "set enable_fallback_to_original_planner = false"
sql "SELECT * FROM ${viewName}"
}
assertEquals(size, result1.size())
assertEquals(size, result2.size())
assertEquals(size, result3.size())
}
def createPolicy = { name, predicate, type ->
sql """
CREATE ROW POLICY ${name} ON ${dbName}.${tableName}
AS ${type} TO ${user} USING (${predicate})
"""
}
def dropPolciy = { name ->
sql """
DROP ROW POLICY ${name} ON ${dbName}.${tableName} FOR ${user}
"""
}
// create table
sql "DROP TABLE IF EXISTS ${tableName}"
sql """
CREATE TABLE ${tableName} (
`k` INT,
`v` INT
) DUPLICATE KEY (`k`) DISTRIBUTED BY HASH (`k`) BUCKETS 1
PROPERTIES ('replication_num' = '1')
"""
sql """
insert into ${tableName} values (1,1), (2,1), (1,3);
"""
// create view
sql """
create view ${viewName} as select * from ${tableName};
"""
// create user
sql "DROP USER IF EXISTS ${user}"
sql "CREATE USER ${user} IDENTIFIED BY '123456'"
sql "GRANT SELECT_PRIV ON internal.${dbName}.${tableName} TO ${user}"
// no policy
assertQueryResult 3
// (k = 1)
createPolicy"policy0", "k = 1", "RESTRICTIVE"
assertQueryResult 2
// (k = 1 and v = 1)
createPolicy"policy1", "v = 1", "RESTRICTIVE"
assertQueryResult 1
// (v = 1)
dropPolciy "policy0"
assertQueryResult 2
// (v = 1) and (k = 1)
createPolicy"policy2", "k = 1", "PERMISSIVE"
assertQueryResult 1
// (v = 1) and (k = 1 or k = 2)
createPolicy"policy3", "k = 2", "PERMISSIVE"
assertQueryResult 2
}