package org.shaded.apache.hadoop.hive.ql.parse.spark;

import com.google.common.base.Preconditions;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Stack;
import org.shaded.apache.hadoop.hive.ql.exec.Operator;
import org.shaded.apache.hadoop.hive.ql.exec.TableScanOperator;
import org.shaded.apache.hadoop.hive.ql.exec.Utilities;
import org.shaded.apache.hadoop.hive.ql.exec.spark.SparkUtilities;
import org.shaded.apache.hadoop.hive.ql.lib.Node;
import org.shaded.apache.hadoop.hive.ql.lib.NodeProcessor;
import org.shaded.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
import org.shaded.apache.hadoop.hive.ql.parse.SemanticException;
import org.shaded.apache.hadoop.hive.ql.plan.OperatorDesc;

/* loaded from: input_file:org/shaded/apache/hadoop/hive/ql/parse/spark/SplitOpTreeForDPP.class */
public class SplitOpTreeForDPP implements NodeProcessor {
    @Override // org.shaded.apache.hadoop.hive.ql.lib.NodeProcessor
    public Object process(Node node, Stack<Node> stack, NodeProcessorCtx nodeProcessorCtx, Object... objArr) throws SemanticException {
        SparkPartitionPruningSinkOperator sparkPartitionPruningSinkOperator = (SparkPartitionPruningSinkOperator) node;
        GenSparkProcContext genSparkProcContext = (GenSparkProcContext) nodeProcessorCtx;
        Operator<? extends OperatorDesc> operator = sparkPartitionPruningSinkOperator;
        Operator<? extends OperatorDesc> operator2 = null;
        while (operator != null && operator.getNumChild() <= 1) {
            operator2 = operator;
            operator = operator.getParentOperators().get(0);
        }
        if (sparkPartitionPruningSinkOperator.isWithMapjoin()) {
            genSparkProcContext.pruningSinkSet.add(sparkPartitionPruningSinkOperator);
            return null;
        }
        List<Operator<?>> linkedList = new LinkedList<>();
        collectRoots(linkedList, sparkPartitionPruningSinkOperator);
        sparkPartitionPruningSinkOperator.getBranchingOp();
        List<Operator<? extends OperatorDesc>> childOperators = operator.getChildOperators();
        operator.setChildOperators(Utilities.makeList(operator2));
        List<Operator<?>> cloneOperatorTree = Utilities.cloneOperatorTree(genSparkProcContext.parseContext.getConf(), linkedList);
        for (int i = 0; i < linkedList.size(); i++) {
            ((TableScanOperator) cloneOperatorTree.get(i)).getConf().setTableMetadata(((TableScanOperator) linkedList.get(i)).getConf().getTableMetadata());
        }
        genSparkProcContext.clonedPruningTableScanSet.addAll(cloneOperatorTree);
        operator.setChildOperators(childOperators);
        operator.removeChild(operator2);
        HashSet hashSet = new HashSet();
        Iterator<Operator<?>> it = cloneOperatorTree.iterator();
        while (it.hasNext()) {
            SparkUtilities.collectOp(hashSet, it.next(), SparkPartitionPruningSinkOperator.class);
        }
        Preconditions.checkArgument(hashSet.size() == 1, "AssertionError: expected to only contain one SparkPartitionPruningSinkOperator, but found " + hashSet.size());
        SparkPartitionPruningSinkOperator sparkPartitionPruningSinkOperator2 = (SparkPartitionPruningSinkOperator) hashSet.iterator().next();
        sparkPartitionPruningSinkOperator2.getConf().setTableScan(sparkPartitionPruningSinkOperator.getConf().getTableScan());
        genSparkProcContext.pruningSinkSet.add(sparkPartitionPruningSinkOperator2);
        return null;
    }

    private void collectRoots(List<Operator<?>> list, Operator<?> operator) {
        if (operator.getNumParent() == 0) {
            list.add(operator);
            return;
        }
        Iterator<Operator<? extends OperatorDesc>> it = operator.getParentOperators().iterator();
        while (it.hasNext()) {
            collectRoots(list, it.next());
        }
    }
}
