Skip to content

OffTopic

any_guardrail.guardrails.off_topic

OffTopic

Bases: HuggingFace[Any, Any, bool, dict[str, float], float]

Abstract base class for the Off Topic models.

For more information about the implementations about either off topic model, please see the below model cards:

Source code in src/any_guardrail/guardrails/off_topic/off_topic.py
class OffTopic(HuggingFace[Any, Any, bool, dict[str, float], float]):
    """Abstract base class for the Off Topic models.

    For more information about the implementations about either off topic model, please see the below model cards:

    - [govtech/stsb-roberta-base-off-topic model](https://huggingface.co/govtech/stsb-roberta-base-off-topic).
    - [govtech/jina-embeddings-v2-small-en-off-topic](https://huggingface.co/govtech/jina-embeddings-v2-small-en-off-topic).
    """

    SUPPORTED_MODELS: ClassVar = [
        "mozilla-ai/jina-embeddings-v2-small-en-off-topic",
        "mozilla-ai/stsb-roberta-base-off-topic",
    ]

    implementation: OffTopicJina | OffTopicStsb

    def __init__(self, model_id: str | None = None) -> None:
        """Off Topic model based on one of two implementations decided by model ID."""
        if model_id is None:
            model_id = self.SUPPORTED_MODELS[0]

        if model_id == self.SUPPORTED_MODELS[0]:
            self.implementation = OffTopicJina()
        elif model_id == self.SUPPORTED_MODELS[1]:
            self.implementation = OffTopicStsb()
        else:
            msg = f"Unsupported model_id: {model_id}"
            raise ValueError(msg)
        super().__init__(model_id)

    def validate(
        self, input_text: str, comparison_text: str | None = None
    ) -> GuardrailOutput[bool, dict[str, float], float]:
        """Compare two texts to see if they are relevant to each other.

        Args:
            input_text: the original text you want to compare against.
            comparison_text: the text you want to compare to.

        Returns:
            valid=False means off topic, valid=True  means on topic. Will also provide probabilities of each.

        """
        msg = "Must provide a text to compare to."
        if not comparison_text:
            raise ValueError(msg)
        model_inputs: Any = self.implementation._pre_processing(input_text, comparison_text)
        model_outputs: Any = self.implementation._inference(model_inputs)
        return self._post_processing(model_outputs)

    def _load_model(self) -> None:
        self.implementation._load_model()

    def _post_processing(
        self, model_outputs: GuardrailInferenceOutput[Any]
    ) -> GuardrailOutput[bool, dict[str, float], float]:
        return self.implementation._post_processing(model_outputs)
__init__(model_id=None)

Off Topic model based on one of two implementations decided by model ID.

Source code in src/any_guardrail/guardrails/off_topic/off_topic.py
def __init__(self, model_id: str | None = None) -> None:
    """Off Topic model based on one of two implementations decided by model ID."""
    if model_id is None:
        model_id = self.SUPPORTED_MODELS[0]

    if model_id == self.SUPPORTED_MODELS[0]:
        self.implementation = OffTopicJina()
    elif model_id == self.SUPPORTED_MODELS[1]:
        self.implementation = OffTopicStsb()
    else:
        msg = f"Unsupported model_id: {model_id}"
        raise ValueError(msg)
    super().__init__(model_id)
validate(input_text, comparison_text=None)

Compare two texts to see if they are relevant to each other.

Parameters:

Name Type Description Default
input_text str

the original text you want to compare against.

required
comparison_text str | None

the text you want to compare to.

None

Returns:

Type Description
GuardrailOutput[bool, dict[str, float], float]

valid=False means off topic, valid=True means on topic. Will also provide probabilities of each.

Source code in src/any_guardrail/guardrails/off_topic/off_topic.py
def validate(
    self, input_text: str, comparison_text: str | None = None
) -> GuardrailOutput[bool, dict[str, float], float]:
    """Compare two texts to see if they are relevant to each other.

    Args:
        input_text: the original text you want to compare against.
        comparison_text: the text you want to compare to.

    Returns:
        valid=False means off topic, valid=True  means on topic. Will also provide probabilities of each.

    """
    msg = "Must provide a text to compare to."
    if not comparison_text:
        raise ValueError(msg)
    model_inputs: Any = self.implementation._pre_processing(input_text, comparison_text)
    model_outputs: Any = self.implementation._inference(model_inputs)
    return self._post_processing(model_outputs)