/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore.filter;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;
import org.antlr.v4.runtime.ANTLRErrorListener;
import org.antlr.v4.runtime.ANTLRErrorStrategy;
import org.antlr.v4.runtime.BailErrorStrategy;
import org.antlr.v4.runtime.BaseErrorListener;
import org.antlr.v4.runtime.CharStream;
import org.antlr.v4.runtime.CharStreams;
import org.antlr.v4.runtime.CommonTokenStream;
import org.antlr.v4.runtime.RecognitionException;
import org.antlr.v4.runtime.Recognizer;
import org.antlr.v4.runtime.TokenSource;
import org.antlr.v4.runtime.TokenStream;
import org.antlr.v4.runtime.misc.ParseCancellationException;
import org.antlr.v4.runtime.tree.ParseTree;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.antlr4.FiltersBaseVisitor;
import org.springframework.ai.vectorstore.filter.antlr4.FiltersLexer;
import org.springframework.ai.vectorstore.filter.antlr4.FiltersParser;
import org.springframework.core.NestedExceptionUtils;
import org.springframework.util.Assert;

public class FilterExpressionTextParser {
    private static final String WHERE_PREFIX = "WHERE";
    private final DescriptiveErrorListener errorListener;
    private final ANTLRErrorStrategy errorHandler;
    private final Map<String, Filter.Expression> cache = new ConcurrentHashMap<String, Filter.Expression>();

    public FilterExpressionTextParser() {
        this((ANTLRErrorStrategy)new BailErrorStrategy());
    }

    public FilterExpressionTextParser(ANTLRErrorStrategy handler) {
        this.errorListener = DescriptiveErrorListener.INSTANCE;
        this.errorHandler = handler;
    }

    public Filter.Expression parse(String textFilterExpression) {
        Assert.hasText((String)textFilterExpression, (String)"Expression should not be empty!");
        if (!textFilterExpression.toUpperCase().startsWith(WHERE_PREFIX)) {
            textFilterExpression = String.format("%s %s", WHERE_PREFIX, textFilterExpression);
        }
        if (this.cache.containsKey(textFilterExpression)) {
            return this.cache.get(textFilterExpression);
        }
        FiltersLexer lexer = new FiltersLexer((CharStream)CharStreams.fromString((String)textFilterExpression));
        CommonTokenStream tokens = new CommonTokenStream((TokenSource)lexer);
        FiltersParser parser = new FiltersParser((TokenStream)tokens);
        parser.removeErrorListeners();
        this.errorListener.errorMessages.clear();
        parser.addErrorListener((ANTLRErrorListener)this.errorListener);
        if (this.errorHandler != null) {
            parser.setErrorHandler(this.errorHandler);
        }
        FilterExpressionVisitor filterExpressionVisitor = new FilterExpressionVisitor();
        try {
            Filter.Operand operand = (Filter.Operand)filterExpressionVisitor.visit((ParseTree)parser.where());
            Filter.Expression filterExpression = filterExpressionVisitor.castToExpression(operand);
            this.cache.putIfAbsent(textFilterExpression, filterExpression);
            return filterExpression;
        }
        catch (ParseCancellationException e) {
            String msg = this.errorListener.errorMessages.stream().collect(Collectors.joining());
            Throwable rootCause = NestedExceptionUtils.getRootCause((Throwable)e);
            throw new FilterExpressionParseException(msg, rootCause);
        }
    }

    public void clearCache() {
        this.cache.clear();
    }

    Map<String, Filter.Expression> getCache() {
        return this.cache;
    }

    public static class DescriptiveErrorListener
    extends BaseErrorListener {
        public static final DescriptiveErrorListener INSTANCE = new DescriptiveErrorListener();
        public final List<String> errorMessages = new CopyOnWriteArrayList<String>();

        public void syntaxError(Recognizer<?, ?> recognizer, Object offendingSymbol, int line, int charPositionInLine, String msg, RecognitionException e) {
            String sourceName = recognizer.getInputStream().getSourceName();
            String errorMessage = String.format("Source: %s, Line: %s:%s, Error: %s", sourceName, line, charPositionInLine, msg);
            this.errorMessages.add(errorMessage);
        }
    }

