Bases: StandardGuardrail
Wrapper class for Google ShieldGemma models.
For more information, please visit the model cards: Shield Gemma.
Note we do not support the image classifier.
Source code in src/any_guardrail/guardrails/shield_gemma/shield_gemma.py
| class ShieldGemma(StandardGuardrail):
"""Wrapper class for Google ShieldGemma models.
For more information, please visit the model cards: [Shield Gemma](https://huggingface.co/collections/google/shieldgemma-67d130ef8da6af884072a789).
Note we do not support the image classifier.
"""
SUPPORTED_MODELS: ClassVar = [
"google/shieldgemma-2b",
"google/shieldgemma-9b",
"google/shieldgemma-27b",
]
def __init__(
self,
policy: str,
threshold: float = DEFAULT_THRESHOLD,
model_id: str | None = None,
provider: StandardProvider | None = None,
) -> None:
"""Initialize the ShieldGemma guardrail."""
self.model_id = default(model_id, self.SUPPORTED_MODELS)
self.policy = policy
self.system_prompt = SYSTEM_PROMPT_SHIELD_GEMMA
self.threshold = threshold
if provider is not None:
self.provider = provider
else:
self.provider = HuggingFaceProvider(model_class=AutoModelForCausalLM, tokenizer_class=AutoTokenizer)
self.provider.load_model(self.model_id)
def _pre_processing(self, input_text: str) -> StandardPreprocessOutput:
formatted_prompt = self.system_prompt.format(user_prompt=input_text, safety_policy=self.policy)
tokenized = self.provider.tokenizer(formatted_prompt, return_tensors="pt") # type: ignore[attr-defined]
return GuardrailPreprocessOutput(data=tokenized)
def _inference(self, model_inputs: StandardPreprocessOutput) -> StandardInferenceOutput:
return self.provider.infer(model_inputs)
def _post_processing(self, model_outputs: StandardInferenceOutput) -> BinaryScoreOutput:
logits = model_outputs.data["logits"]
vocab = self.provider.tokenizer.get_vocab() # type: ignore[attr-defined]
selected_logits = logits[0, -1, [vocab["Yes"], vocab["No"]]]
probabilities = softmax(selected_logits, dim=0)
score = probabilities[0].item()
return GuardrailOutput(valid=score < self.threshold, explanation=None, score=score)
|
__init__(policy, threshold=DEFAULT_THRESHOLD, model_id=None, provider=None)
Initialize the ShieldGemma guardrail.
Source code in src/any_guardrail/guardrails/shield_gemma/shield_gemma.py
| def __init__(
self,
policy: str,
threshold: float = DEFAULT_THRESHOLD,
model_id: str | None = None,
provider: StandardProvider | None = None,
) -> None:
"""Initialize the ShieldGemma guardrail."""
self.model_id = default(model_id, self.SUPPORTED_MODELS)
self.policy = policy
self.system_prompt = SYSTEM_PROMPT_SHIELD_GEMMA
self.threshold = threshold
if provider is not None:
self.provider = provider
else:
self.provider = HuggingFaceProvider(model_class=AutoModelForCausalLM, tokenizer_class=AutoTokenizer)
self.provider.load_model(self.model_id)
|