Source code for ask_youtube_playlists.question_answering.retriever

"""Contains the functionality used to retrieve the most relevant documents
for a given question."""
import pathlib

from typing import List, NamedTuple

import numpy as np
import yaml

from langchain.schema import Document

from ask_youtube_playlists.data_processing import (
    get_embedding_model,
    get_documents_from_directory,
    load_embeddings,
)


[docs]class DocumentInfo(NamedTuple): """Class to store information about a document. Attributes: document: The document text or content. score: The relevance score of the document. The higher the score, the more relevant the document is. It is in the range [0, 1]. playlist_name: The name of the playlist to which the document belongs. """ document: Document score: float playlist_name: str
[docs]class Retriever: """Class to retrieve the most relevant documents for a given question.""" def __init__(self, retriever_directory: pathlib.Path, config_filename: str = "hyperparams.yaml"): self.retriever_directory = retriever_directory self.embedding_model_name = "" self.max_chunk_size = None self.min_overlap_size = None self._load_config(config_filename) self.embedding_model = get_embedding_model(self.embedding_model_name) chunked_data_directory = retriever_directory / "chunked_data" self.documents = get_documents_from_directory(chunked_data_directory) embedding_directory = retriever_directory / "embeddings" self.video_embeddings = load_embeddings(embedding_directory) @property def total_number_of_documents(self) -> int: """Returns the total number of documents.""" return sum(len(video) for video in self.documents) def _load_config(self, filename: str = "hyperparams.yaml"): """Loads the hyperparameters from a YAML file. Args: filename (str): The name of the YAML file. """ path_to_config = self.retriever_directory / filename with open(path_to_config, "r") as file: config = yaml.safe_load(file) self.embedding_model_name = config["model_name"] self.max_chunk_size = config["max_chunk_size"] self.min_overlap_size = config["min_overlap_size"]
[docs] @staticmethod def cosine_distance(question_embedding: np.ndarray, document_embedding: np.ndarray) -> float: """Calculates the cosine distance between two vectors. Args: question_embedding (np.ndarray): The embedding of the question. document_embedding (np.ndarray): The embedding of the document. Returns: float: The cosine distance between the two vectors. """ question_embedding_norm = np.linalg.norm(question_embedding) document_embedding_norm = np.linalg.norm(document_embedding) denominator = question_embedding_norm * document_embedding_norm return np.dot(question_embedding, document_embedding) / denominator
[docs] def retrieve_from_playlist(self, question: str, n_documents: int) -> List[DocumentInfo]: """Retrieves the most relevant documents with their relevance score. Args: question (str): The question posed by the user. n_documents (int): The number of documents to retrieve. """ n_documents = min(n_documents, self.total_number_of_documents) playlist_name = self.retriever_directory.parent.name question_embedding = self.embedding_model.embed_query(question) question_embedding = np.array(question_embedding) # type: ignore document_infos = [] for video_documents, video_embedding in zip(self.documents, self.video_embeddings): iterator = zip(video_documents, video_embedding) for i, (document, document_embedding) in enumerate(iterator): if document.metadata["index"] != i: raise ValueError("The index of the document does not match" " its position in the list." f" Index: {document.metadata['index']}," f" Position: {i}") score = 1 - self.cosine_distance( question_embedding, document_embedding # type: ignore ) document_info = DocumentInfo(document=document, score=score, playlist_name=playlist_name) document_infos.append(document_info) document_infos.sort(key=lambda x: x.score, reverse=True) return document_infos[:n_documents]
[docs] @classmethod def retrieve(cls, retrievers: List['Retriever'], question: str, n_documents: int) -> List[DocumentInfo]: """Retrieves the most relevant documents with their score and the playlist they belong to. This function retrieves documents in two steps: 1. Extracts the most relevant documents from each retriever in 2. Ranks the retrieved documents from all retrievers and returns the most relevant ones, in addition to their score and the playlist they belong to. Args: retrievers (List[Retriever]): A list of retrievers. question (str): The question posed by the user. n_documents (int): The number of documents to retrieve. Returns: list: A list of named tuples, each containing the document, its score and the playlist it belongs to. The list is sorted in descending order by relevance score. """ document_infos = [] for retriever in retrievers: document_infos.extend(retriever.retrieve_from_playlist( question, n_documents )) document_infos.sort(key=lambda x: x.score, reverse=True) return document_infos[:n_documents]