package org.apache.flink.table.plan.rules.logical;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.plan.Contexts;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.rules.AggregateExtractProjectRule;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.flink.table.expressions.PlannerResolvedFieldReference;
import org.apache.flink.table.plan.logical.LogicalWindow;
import org.apache.flink.table.plan.logical.rel.LogicalTableAggregate;
import org.apache.flink.table.plan.logical.rel.LogicalWindowAggregate;
import org.apache.flink.table.plan.logical.rel.LogicalWindowTableAggregate;
import org.apache.flink.table.plan.logical.rel.TableAggregate;

/* loaded from: input_file:org/apache/flink/table/plan/rules/logical/ExtendedAggregateExtractProjectRule.class */
public class ExtendedAggregateExtractProjectRule extends AggregateExtractProjectRule {
    public static final ExtendedAggregateExtractProjectRule INSTANCE = new ExtendedAggregateExtractProjectRule(operand(SingleRel.class, operand(RelNode.class, any()), new RelOptRuleOperand[0]), RelBuilder.proto(Contexts.of(RelFactories.DEFAULT_STRUCT, RelBuilder.Config.DEFAULT.withPruneInputOfAggregate(false))));

    public ExtendedAggregateExtractProjectRule(RelOptRuleOperand relOptRuleOperand, RelBuilderFactory relBuilderFactory) {
        super(relOptRuleOperand, relBuilderFactory);
    }

    @Override // org.apache.calcite.plan.RelOptRule
    public boolean matches(RelOptRuleCall relOptRuleCall) {
        SingleRel singleRel = (SingleRel) relOptRuleCall.rel(0);
        return (singleRel instanceof LogicalWindowAggregate) || (singleRel instanceof LogicalAggregate) || (singleRel instanceof TableAggregate);
    }

    @Override // org.apache.calcite.rel.rules.AggregateExtractProjectRule, org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        RelNode rel = relOptRuleCall.rel(0);
        RelNode rel2 = relOptRuleCall.rel(1);
        RelBuilder push = relOptRuleCall.builder().push(rel2);
        if (rel instanceof Aggregate) {
            relOptRuleCall.transformTo(performExtractForAggregate((Aggregate) rel, rel2, push));
        } else if (rel instanceof TableAggregate) {
            relOptRuleCall.transformTo(performExtractForTableAggregate((TableAggregate) rel, rel2, push));
        }
    }

    private RelNode performExtractForAggregate(Aggregate aggregate, RelNode relNode, RelBuilder relBuilder) {
        return getNewAggregate(aggregate, relBuilder, extractProjectsAndMapping(aggregate, relNode, relBuilder));
    }

    private RelNode performExtractForTableAggregate(TableAggregate tableAggregate, RelNode relNode, RelBuilder relBuilder) {
        RelNode performExtractForAggregate = performExtractForAggregate(tableAggregate.getCorrespondingAggregate(), relNode, relBuilder);
        return tableAggregate instanceof LogicalTableAggregate ? LogicalTableAggregate.create((Aggregate) performExtractForAggregate) : LogicalWindowTableAggregate.create((LogicalWindowAggregate) performExtractForAggregate);
    }

    private Mapping extractProjectsAndMapping(Aggregate aggregate, RelNode relNode, RelBuilder relBuilder) {
        ImmutableBitSet.Builder inputFieldUsed = getInputFieldUsed(aggregate, relNode);
        ArrayList arrayList = new ArrayList();
        Mapping create = Mappings.create(MappingType.INVERSE_SURJECTION, aggregate.getInput().getRowType().getFieldCount(), inputFieldUsed.cardinality());
        int i = 0;
        Iterator<Integer> it = inputFieldUsed.build().iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            arrayList.add(relBuilder.field(intValue));
            int i2 = i;
            i++;
            create.set(intValue, i2);
        }
        if (relNode instanceof Project) {
            relBuilder.project(arrayList);
        } else {
            relBuilder.project(arrayList, Collections.emptyList(), true);
        }
        return create;
    }

    private ImmutableBitSet.Builder getInputFieldUsed(Aggregate aggregate, RelNode relNode) {
        ImmutableBitSet.Builder rebuild = aggregate.getGroupSet().rebuild();
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            Iterator<Integer> it = aggregateCall.getArgList().iterator();
            while (it.hasNext()) {
                rebuild.set(it.next().intValue());
            }
            if (aggregateCall.filterArg >= 0) {
                rebuild.set(aggregateCall.filterArg);
            }
        }
        if (aggregate instanceof LogicalWindowAggregate) {
            rebuild.set(getWindowTimeFieldIndex(((LogicalWindowAggregate) aggregate).getWindow(), relNode));
        }
        return rebuild;
    }

    private RelNode getNewAggregate(Aggregate aggregate, RelBuilder relBuilder, Mapping mapping) {
        ImmutableBitSet apply = Mappings.apply(mapping, aggregate.getGroupSet());
        Iterable<? extends ImmutableBitSet> iterable = (Iterable) aggregate.getGroupSets().stream().map(immutableBitSet -> {
            return Mappings.apply(mapping, immutableBitSet);
        }).collect(Collectors.toList());
        List<RelBuilder.AggCall> newAggCallList = getNewAggCallList(aggregate, relBuilder, mapping);
        RelBuilder.GroupKey groupKey = relBuilder.groupKey(apply, iterable);
        if (!(aggregate instanceof LogicalWindowAggregate)) {
            relBuilder.aggregate(groupKey, newAggCallList);
            return relBuilder.build();
        }
        if (apply.size() == 0 && newAggCallList.size() == 0) {
            return aggregate;
        }
        relBuilder.aggregate(groupKey, newAggCallList);
        LogicalWindowAggregate logicalWindowAggregate = (LogicalWindowAggregate) aggregate;
        return LogicalWindowAggregate.create(logicalWindowAggregate.getWindow(), logicalWindowAggregate.getNamedProperties(), (Aggregate) relBuilder.build());
    }

    private int getWindowTimeFieldIndex(LogicalWindow logicalWindow, RelNode relNode) {
        return relNode.getRowType().getFieldNames().indexOf(((PlannerResolvedFieldReference) logicalWindow.timeAttribute()).name());
    }

    private List<RelBuilder.AggCall> getNewAggCallList(Aggregate aggregate, RelBuilder relBuilder, Mapping mapping) {
        ArrayList arrayList = new ArrayList();
        for (AggregateCall aggregateCall : aggregate.getAggCallList()) {
            arrayList.add(relBuilder.aggregateCall(aggregateCall.getAggregation(), relBuilder.fields((List<? extends Number>) Mappings.apply2(mapping, aggregateCall.getArgList()))).distinct(aggregateCall.isDistinct()).filter(aggregateCall.filterArg < 0 ? null : relBuilder.field(Mappings.apply(mapping, aggregateCall.filterArg))).approximate(aggregateCall.isApproximate()).sort(relBuilder.fields(aggregateCall.collation)).as(aggregateCall.name));
        }
        return arrayList;
    }
}
