yetessam commited on
Commit
1c4fc66
·
verified ·
1 Parent(s): 3f43d16

Update tools/polite_guard.py

Browse files
Files changed (1) hide show
  1. tools/polite_guard.py +9 -1
tools/polite_guard.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import Any, Optional
2
  from smolagents.tools import Tool
3
  from transformers import pipeline
 
4
 
5
 
6
  class PoliteGuardTool(Tool):
@@ -19,7 +20,7 @@ class PoliteGuardTool(Tool):
19
  name = "polite_guard"
20
  description = "Uses Polite guard to classify input text from polite to impolite name and it provides a score as well. Anything over .95 should be considered a significant threshold."
21
  inputs = {'input_text': {'type': 'any', 'description': 'Enter text for assessing whether it is respectful'}}
22
- output_type = "any"
23
 
24
  def forward(self, input_text: Any) -> Any:
25
  str_return_value = self.ask_polite_guard(input_text)
@@ -50,6 +51,13 @@ class PoliteGuardTool(Tool):
50
  score = output['score']
51
  str_output = f"label is {label} with a score of {score}"
52
  return str_output
 
 
 
 
 
 
 
53
 
54
  except Exception as e:
55
  return f"Error fetching classification for text '{input_text}': {str(e)}"
 
1
  from typing import Any, Optional
2
  from smolagents.tools import Tool
3
  from transformers import pipeline
4
+ import json
5
 
6
 
7
  class PoliteGuardTool(Tool):
 
20
  name = "polite_guard"
21
  description = "Uses Polite guard to classify input text from polite to impolite name and it provides a score as well. Anything over .95 should be considered a significant threshold."
22
  inputs = {'input_text': {'type': 'any', 'description': 'Enter text for assessing whether it is respectful'}}
23
+ output_type = "string"
24
 
25
  def forward(self, input_text: Any) -> Any:
26
  str_return_value = self.ask_polite_guard(input_text)
 
51
  score = output['score']
52
  str_output = f"label is {label} with a score of {score}"
53
  return str_output
54
+
55
+ # Ensure fixed keys + float score
56
+ payload = {
57
+ "label": str(output.get("label", "")),
58
+ "score": float(output.get("score", 0.5)),
59
+ }
60
+ return json.dumps(payload)
61
 
62
  except Exception as e:
63
  return f"Error fetching classification for text '{input_text}': {str(e)}"