/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.chat.prompt;

import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.Charset;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.antlr.runtime.Token;
import org.antlr.runtime.TokenStream;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplateActions;
import org.springframework.ai.chat.prompt.PromptTemplateMessageActions;
import org.springframework.ai.chat.prompt.TemplateFormat;
import org.springframework.ai.model.Media;
import org.springframework.core.io.Resource;
import org.springframework.util.StreamUtils;
import org.stringtemplate.v4.ST;

public class PromptTemplate
implements PromptTemplateActions,
PromptTemplateMessageActions {
    protected String template;
    protected TemplateFormat templateFormat = TemplateFormat.ST;
    private ST st;
    private Map<String, Object> dynamicModel = new HashMap<String, Object>();

    public PromptTemplate(Resource resource) {
        try (InputStream inputStream = resource.getInputStream();){
            this.template = StreamUtils.copyToString((InputStream)inputStream, (Charset)Charset.defaultCharset());
        }
        catch (IOException ex) {
            throw new RuntimeException("Failed to read resource", ex);
        }
        try {
            this.st = new ST(this.template, '{', '}');
        }
        catch (Exception ex) {
            throw new IllegalArgumentException("The template string is not valid.", ex);
        }
    }

    public PromptTemplate(String template) {
        this.template = template;
        try {
            this.st = new ST(this.template, '{', '}');
        }
        catch (Exception ex) {
            throw new IllegalArgumentException("The template string is not valid.", ex);
        }
    }

    public PromptTemplate(String template, Map<String, Object> model) {
        this.template = template;
        try {
            this.st = new ST(this.template, '{', '}');
            for (Map.Entry<String, Object> entry : model.entrySet()) {
                this.add(entry.getKey(), entry.getValue());
            }
        }
        catch (Exception ex) {
            throw new IllegalArgumentException("The template string is not valid.", ex);
        }
    }

    public PromptTemplate(Resource resource, Map<String, Object> model) {
        try (InputStream inputStream = resource.getInputStream();){
            this.template = StreamUtils.copyToString((InputStream)inputStream, (Charset)Charset.defaultCharset());
        }
        catch (IOException ex) {
            throw new RuntimeException("Failed to read resource", ex);
        }
        try {
            this.st = new ST(this.template, '{', '}');
            for (Map.Entry<String, Object> entry : model.entrySet()) {
                this.add(entry.getKey(), entry.getValue());
            }
        }
        catch (Exception ex) {
            throw new IllegalArgumentException("The template string is not valid.", ex);
        }
    }

    public void add(String name, Object value) {
        this.st.add(name, value);
        this.dynamicModel.put(name, value);
    }

    public String getTemplate() {
        return this.template;
    }

    public TemplateFormat getTemplateFormat() {
        return this.templateFormat;
    }

    @Override
    public String render() {
        this.validate(this.dynamicModel);
        return this.st.render();
    }

    @Override
    public String render(Map<String, Object> model) {
        this.validate(model);
        for (Map.Entry<String, Object> entry : model.entrySet()) {
            if (this.st.getAttribute(entry.getKey()) != null) {
                this.st.remove(entry.getKey());
            }
            if (entry.getValue() instanceof Resource) {
                this.st.add(entry.getKey(), (Object)this.renderResource((Resource)entry.getValue()));
                continue;
            }
            this.st.add(entry.getKey(), entry.getValue());
        }
        return this.st.render();
    }

    private String renderResource(Resource resource) {
        try {
            return resource.getContentAsString(Charset.defaultCharset());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public Message createMessage() {
        return new UserMessage(this.render());
    }

    @Override
    public Message createMessage(List<Media> mediaList) {
        return new UserMessage(this.render(), mediaList);
    }

    @Override
    public Message createMessage(Map<String, Object> model) {
        return new UserMessage(this.render(model));
    }

    @Override
    public Prompt create() {
        return new Prompt(this.render(new HashMap<String, Object>()));
    }

    @Override
    public Prompt create(ChatOptions modelOptions) {
        return new Prompt(this.render(new HashMap<String, Object>()), modelOptions);
    }

    @Override
    public Prompt create(Map<String, Object> model) {
        return new Prompt(this.render(model));
    }

    @Override
    public Prompt create(Map<String, Object> model, ChatOptions modelOptions) {
        return new Prompt(this.render(model), modelOptions);
    }

    public Set<String> getInputVariables() {
        TokenStream tokens = this.st.impl.tokens;
        HashSet<String> inputVariables = new HashSet<String>();
        boolean isInsideList = false;
        for (int i = 0; i < tokens.size(); ++i) {
            Token token = tokens.get(i);
            if (token.getType() == 23 && i + 1 < tokens.size() && tokens.get(i + 1).getType() == 25) {
                if (i + 2 >= tokens.size() || tokens.get(i + 2).getType() != 13) continue;
                inputVariables.add(tokens.get(i + 1).getText());
                isInsideList = true;
                continue;
            }
            if (token.getType() == 24) {
                isInsideList = false;
                continue;
            }
            if (isInsideList || token.getType() != 25) continue;
            inputVariables.add(token.getText());
        }
        return inputVariables;
    }

    private Set<String> getModelKeys(Map<String, Object> model) {
        HashSet<String> dynamicVariableNames = new HashSet<String>(this.dynamicModel.keySet());
        HashSet<String> modelVariables = new HashSet<String>(model.keySet());
        modelVariables.addAll(dynamicVariableNames);
        return modelVariables;
    }

    protected void validate(Map<String, Object> model) {
        Set<String> templateTokens = this.getInputVariables();
        Set<String> modelKeys = this.getModelKeys(model);
        if (!modelKeys.containsAll(templateTokens)) {
            templateTokens.removeAll(modelKeys);
            throw new IllegalStateException("Not all template variables were replaced. Missing variable names are " + String.valueOf(templateTokens));
        }
    }
}

