using Microsoft.EntityFrameworkCore; using MoneyMap.Data; using System.Text; using System.Text.Json; namespace MoneyMap.Services.AITools { public interface IAIToolExecutor { /// /// Execute a single tool call and return the result as JSON. /// Task ExecuteAsync(AIToolCall toolCall); /// /// Pre-fetch all relevant context as a text block for providers that don't support tool use (Ollama). /// Task GetEnrichedContextAsync(DateTime? receiptDate = null, decimal? total = null, string? merchantHint = null); } public class AIToolExecutor : IAIToolExecutor { private readonly MoneyMapContext _db; private readonly ILogger _logger; private const int MaxResults = 20; public AIToolExecutor(MoneyMapContext db, ILogger logger) { _db = db; _logger = logger; } public async Task ExecuteAsync(AIToolCall toolCall) { _logger.LogInformation("Executing AI tool: {ToolName} with args: {Args}", toolCall.Name, JsonSerializer.Serialize(toolCall.Arguments)); try { var result = toolCall.Name switch { "search_categories" => await SearchCategoriesAsync(toolCall), "search_transactions" => await SearchTransactionsAsync(toolCall), "search_merchants" => await SearchMerchantsAsync(toolCall), _ => $"{{\"error\": \"Unknown tool: {toolCall.Name}\"}}" }; _logger.LogInformation("Tool {ToolName} returned {Length} chars", toolCall.Name, result.Length); return new AIToolResult { ToolCallId = toolCall.Id, Content = result }; } catch (Exception ex) { _logger.LogError(ex, "Error executing tool {ToolName}", toolCall.Name); return new AIToolResult { ToolCallId = toolCall.Id, Content = JsonSerializer.Serialize(new { error = ex.Message }), IsError = true }; } } public async Task GetEnrichedContextAsync(DateTime? receiptDate, decimal? total, string? merchantHint) { var sb = new StringBuilder(); sb.AppendLine("=== DATABASE CONTEXT (use this to match categories and transactions) ==="); sb.AppendLine(); // Categories var categories = await _db.CategoryMappings .Include(cm => cm.Merchant) .OrderBy(cm => cm.Category) .ToListAsync(); var grouped = categories.GroupBy(c => c.Category).ToList(); sb.AppendLine($"EXISTING CATEGORIES ({grouped.Count} total):"); foreach (var group in grouped) { var patterns = group.Select(c => c.Pattern).Take(5); var merchants = group.Where(c => c.Merchant != null).Select(c => c.Merchant!.Name).Distinct().Take(3); sb.Append($" - {group.Key}: patterns=[{string.Join(", ", patterns)}]"); if (merchants.Any()) sb.Append($", merchants=[{string.Join(", ", merchants)}]"); sb.AppendLine(); } sb.AppendLine(); // Merchants matching hint if (!string.IsNullOrWhiteSpace(merchantHint)) { var matchingMerchants = await _db.Merchants .Where(m => m.Name.Contains(merchantHint)) .Select(m => new { m.Name, TransactionCount = m.Transactions.Count, TopCategory = m.Transactions .Where(t => t.Category != "") .GroupBy(t => t.Category) .OrderByDescending(g => g.Count()) .Select(g => g.Key) .FirstOrDefault() }) .Take(10) .ToListAsync(); if (matchingMerchants.Count > 0) { sb.AppendLine($"MATCHING MERCHANTS for \"{merchantHint}\":"); foreach (var m in matchingMerchants) sb.AppendLine($" - {m.Name} ({m.TransactionCount} transactions, typical category: {m.TopCategory ?? "none"})"); sb.AppendLine(); } } // Matching transactions if (receiptDate.HasValue || total.HasValue) { var txQuery = _db.Transactions .Include(t => t.Merchant) .Where(t => !_db.Receipts.Any(r => r.TransactionId == t.Id)) .AsQueryable(); if (receiptDate.HasValue) { var minDate = receiptDate.Value.AddDays(-1); var maxDate = receiptDate.Value.AddDays(7); txQuery = txQuery.Where(t => t.Date >= minDate && t.Date <= maxDate); } if (total.HasValue) { var absTotal = Math.Abs(total.Value); var minAmt = absTotal * 0.9m; var maxAmt = absTotal * 1.1m; txQuery = txQuery.Where(t => (t.Amount >= -maxAmt && t.Amount <= -minAmt) || (t.Amount >= minAmt && t.Amount <= maxAmt)); } var transactions = await txQuery .OrderBy(t => t.Date) .Take(10) .ToListAsync(); if (transactions.Count > 0) { sb.AppendLine("CANDIDATE TRANSACTIONS (unmapped, matching date/amount):"); foreach (var t in transactions) { sb.AppendLine($" - ID={t.Id}, Date={t.Date:yyyy-MM-dd}, Amount={t.Amount:C}, Name=\"{t.Name}\", " + $"Merchant={t.Merchant?.Name ?? "none"}, Category={t.Category}"); } sb.AppendLine(); } } sb.AppendLine("=== END DATABASE CONTEXT ==="); return sb.ToString(); } private async Task SearchCategoriesAsync(AIToolCall toolCall) { var query = toolCall.GetString("query"); var mappings = _db.CategoryMappings .Include(cm => cm.Merchant) .AsQueryable(); if (!string.IsNullOrWhiteSpace(query)) mappings = mappings.Where(cm => cm.Category.Contains(query)); var results = await mappings .OrderBy(cm => cm.Category) .ToListAsync(); var grouped = results .GroupBy(c => c.Category) .Take(MaxResults) .Select(g => new { category = g.Key, patterns = g.Select(c => c.Pattern).Take(5).ToList(), merchants = g.Where(c => c.Merchant != null) .Select(c => c.Merchant!.Name) .Distinct() .Take(5) .ToList() }) .ToList(); return JsonSerializer.Serialize(new { categories = grouped }); } private async Task SearchTransactionsAsync(AIToolCall toolCall) { var merchant = toolCall.GetString("merchant"); var minDateStr = toolCall.GetString("minDate"); var maxDateStr = toolCall.GetString("maxDate"); var minAmount = toolCall.GetDecimal("minAmount"); var maxAmount = toolCall.GetDecimal("maxAmount"); var limit = toolCall.GetInt("limit") ?? 10; limit = Math.Min(limit, MaxResults); var txQuery = _db.Transactions .Include(t => t.Merchant) .Where(t => !_db.Receipts.Any(r => r.TransactionId == t.Id)) .AsQueryable(); if (!string.IsNullOrWhiteSpace(merchant)) { txQuery = txQuery.Where(t => t.Name.Contains(merchant) || (t.Merchant != null && t.Merchant.Name.Contains(merchant))); } if (DateTime.TryParse(minDateStr, out var minDate)) txQuery = txQuery.Where(t => t.Date >= minDate); if (DateTime.TryParse(maxDateStr, out var maxDate)) txQuery = txQuery.Where(t => t.Date <= maxDate); if (minAmount.HasValue) { var min = minAmount.Value; txQuery = txQuery.Where(t => t.Amount <= -min || t.Amount >= min); } if (maxAmount.HasValue) { var max = maxAmount.Value; txQuery = txQuery.Where(t => t.Amount >= -max && t.Amount <= max); } var transactions = await txQuery .OrderByDescending(t => t.Date) .Take(limit) .Select(t => new { id = t.Id, date = t.Date.ToString("yyyy-MM-dd"), amount = t.Amount, name = t.Name, merchant = t.Merchant != null ? t.Merchant.Name : null, category = t.Category }) .ToListAsync(); return JsonSerializer.Serialize(new { transactions }); } private async Task SearchMerchantsAsync(AIToolCall toolCall) { var query = toolCall.GetString("query") ?? ""; var merchants = await _db.Merchants .Where(m => m.Name.Contains(query)) .Select(m => new { name = m.Name, transactionCount = m.Transactions.Count, topCategory = m.Transactions .Where(t => t.Category != "") .GroupBy(t => t.Category) .OrderByDescending(g => g.Count()) .Select(g => g.Key) .FirstOrDefault() }) .Take(MaxResults) .ToListAsync(); return JsonSerializer.Serialize(new { merchants }); } } }