Skip to content

OffTopic

any_guardrail.guardrails.off_topic

OffTopic

Bases: ThreeStageGuardrail[Any, Any, bool, dict[str, float], float]

Abstract base class for the Off Topic models.

For more information about the implementations about either off topic model, please see the below model cards:

Source code in src/any_guardrail/guardrails/off_topic/off_topic.py
class OffTopic(ThreeStageGuardrail[Any, Any, bool, dict[str, float], float]):
    """Abstract base class for the Off Topic models.

    For more information about the implementations about either off topic model, please see the below model cards:

    - [govtech/stsb-roberta-base-off-topic model](https://huggingface.co/govtech/stsb-roberta-base-off-topic).
    - [govtech/jina-embeddings-v2-small-en-off-topic](https://huggingface.co/govtech/jina-embeddings-v2-small-en-off-topic).
    """

    SUPPORTED_MODELS: ClassVar = [
        "mozilla-ai/jina-embeddings-v2-small-en-off-topic",
        "mozilla-ai/stsb-roberta-base-off-topic",
    ]

    implementation: OffTopicJina | OffTopicStsb

    def __init__(
        self,
        model_id: str | None = None,
        provider: StandardProvider | None = None,  # Reserved for future extensibility
    ) -> None:
        """Off Topic model based on one of two implementations decided by model ID."""
        self.model_id = default(model_id, self.SUPPORTED_MODELS)
        self.provider = provider  # Reserved for future extensibility
        if self.model_id == self.SUPPORTED_MODELS[0]:
            self.implementation = OffTopicJina(provider=provider)
        else:
            self.implementation = OffTopicStsb(provider=provider)

    def validate(  # type: ignore[override]
        self, input_text: str, comparison_text: str | None = None
    ) -> GuardrailOutput[bool, dict[str, float], float]:
        """Compare two texts to see if they are relevant to each other.

        Args:
            input_text: the original text you want to compare against.
            comparison_text: the text you want to compare to.

        Returns:
            valid=False means off topic, valid=True  means on topic. Will also provide probabilities of each.

        """
        msg = "Must provide a text to compare to."
        if not comparison_text:
            raise ValueError(msg)
        model_inputs: Any = self.implementation._pre_processing(input_text, comparison_text)
        model_outputs: Any = self.implementation._inference(model_inputs)
        return self._post_processing(model_outputs)

    def _pre_processing(self, *args: Any, **kwargs: Any) -> GuardrailPreprocessOutput[Any]:
        return self.implementation._pre_processing(*args, **kwargs)

    def _inference(self, model_inputs: GuardrailPreprocessOutput[Any]) -> GuardrailInferenceOutput[Any]:
        return self.implementation._inference(model_inputs)

    def _post_processing(
        self, model_outputs: GuardrailInferenceOutput[Any]
    ) -> GuardrailOutput[bool, dict[str, float], float]:
        return self.implementation._post_processing(model_outputs)
__init__(model_id=None, provider=None)

Off Topic model based on one of two implementations decided by model ID.

Source code in src/any_guardrail/guardrails/off_topic/off_topic.py
def __init__(
    self,
    model_id: str | None = None,
    provider: StandardProvider | None = None,  # Reserved for future extensibility
) -> None:
    """Off Topic model based on one of two implementations decided by model ID."""
    self.model_id = default(model_id, self.SUPPORTED_MODELS)
    self.provider = provider  # Reserved for future extensibility
    if self.model_id == self.SUPPORTED_MODELS[0]:
        self.implementation = OffTopicJina(provider=provider)
    else:
        self.implementation = OffTopicStsb(provider=provider)
validate(input_text, comparison_text=None)

Compare two texts to see if they are relevant to each other.

Parameters:

Name Type Description Default
input_text str

the original text you want to compare against.

required
comparison_text str | None

the text you want to compare to.

None

Returns:

Type Description
GuardrailOutput[bool, dict[str, float], float]

valid=False means off topic, valid=True means on topic. Will also provide probabilities of each.

Source code in src/any_guardrail/guardrails/off_topic/off_topic.py
def validate(  # type: ignore[override]
    self, input_text: str, comparison_text: str | None = None
) -> GuardrailOutput[bool, dict[str, float], float]:
    """Compare two texts to see if they are relevant to each other.

    Args:
        input_text: the original text you want to compare against.
        comparison_text: the text you want to compare to.

    Returns:
        valid=False means off topic, valid=True  means on topic. Will also provide probabilities of each.

    """
    msg = "Must provide a text to compare to."
    if not comparison_text:
        raise ValueError(msg)
    model_inputs: Any = self.implementation._pre_processing(input_text, comparison_text)
    model_outputs: Any = self.implementation._inference(model_inputs)
    return self._post_processing(model_outputs)