"""Contains the functionality to answer a question using generative
models."""
import streamlit as st
from dataclasses import dataclass
from typing import List, Dict, Union
import langchain
from langchain import llms
from langchain.schema import Document
[docs]@dataclass
class LLMSpec:
    """Class to store the information of a language model.
    Attributes:
        model_name (str): The name of the language model.
        model_type (str): The class or method used to load the language model.
    """
    model_name: str
    model_type: str
    max_tokens: int 
GENERATIVE_MODELS = [
    LLMSpec("gpt2", "huggingface-pipeline", max_tokens=1024),
    LLMSpec("gpt-3.5-turbo", "openai-chat", max_tokens=4096),
    LLMSpec("gpt-3.5-turbo-16k", "openai-chat", max_tokens=16384),
    LLMSpec("gpt-4", "openai-chat", max_tokens=8192),
]
GENERATIVE_MODEL_NAMES = [model_spec.model_name
                          for model_spec in GENERATIVE_MODELS]
[docs]def get_model_spec(model_name: str) -> LLMSpec:
    """Returns the language model specification.
    Args:
        model_name (str): The name of the language model.
    Returns:
        LLMSpec: The language model specification.
    Raises:
        ValueError: If the language model is not available.
    """
    for model_spec in GENERATIVE_MODELS:
        if model_spec.model_name == model_name:
            return model_spec
    raise ValueError(f"Model '{model_name}' not available. Available "
                     f"models are: {GENERATIVE_MODELS}") 
[docs]def load_model(model_name: str,
               temperature: float = 0.7,
               max_length: int = 1024,
               ) -> llms.base.BaseLLM:
    """Loads the language model.
    Args:
        model_name (str): The language model name.
        temperature (float, optional): The temperature used to generate the
            answer. The higher the temperature, the more "creative" the answer
            will be. Defaults to 0.7.
        max_length (int, optional): The maximum length of the generated answer.
            Defaults to 128.
    Returns:
        llms.base.BaseLLM: The language model.
    """
    model_spec = get_model_spec(model_name)
    if model_spec.model_type == "openai-chat":
        llm = llms.OpenAIChat(  # type: ignore
            model_name=model_spec.model_name,
            model_kwargs={"temperature": temperature,
                          "max_length": max_length}
        )
        return llm
    if model_spec.model_type == "huggingface-pipeline":
        llm = llms.HuggingFacePipeline.from_model_id(  # type: ignore
            model_id=model_spec.model_name,
            task="text-generation",
            model_kwargs={"temperature": temperature,
                          "max_length": max_length}
        )
        return llm
    available_models = [model.model_name for model in GENERATIVE_MODELS]
    raise ValueError(f"Model type '{model_spec.model_type}' not available. "
                     f"Available models are: {available_models}") 
def _get_generative_prompt_template(retrieved_documents: List[Document],
                                    ) -> langchain.PromptTemplate:
    """Returns the template used to generate the answer.
    Returns:
        langchain.PromptTemplate: The template used to generate the answer.
    """
    template_text = ""
    for document in reversed(retrieved_documents):
        template_text += f"{document.page_content}\n\n"
    template_text += "Question: {question}\n\n"
    template_text += "Answer:"
    template = langchain.PromptTemplate(template=template_text,
                                        input_variables=["question"])
    return template
[docs]@st.cache_data
def get_generative_answer(question: str,
                          relevant_documents: List[Document],
                          model_name: str,
                          temperature: int,
                          max_length: int) -> str:
    """Returns the answer to the question as a string.
    Args:
        question (str): The question asked by the user.
        relevant_documents (List[Document]): The list of relevant documents.
        model_name (str): The name of the language model.
        temperature (float): The temperature used to generate the answer.
        max_length (int): The maximum length of the generated answer.
    Returns:
        str: The answer to the question.
    """
    model = load_model(model_name=model_name,
                       temperature=temperature,
                       max_length=max_length)
    template = _get_generative_prompt_template(relevant_documents)
    prompt = template.format(question=question)
    answer = model.generate(prompts=[prompt])
    return answer.generations[0][0].text