    public static class FilterExpressionVisitor
    extends FiltersBaseVisitor<Filter.Operand> {
        private static final Map<String, Filter.ExpressionType> COMP_EXPRESSION_TYPE_MAP = Map.of("==", Filter.ExpressionType.EQ, "!=", Filter.ExpressionType.NE, ">", Filter.ExpressionType.GT, ">=", Filter.ExpressionType.GTE, "<", Filter.ExpressionType.LT, "<=", Filter.ExpressionType.LTE);

        @Override
        public Filter.Operand visitWhere(FiltersParser.WhereContext ctx) {
            return (Filter.Operand)this.visit((ParseTree)ctx.booleanExpression());
        }

        @Override
        public Filter.Operand visitIdentifier(FiltersParser.IdentifierContext ctx) {
            return new Filter.Key(ctx.getText());
        }

        @Override
        public Filter.Operand visitTextConstant(FiltersParser.TextConstantContext ctx) {
            String onceQuotedText = this.removeOuterQuotes(ctx.getText());
            return new Filter.Value(onceQuotedText);
        }

        private String removeOuterQuotes(String in) {
            return in.substring(1, in.length() - 1);
        }

        @Override
        public Filter.Operand visitIntegerConstant(FiltersParser.IntegerConstantContext ctx) {
            return new Filter.Value(Integer.valueOf(ctx.getText()));
        }

        @Override
        public Filter.Operand visitDecimalConstant(FiltersParser.DecimalConstantContext ctx) {
            return new Filter.Value(Double.valueOf(ctx.getText()));
        }

        @Override
        public Filter.Operand visitBooleanConstant(FiltersParser.BooleanConstantContext ctx) {
            return new Filter.Value(Boolean.valueOf(ctx.getText()));
        }

        @Override
        public Filter.Operand visitConstantArray(FiltersParser.ConstantArrayContext ctx) {
            ArrayList list = new ArrayList();
            ctx.constant().forEach(constantCtx -> list.add(((Filter.Value)this.visit((ParseTree)constantCtx)).value()));
            return new Filter.Value(list);
        }

        @Override
        public Filter.Operand visitInExpression(FiltersParser.InExpressionContext ctx) {
            return new Filter.Expression(Filter.ExpressionType.IN, this.visitIdentifier(ctx.identifier()), this.visitConstantArray(ctx.constantArray()));
        }

        @Override
        public Filter.Operand visitNinExpression(FiltersParser.NinExpressionContext ctx) {
            return new Filter.Expression(Filter.ExpressionType.NIN, this.visitIdentifier(ctx.identifier()), this.visitConstantArray(ctx.constantArray()));
        }

        @Override
        public Filter.Operand visitCompareExpression(FiltersParser.CompareExpressionContext ctx) {
            return new Filter.Expression(this.covertCompare(ctx.compare().getText()), this.visitIdentifier(ctx.identifier()), (Filter.Operand)this.visit((ParseTree)ctx.constant()));
        }

        private Filter.ExpressionType covertCompare(String compare) {
            if (!COMP_EXPRESSION_TYPE_MAP.containsKey(compare)) {
                throw new RuntimeException("Unknown compare operator: " + compare);
            }
            return COMP_EXPRESSION_TYPE_MAP.get(compare);
        }

        @Override
        public Filter.Operand visitAndExpression(FiltersParser.AndExpressionContext ctx) {
            return new Filter.Expression(Filter.ExpressionType.AND, (Filter.Operand)this.visit((ParseTree)ctx.left), (Filter.Operand)this.visit((ParseTree)ctx.right));
        }

        @Override
        public Filter.Operand visitOrExpression(FiltersParser.OrExpressionContext ctx) {
            return new Filter.Expression(Filter.ExpressionType.OR, (Filter.Operand)this.visit((ParseTree)ctx.left), (Filter.Operand)this.visit((ParseTree)ctx.right));
        }

        @Override
        public Filter.Operand visitGroupExpression(FiltersParser.GroupExpressionContext ctx) {
            return new Filter.Group(this.castToExpression((Filter.Operand)this.visit((ParseTree)ctx.booleanExpression())));
        }

        @Override
        public Filter.Operand visitNotExpression(FiltersParser.NotExpressionContext ctx) {
            return new Filter.Expression(Filter.ExpressionType.NOT, (Filter.Operand)this.visit((ParseTree)ctx.booleanExpression()), null);
        }

        public Filter.Expression castToExpression(Filter.Operand expression) {
            if (expression instanceof Filter.Group) {
                Filter.Group group = (Filter.Group)expression;
                return group.content();
            }
            if (expression instanceof Filter.Expression) {
                Filter.Expression exp = (Filter.Expression)expression;
                return exp;
            }
            throw new RuntimeException("Invalid expression: " + String.valueOf(expression));
        }
    }

    public static class FilterExpressionParseException
    extends RuntimeException {
        public FilterExpressionParseException(String message, Throwable cause) {
            super(message, cause);
        }
    }
}

