Source code for ask_youtube_playlists.question_answering.extractive

"""Contains the functionality to perform extractive question answering."""
import streamlit as st

import functools
from typing import Any, Tuple

from transformers import (AutoModelForQuestionAnswering,
                          AutoTokenizer,
                          pipeline)


EXTRACTIVE_MODEL_NAMES = [
    "deepset/roberta-base-squad2",
]


@functools.lru_cache(maxsize=1)
def _load_extractive_model(model_name: str = "deepset/roberta-base-squad2"
                           ) -> Tuple[Any, Any]:
    """Loads the extractive question answering model.

    Args:
        model_name (str, optional): The model name. Defaults to
            "deepset/roberta-base-squad2".

    Returns:
        AutoModelForQuestionAnswering, AutoTokenizer: The model and tokenizer.
    """
    model = AutoModelForQuestionAnswering.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer


[docs]@st.cache_data def get_extractive_answer(question: str, context: str, model_name: str = "deepset/roberta-base-squad2", ) -> str: """Returns the answer to a question using extractive question answering. Args: question (str): The question. context (str): The context. model_name (str, optional): The model name. Defaults to "deepset/roberta-base-squad2". Returns: A dictionary with the 'answer' as a string, the 'score' as a float and the 'start' and 'end' as integers. """ model, tokenizer = _load_extractive_model(model_name) qa_input = { 'question': question, 'context': context } nlp = pipeline('question-answering', model=model, tokenizer=tokenizer) res = nlp(qa_input) return res