Source code for ask_youtube_playlists.question_answering.generative

"""Contains the functionality to answer a question using generative
models."""
import streamlit as st
from dataclasses import dataclass
from typing import List, Dict, Union

import langchain
from langchain import llms
from langchain.schema import Document


[docs]@dataclass class LLMSpec: """Class to store the information of a language model. Attributes: model_name (str): The name of the language model. model_type (str): The class or method used to load the language model. """ model_name: str model_type: str max_tokens: int
GENERATIVE_MODELS = [ LLMSpec("gpt2", "huggingface-pipeline", max_tokens=1024), LLMSpec("gpt-3.5-turbo", "openai-chat", max_tokens=4096), LLMSpec("gpt-3.5-turbo-16k", "openai-chat", max_tokens=16384), LLMSpec("gpt-4", "openai-chat", max_tokens=8192), ] GENERATIVE_MODEL_NAMES = [model_spec.model_name for model_spec in GENERATIVE_MODELS]
[docs]def get_model_spec(model_name: str) -> LLMSpec: """Returns the language model specification. Args: model_name (str): The name of the language model. Returns: LLMSpec: The language model specification. Raises: ValueError: If the language model is not available. """ for model_spec in GENERATIVE_MODELS: if model_spec.model_name == model_name: return model_spec raise ValueError(f"Model '{model_name}' not available. Available " f"models are: {GENERATIVE_MODELS}")
[docs]def load_model(model_name: str, temperature: float = 0.7, max_length: int = 1024, ) -> llms.base.BaseLLM: """Loads the language model. Args: model_name (str): The language model name. temperature (float, optional): The temperature used to generate the answer. The higher the temperature, the more "creative" the answer will be. Defaults to 0.7. max_length (int, optional): The maximum length of the generated answer. Defaults to 128. Returns: llms.base.BaseLLM: The language model. """ model_spec = get_model_spec(model_name) if model_spec.model_type == "openai-chat": llm = llms.OpenAIChat( # type: ignore model_name=model_spec.model_name, model_kwargs={"temperature": temperature, "max_length": max_length} ) return llm if model_spec.model_type == "huggingface-pipeline": llm = llms.HuggingFacePipeline.from_model_id( # type: ignore model_id=model_spec.model_name, task="text-generation", model_kwargs={"temperature": temperature, "max_length": max_length} ) return llm available_models = [model.model_name for model in GENERATIVE_MODELS] raise ValueError(f"Model type '{model_spec.model_type}' not available. " f"Available models are: {available_models}")
def _get_generative_prompt_template(retrieved_documents: List[Document], ) -> langchain.PromptTemplate: """Returns the template used to generate the answer. Returns: langchain.PromptTemplate: The template used to generate the answer. """ template_text = "" for document in reversed(retrieved_documents): template_text += f"{document.page_content}\n\n" template_text += "Question: {question}\n\n" template_text += "Answer:" template = langchain.PromptTemplate(template=template_text, input_variables=["question"]) return template
[docs]@st.cache_data def get_generative_answer(question: str, relevant_documents: List[Document], model_name: str, temperature: int, max_length: int) -> str: """Returns the answer to the question as a string. Args: question (str): The question asked by the user. relevant_documents (List[Document]): The list of relevant documents. model_name (str): The name of the language model. temperature (float): The temperature used to generate the answer. max_length (int): The maximum length of the generated answer. Returns: str: The answer to the question. """ model = load_model(model_name=model_name, temperature=temperature, max_length=max_length) template = _get_generative_prompt_template(relevant_documents) prompt = template.format(question=question) answer = model.generate(prompts=[prompt]) return answer.generations[0][0].text