Skip to content

Commit 77abb68

Browse files
committed
Feat: api-based profanity filter
1 parent 19c6aeb commit 77abb68

File tree

5 files changed

+156
-2
lines changed

5 files changed

+156
-2
lines changed

teapot/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ def trn_api_key():
1717
return os.getenv('TRN_API_KEY')
1818

1919

20+
def profanity_filter():
21+
return os.getenv("PROFANITY_FILTER", "none")
22+
2023
def bot_prefix():
2124
return eval(os.getenv('BOT_PREFIX', "['/teapot ', '/tp']"))
2225

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import re
2+
import teapot
3+
from profanity_check import predict_prob
4+
import discord
5+
from teapot.tools.teascout_client import TeaScout
6+
7+
# Pre-initialize TeaScout client and model name if using external API
8+
client = None
9+
_profanity_model_name = None
10+
cfg = teapot.config.profanity_filter()
11+
if isinstance(cfg, str) and cfg.startswith("http"):
12+
# Expected format: https://api_key:model_name@example.com
13+
m = re.match(r'^(https?://)', cfg)
14+
scheme = m.group(1) if m else ''
15+
left, host = cfg.split('@', 1)
16+
# remove scheme from left if present
17+
if left.startswith('http://') or left.startswith('https://'):
18+
left = left.split('://', 1)[1]
19+
# left should now be 'api_key:model_name'
20+
if ':' in left:
21+
api_key, _profanity_model_name = left.split(':', 1)
22+
else:
23+
api_key = left
24+
_profanity_model_name = None
25+
base_url = scheme + host
26+
client = TeaScout(base_url, api_key)
27+
28+
29+
async def on_message_handler(bot, message):
30+
punctuations = '!()-[]{};:\'"\\,<>./?@#$%^&*_~'
31+
msg = ""
32+
for char in message.content.lower():
33+
if char not in punctuations:
34+
msg = msg + char
35+
if teapot.config.profanity_filter() == "none":
36+
return
37+
elif teapot.config.profanity_filter() == "local":
38+
prob = predict_prob([msg])
39+
if prob[0] >= 0.8:
40+
em = discord.Embed(title=f"AI Analysis Results", color=0xC54B4F)
41+
em.add_field(name='PROFANITY DETECTED! ', value=str(prob[0]))
42+
await message.channel.send(embed=em)
43+
else:
44+
# Use parsed model name if available, otherwise fall back to default
45+
model_name = _profanity_model_name or "default"
46+
model = client.model(model_name).text(msg)
47+
result = model.inference()
48+
score = result.get('score', 0)
49+
if score >= 0.8:
50+
em = discord.Embed(title=f"AI Analysis Results", color=0xC54B4F)
51+
em.add_field(name='Hate Speech detected! ', value=str(score))
52+
em.set_footer(text="This is a experimental feature powered by TeaScout. Feedback is welcome!")
53+
await message.channel.send(embed=em)

teapot/events.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import json
22
import teapot
33
import discord
4-
from profanity_check import predict_prob
54

65

76
def __init__(bot):

teapot/setup.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,19 @@ def __init__():
4545
input_lavalink_port = input("Lavalink Port: ")
4646
input_lavalink_password = input("Lavalink Password: ")
4747

48-
input_osu_api_key = input("osu!api Key")
48+
input_osu_api_key = input("osu!api Key:")
49+
50+
input_profanity_filter = input("Enable profanity filter? [Y/n] ")
51+
if input_profanity_filter.lower() == "y" or input_profanity_filter.lower() == "yes":
52+
# ask for which implementation did user want, either local or api
53+
input_profanity_impl = input("Use local or api implementation? [local/api] ")
54+
if input_profanity_impl.lower() == "api":
55+
input_profanity_impl = input("Profanity API (should be in format like https://api_key:model_name@example.com): ")
56+
else:
57+
input_profanity_impl = "local"
58+
else:
59+
input_profanity_impl = "none"
60+
4961

5062
try:
5163
config = f"""CONFIG_VERSION={teapot.config_version()}
@@ -65,6 +77,7 @@ def __init__():
6577
LAVALINK_PASSWORD={input_lavalink_password}
6678
6779
OSU_API_KEY={input_osu_api_key}
80+
PROFANITY_FILTER={input_profanity_impl}
6881
"""
6982
open('./.env', 'w').write(config)
7083
print("\n[*] Successfully created .env file!")

teapot/tools/teascout_client.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import requests
2+
from typing import Optional, Dict, Any, List
3+
4+
class TeaScout:
5+
def __init__(self, url: str, key: Optional[str] = None):
6+
"""
7+
Initialize the TeaScout client.
8+
9+
Args:
10+
url (str): The base URL of the TeaScout server (e.g., "http://localhost:5000").
11+
key (str, optional): API Key if required by the server.
12+
"""
13+
self.base_url = url.rstrip('/')
14+
self.api_key = key
15+
self.session = requests.Session()
16+
if self.api_key:
17+
self.session.headers.update({'Authorization': f'Bearer {self.api_key}'})
18+
19+
def list_models(self) -> Dict[str, str]:
20+
"""
21+
List available models from the server.
22+
23+
Returns:
24+
dict: A dictionary of model names and their descriptions.
25+
"""
26+
try:
27+
response = self.session.get(f"{self.base_url}/models")
28+
response.raise_for_status()
29+
return response.json()
30+
except requests.RequestException as e:
31+
raise ConnectionError(f"Failed to fetch models: {e}")
32+
33+
def model(self, model_name: str) -> 'ModelContext':
34+
"""
35+
Select a model to work with.
36+
37+
Args:
38+
model_name (str): The ID of the model to use.
39+
40+
Returns:
41+
ModelContext: A context object for building the request.
42+
"""
43+
return ModelContext(self, model_name)
44+
45+
class ModelContext:
46+
def __init__(self, client: TeaScout, model_name: str):
47+
self.client = client
48+
self.model_name = model_name
49+
self._text_content = None
50+
51+
def text(self, content: str) -> 'ModelContext':
52+
"""
53+
Set the text content for inference.
54+
55+
Args:
56+
content (str): The text to analyze.
57+
58+
Returns:
59+
ModelContext: Returns self for chaining.
60+
"""
61+
self._text_content = content
62+
return self
63+
64+
def inference(self) -> Dict[str, Any]:
65+
"""
66+
Execute the inference request.
67+
68+
Returns:
69+
dict: The inference result from the server.
70+
71+
Raises:
72+
ValueError: If text content is not set.
73+
ConnectionError: If the request fails.
74+
"""
75+
if self._text_content is None:
76+
raise ValueError("Text content must be set using .text() before calling .inference()")
77+
78+
url = f"{self.client.base_url}/inference/{self.model_name}"
79+
payload = {"text": self._text_content}
80+
81+
try:
82+
response = self.client.session.post(url, json=payload)
83+
response.raise_for_status()
84+
return response.json()
85+
except requests.RequestException as e:
86+
raise ConnectionError(f"Inference request failed: {e}")

0 commit comments

Comments
 (0)