diff --git a/bot.py b/bot.py index 4c5661f..6e1b1be 100644 --- a/bot.py +++ b/bot.py @@ -12,7 +12,7 @@ from dotenv import load_dotenv from utils.database import Database from utils.drama_tracker import DramaTracker -from utils.ollama_client import LLMClient +from utils.llm_client import LLMClient # Load .env load_dotenv() @@ -69,7 +69,7 @@ class BCSBot(commands.Bot): llm_base_url = os.getenv("LLM_BASE_URL", "http://athena.lan:11434") llm_model = os.getenv("LLM_MODEL", "Qwen3-VL-32B-Thinking-Q8_0") llm_api_key = os.getenv("LLM_API_KEY", "not-needed") - self.ollama = LLMClient(llm_base_url, llm_model, llm_api_key) + self.llm = LLMClient(llm_base_url, llm_model, llm_api_key) # Drama tracker sentiment = config.get("sentiment", {}) @@ -154,7 +154,7 @@ class BCSBot(commands.Bot): async def close(self): await self.db.close() - await self.ollama.close() + await self.llm.close() await super().close() diff --git a/cogs/chat.py b/cogs/chat.py index f24fb8c..a5abfad 100644 --- a/cogs/chat.py +++ b/cogs/chat.py @@ -70,11 +70,21 @@ class ChatCog(commands.Cog): {"role": "user", "content": f"{score_context}\n{message.author.display_name}: {content}"} ) - async with message.channel.typing(): - response = await self.bot.ollama.chat( - list(self._chat_history[ch_id]), - CHAT_PERSONALITY, - ) + typing_ctx = None + + async def start_typing(): + nonlocal typing_ctx + typing_ctx = message.channel.typing() + await typing_ctx.__aenter__() + + response = await self.bot.llm.chat( + list(self._chat_history[ch_id]), + CHAT_PERSONALITY, + on_first_token=start_typing, + ) + + if typing_ctx: + await typing_ctx.__aexit__(None, None, None) if response is None: response = "I'd roast you but my brain is offline. Try again later." diff --git a/cogs/commands.py b/cogs/commands.py index ec1af40..6ba3968 100644 --- a/cogs/commands.py +++ b/cogs/commands.py @@ -126,8 +126,8 @@ class CommandsCog(commands.Cog): inline=True, ) embed.add_field( - name="Ollama", - value=f"`{self.bot.ollama.model}` @ `{self.bot.ollama.host}`", + name="LLM", + value=f"`{self.bot.llm.model}` @ `{self.bot.llm.host}`", inline=False, ) @@ -301,7 +301,7 @@ class CommandsCog(commands.Cog): else "(no prior context)" ) - result = await self.bot.ollama.analyze_message(msg.content, context) + result = await self.bot.llm.analyze_message(msg.content, context) if result is None: embed = discord.Embed( title=f"Analysis: {msg.author.display_name}", @@ -359,7 +359,7 @@ class CommandsCog(commands.Cog): await interaction.response.defer(ephemeral=True) user_notes = self.bot.drama_tracker.get_user_notes(interaction.user.id) - raw, parsed = await self.bot.ollama.raw_analyze(message, user_notes=user_notes) + raw, parsed = await self.bot.llm.raw_analyze(message, user_notes=user_notes) embed = discord.Embed( title="BCS Test Analysis", color=discord.Color.blue() @@ -368,7 +368,7 @@ class CommandsCog(commands.Cog): name="Input Message", value=message[:1024], inline=False ) embed.add_field( - name="Raw Ollama Response", + name="Raw LLM Response", value=f"```json\n{raw[:1000]}\n```", inline=False, ) diff --git a/cogs/sentiment.py b/cogs/sentiment.py index bf97b86..5b532f2 100644 --- a/cogs/sentiment.py +++ b/cogs/sentiment.py @@ -82,7 +82,7 @@ class SentimentCog(commands.Cog): # Analyze the message context = self._get_context(message) user_notes = self.bot.drama_tracker.get_user_notes(message.author.id) - result = await self.bot.ollama.analyze_message( + result = await self.bot.llm.analyze_message( message.content, context, user_notes=user_notes ) diff --git a/utils/ollama_client.py b/utils/llm_client.py similarity index 61% rename from utils/ollama_client.py rename to utils/llm_client.py index 693c138..ea9cf59 100644 --- a/utils/ollama_client.py +++ b/utils/llm_client.py @@ -1,3 +1,4 @@ +import asyncio import json import logging from pathlib import Path @@ -96,8 +97,9 @@ class LLMClient: self._client = AsyncOpenAI( base_url=f"{self.host}/v1", api_key=api_key, - timeout=300.0, # 5 min — first request loads model into VRAM + timeout=600.0, # 10 min — first request loads model into VRAM ) + self._semaphore = asyncio.Semaphore(1) # serialize requests to avoid overloading async def close(self): await self._client.close() @@ -110,36 +112,37 @@ class LLMClient: user_content += f"=== NOTES ABOUT THIS USER (from prior analysis) ===\n{user_notes}\n\n" user_content += f"=== TARGET MESSAGE (analyze THIS message only) ===\n{message}" - try: - response = await self._client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": user_content}, - ], - tools=[ANALYSIS_TOOL], - tool_choice={"type": "function", "function": {"name": "report_analysis"}}, - temperature=0.1, - ) + async with self._semaphore: + try: + response = await self._client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ], + tools=[ANALYSIS_TOOL], + tool_choice={"type": "function", "function": {"name": "report_analysis"}}, + temperature=0.1, + ) - choice = response.choices[0] + choice = response.choices[0] - # Extract tool call arguments - if choice.message.tool_calls: - tool_call = choice.message.tool_calls[0] - args = json.loads(tool_call.function.arguments) - return self._validate_result(args) + # Extract tool call arguments + if choice.message.tool_calls: + tool_call = choice.message.tool_calls[0] + args = json.loads(tool_call.function.arguments) + return self._validate_result(args) - # Fallback: try parsing the message content as JSON - if choice.message.content: - return self._parse_content_fallback(choice.message.content) + # Fallback: try parsing the message content as JSON + if choice.message.content: + return self._parse_content_fallback(choice.message.content) - logger.warning("No tool call or content in LLM response.") - return None + logger.warning("No tool call or content in LLM response.") + return None - except Exception as e: - logger.error("LLM analysis error: %s", e) - return None + except Exception as e: + logger.error("LLM analysis error: %s", e) + return None def _validate_result(self, result: dict) -> dict: score = float(result.get("toxicity_score", 0.0)) @@ -196,24 +199,43 @@ class LLMClient: return None async def chat( - self, messages: list[dict[str, str]], system_prompt: str + self, messages: list[dict[str, str]], system_prompt: str, + on_first_token=None, ) -> str | None: - """Send a conversational chat request (no tools).""" - try: - response = await self._client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": system_prompt}, - *messages, - ], - temperature=0.8, - max_tokens=300, - ) - content = response.choices[0].message.content - return content.strip() if content else None - except Exception as e: - logger.error("LLM chat error: %s", e) - return None + """Send a conversational chat request (no tools). + + If *on_first_token* is an async callable it will be awaited once the + first content token arrives (useful for triggering the typing indicator + only after the model starts generating). + """ + async with self._semaphore: + try: + stream = await self._client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + *messages, + ], + temperature=0.8, + max_tokens=300, + stream=True, + ) + + chunks: list[str] = [] + notified = False + async for chunk in stream: + delta = chunk.choices[0].delta if chunk.choices else None + if delta and delta.content: + if not notified and on_first_token: + await on_first_token() + notified = True + chunks.append(delta.content) + + content = "".join(chunks).strip() + return content if content else None + except Exception as e: + logger.error("LLM chat error: %s", e) + return None async def raw_analyze(self, message: str, context: str = "", user_notes: str = "") -> tuple[str, dict | None]: """Return the raw LLM response string AND parsed result for /bcs-test (single LLM call).""" @@ -222,38 +244,39 @@ class LLMClient: user_content += f"=== NOTES ABOUT THIS USER (from prior analysis) ===\n{user_notes}\n\n" user_content += f"=== TARGET MESSAGE (analyze THIS message only) ===\n{message}" - try: - response = await self._client.chat.completions.create( - model=self.model, - messages=[ - {"role": "system", "content": SYSTEM_PROMPT}, - {"role": "user", "content": user_content}, - ], - tools=[ANALYSIS_TOOL], - tool_choice={"type": "function", "function": {"name": "report_analysis"}}, - temperature=0.1, - ) + async with self._semaphore: + try: + response = await self._client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_content}, + ], + tools=[ANALYSIS_TOOL], + tool_choice={"type": "function", "function": {"name": "report_analysis"}}, + temperature=0.1, + ) - choice = response.choices[0] - parts = [] - parsed = None + choice = response.choices[0] + parts = [] + parsed = None - if choice.message.content: - parts.append(f"Content: {choice.message.content}") + if choice.message.content: + parts.append(f"Content: {choice.message.content}") - if choice.message.tool_calls: - for tc in choice.message.tool_calls: - parts.append( - f"Tool call: {tc.function.name}({tc.function.arguments})" - ) - # Parse the first tool call - args = json.loads(choice.message.tool_calls[0].function.arguments) - parsed = self._validate_result(args) - elif choice.message.content: - parsed = self._parse_content_fallback(choice.message.content) + if choice.message.tool_calls: + for tc in choice.message.tool_calls: + parts.append( + f"Tool call: {tc.function.name}({tc.function.arguments})" + ) + # Parse the first tool call + args = json.loads(choice.message.tool_calls[0].function.arguments) + parsed = self._validate_result(args) + elif choice.message.content: + parsed = self._parse_content_fallback(choice.message.content) - raw = "\n".join(parts) or "(empty response)" - return raw, parsed + raw = "\n".join(parts) or "(empty response)" + return raw, parsed - except Exception as e: - return f"Error: {e}", None + except Exception as e: + return f"Error: {e}", None