Skip to content

OffTopic

any_guardrail.guardrails.off_topic

OffTopic

Bases: HuggingFace

Abstract base class for the Off Topic models.

For more information about the implementations about either off topic model, please see the below model cards:

Source code in src/any_guardrail/guardrails/off_topic/off_topic.py
class OffTopic(HuggingFace):
    """Abstract base class for the Off Topic models.

    For more information about the implementations about either off topic model, please see the below model cards:

    - [govtech/stsb-roberta-base-off-topic model](https://huggingface.co/govtech/stsb-roberta-base-off-topic).
    - [govtech/jina-embeddings-v2-small-en-off-topic](https://huggingface.co/govtech/jina-embeddings-v2-small-en-off-topic).
    """

    SUPPORTED_MODELS: ClassVar = [
        "mozilla-ai/jina-embeddings-v2-small-en-off-topic",
        "mozilla-ai/stsb-roberta-base-off-topic",
    ]

    implementation: OffTopicJina | OffTopicStsb

    def __init__(self, model_id: str | None = None) -> None:
        """Off Topic model based on one of two implementations decided by model ID."""
        super().__init__(model_id)
        if self.model_id == self.SUPPORTED_MODELS[0]:
            self.implementation = OffTopicJina()
        elif self.model_id == self.SUPPORTED_MODELS[1]:
            self.implementation = OffTopicStsb()
        else:
            msg = f"Unsupported model_id: {self.model_id}"
            raise ValueError(msg)
        super().__init__()

    def validate(self, input_text: str, comparison_text: str | None = None) -> GuardrailOutput:
        """Compare two texts to see if they are relevant to each other.

        Args:
            input_text: the original text you want to compare against.
            comparison_text: the text you want to compare to.

        Returns:
            valid=False means off topic, valid=True  means on topic. Will also provide probabilities of each.

        """
        msg = "Must provide a text to compare to."
        if comparison_text:
            raise ValueError(msg)
        model_inputs = self.implementation._pre_processing(input_text, comparison_text)
        model_outputs = self.implementation._inference(model_inputs)
        return self._post_processing(model_outputs)

    def _load_model(self) -> None:
        self.implementation._load_model()

    def _post_processing(self, model_outputs: Any) -> GuardrailOutput:
        return self.implementation._post_processing(model_outputs)
__init__(model_id=None)

Off Topic model based on one of two implementations decided by model ID.

Source code in src/any_guardrail/guardrails/off_topic/off_topic.py
def __init__(self, model_id: str | None = None) -> None:
    """Off Topic model based on one of two implementations decided by model ID."""
    super().__init__(model_id)
    if self.model_id == self.SUPPORTED_MODELS[0]:
        self.implementation = OffTopicJina()
    elif self.model_id == self.SUPPORTED_MODELS[1]:
        self.implementation = OffTopicStsb()
    else:
        msg = f"Unsupported model_id: {self.model_id}"
        raise ValueError(msg)
    super().__init__()
validate(input_text, comparison_text=None)

Compare two texts to see if they are relevant to each other.

Parameters:

Name Type Description Default
input_text str

the original text you want to compare against.

required
comparison_text str | None

the text you want to compare to.

None

Returns:

Type Description
GuardrailOutput

valid=False means off topic, valid=True means on topic. Will also provide probabilities of each.

Source code in src/any_guardrail/guardrails/off_topic/off_topic.py
def validate(self, input_text: str, comparison_text: str | None = None) -> GuardrailOutput:
    """Compare two texts to see if they are relevant to each other.

    Args:
        input_text: the original text you want to compare against.
        comparison_text: the text you want to compare to.

    Returns:
        valid=False means off topic, valid=True  means on topic. Will also provide probabilities of each.

    """
    msg = "Must provide a text to compare to."
    if comparison_text:
        raise ValueError(msg)
    model_inputs = self.implementation._pre_processing(input_text, comparison_text)
    model_outputs = self.implementation._inference(model_inputs)
    return self._post_processing(model_outputs)