diff --git a/MoneyMap/Services/TransactionAICategorizer.cs b/MoneyMap/Services/TransactionAICategorizer.cs index 04dec26..7cf03a3 100644 --- a/MoneyMap/Services/TransactionAICategorizer.cs +++ b/MoneyMap/Services/TransactionAICategorizer.cs @@ -38,26 +38,10 @@ public class TransactionAICategorizer : ITransactionAICategorizer public async Task ProposeCategorizationAsync(Transaction transaction, string? model = null) { - var provider = _config["AI:CategorizationProvider"] ?? "OpenAI"; + var selectedModel = model ?? _config["AI:ReceiptParsingModel"] ?? "gpt-4o-mini"; var prompt = await BuildPromptAsync(transaction); - AICategorizationResponse? response; - - if (provider.Equals("LlamaCpp", StringComparison.OrdinalIgnoreCase)) - { - _logger.LogInformation("Using LlamaCpp for transaction categorization with model {Model}", model ?? "default"); - response = await CallLlamaCppAsync(prompt, model); - } - else - { - var apiKey = _config["OpenAI:ApiKey"] ?? Environment.GetEnvironmentVariable("OPENAI_API_KEY"); - if (string.IsNullOrWhiteSpace(apiKey)) - { - _logger.LogWarning("OpenAI API key not configured"); - return null; - } - response = await CallOpenAIAsync(apiKey, prompt); - } + var response = await CallModelAsync(prompt, selectedModel); if (response == null) return null; @@ -79,17 +63,21 @@ public class TransactionAICategorizer : ITransactionAICategorizer { var proposals = new List(); - // Pre-fetch existing categories once to avoid concurrent DbContext access + // Pre-fetch existing categories and all rules once to avoid concurrent DbContext access var existingCategories = await _db.CategoryMappings .Select(m => m.Category) .Distinct() .OrderBy(c => c) .ToListAsync(); + var allRules = await _db.CategoryMappings + .Include(m => m.Merchant) + .ToListAsync(); + // Process transactions sequentially to avoid DbContext concurrency issues foreach (var transaction in transactions) { - var result = await ProposeCategorizationWithCategoriesAsync(transaction, existingCategories, model); + var result = await ProposeCategorizationWithCategoriesAsync(transaction, existingCategories, allRules, model); if (result != null) proposals.Add(result); } @@ -100,28 +88,21 @@ public class TransactionAICategorizer : ITransactionAICategorizer private async Task ProposeCategorizationWithCategoriesAsync( Transaction transaction, List existingCategories, + List allRules, string? model = null) { - var provider = _config["AI:CategorizationProvider"] ?? "OpenAI"; - var prompt = BuildPromptWithCategories(transaction, existingCategories); + var selectedModel = model ?? _config["AI:ReceiptParsingModel"] ?? "gpt-4o-mini"; - AICategorizationResponse? response; + // Find rules whose pattern matches this transaction name + var matchingRules = allRules + .Where(r => transaction.Name.Contains(r.Pattern, StringComparison.OrdinalIgnoreCase)) + .OrderByDescending(r => r.Priority) + .ThenByDescending(r => r.Pattern.Length) // Prefer more specific patterns + .ToList(); - if (provider.Equals("LlamaCpp", StringComparison.OrdinalIgnoreCase)) - { - _logger.LogInformation("Using LlamaCpp for transaction categorization with model {Model}", model ?? "default"); - response = await CallLlamaCppAsync(prompt, model); - } - else - { - var apiKey = _config["OpenAI:ApiKey"] ?? Environment.GetEnvironmentVariable("OPENAI_API_KEY"); - if (string.IsNullOrWhiteSpace(apiKey)) - { - _logger.LogWarning("OpenAI API key not configured"); - return null; - } - response = await CallOpenAIAsync(apiKey, prompt); - } + var prompt = BuildPromptWithCategoriesAndRules(transaction, existingCategories, matchingRules); + + var response = await CallModelAsync(prompt, selectedModel); if (response == null) return null; @@ -161,27 +142,39 @@ public class TransactionAICategorizer : ITransactionAICategorizer transaction.MerchantId = merchant.Id; } - // Create category mapping rule if requested + bool ruleCreated = false; + bool ruleUpdated = false; + + // Create or update category mapping rule if requested if (createRule && !string.IsNullOrWhiteSpace(proposal.Pattern)) { - // Check if rule already exists var existingRule = await _db.CategoryMappings .FirstOrDefaultAsync(m => m.Pattern == proposal.Pattern); if (existingRule == null) { - var merchantId = transaction.MerchantId; var newMapping = new CategoryMapping { Category = proposal.Category, Pattern = proposal.Pattern, - MerchantId = merchantId, + MerchantId = transaction.MerchantId, Priority = proposal.Priority, Confidence = proposal.Confidence, CreatedBy = "AI", CreatedAt = DateTime.UtcNow }; _db.CategoryMappings.Add(newMapping); + ruleCreated = true; + } + else if (existingRule.Category != proposal.Category) + { + existingRule.Category = proposal.Category; + existingRule.MerchantId = transaction.MerchantId; + existingRule.Priority = proposal.Priority; + existingRule.Confidence = proposal.Confidence; + existingRule.CreatedBy = "AI"; + existingRule.CreatedAt = DateTime.UtcNow; + ruleUpdated = true; } } @@ -190,7 +183,8 @@ public class TransactionAICategorizer : ITransactionAICategorizer return new ApplyProposalResult { Success = true, - RuleCreated = createRule && !string.IsNullOrWhiteSpace(proposal.Pattern) + RuleCreated = ruleCreated, + RuleUpdated = ruleUpdated }; } @@ -203,10 +197,26 @@ public class TransactionAICategorizer : ITransactionAICategorizer .OrderBy(c => c) .ToListAsync(); - return BuildPromptWithCategories(transaction, existingCategories); + // Load all rules and find matches in memory (pattern-in-name is hard to express in SQL) + var allRules = await _db.CategoryMappings + .Include(m => m.Merchant) + .ToListAsync(); + + var matchingRules = allRules + .Where(r => transaction.Name.Contains(r.Pattern, StringComparison.OrdinalIgnoreCase)) + .OrderByDescending(r => r.Priority) + .ThenByDescending(r => r.Pattern.Length) + .ToList(); + + return BuildPromptWithCategoriesAndRules(transaction, existingCategories, matchingRules); } private string BuildPromptWithCategories(Transaction transaction, List existingCategories) + { + return BuildPromptWithCategoriesAndRules(transaction, existingCategories, new List()); + } + + private string BuildPromptWithCategoriesAndRules(Transaction transaction, List existingCategories, List matchingRules) { var categoryList = existingCategories.Any() ? string.Join(", ", existingCategories) @@ -243,6 +253,22 @@ public class TransactionAICategorizer : ITransactionAICategorizer if (transaction.IsTransfer) sb.AppendLine($"- Transfer to: {transaction.TransferToAccount?.DisplayLabel ?? "Unknown"}"); + // Include matching rules so the AI respects existing mappings + if (matchingRules.Any()) + { + sb.AppendLine(); + sb.AppendLine("EXISTING RULES that match this transaction (you MUST use these categories unless clearly wrong):"); + foreach (var rule in matchingRules) + { + var createdBy = rule.CreatedBy ?? "Unknown"; + var merchantName = rule.Merchant?.Name; + sb.Append($" - Pattern \"{rule.Pattern}\" → Category \"{rule.Category}\""); + if (!string.IsNullOrWhiteSpace(merchantName)) + sb.Append($", Merchant \"{merchantName}\""); + sb.AppendLine($" (created by {createdBy})"); + } + } + sb.AppendLine(); sb.AppendLine($"Existing categories in this system: {categoryList}"); sb.AppendLine(); @@ -250,27 +276,49 @@ public class TransactionAICategorizer : ITransactionAICategorizer sb.AppendLine("{"); sb.AppendLine(" \"category\": \"Category name\","); sb.AppendLine(" \"canonical_merchant\": \"Clean merchant name (e.g., 'Walmart' from 'WAL-MART #1234')\","); - sb.AppendLine(" \"pattern\": \"Pattern to match future transactions (e.g., 'WALMART' or 'SUBWAY')\","); + sb.AppendLine(" \"pattern\": \"EXACT substring from the transaction Name that identifies this merchant\","); sb.AppendLine(" \"priority\": 0,"); sb.AppendLine(" \"confidence\": 0.85,"); sb.AppendLine(" \"reasoning\": \"Brief explanation\""); sb.AppendLine("}"); sb.AppendLine(); sb.AppendLine("Guidelines:"); + sb.AppendLine("- If an existing rule matches this transaction, you MUST use that rule's category and merchant. Only deviate if the existing rule is clearly incorrect."); sb.AppendLine("- Prefer using existing categories when appropriate"); + sb.AppendLine("- CRITICAL: The pattern MUST be a substring that actually appears in the transaction Name field above. It is used for case-insensitive contains matching. Do NOT invent or clean up the pattern. Extract the shortest distinctive substring from the Name that would identify this merchant. For example, if the Name is 'DEBIT PURCHASE -VISA Kindle Unltd*0M6888', use 'Kindle Unltd' NOT 'Kindle Unlimited'. If the Name is 'WAL-MART #1234 SPRINGFIELD', use 'WAL-MART' NOT 'WALMART'."); sb.AppendLine("- confidence: Your certainty in this categorization (0.0-1.0). Use ~0.9+ for obvious matches like 'WALMART' -> Groceries. Use ~0.7-0.8 for likely matches. Use ~0.5-0.6 for uncertain/ambiguous transactions."); sb.AppendLine("- Return ONLY valid JSON, no additional text."); return sb.ToString(); } - private async Task CallOpenAIAsync(string apiKey, string prompt) + private async Task CallModelAsync(string prompt, string model) + { + if (model.StartsWith("llamacpp:", StringComparison.OrdinalIgnoreCase)) + { + _logger.LogInformation("Using LlamaCpp for transaction categorization with model {Model}", model); + return await CallLlamaCppAsync(prompt, model); + } + + // Default to OpenAI + var apiKey = _config["OpenAI:ApiKey"] ?? Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + if (string.IsNullOrWhiteSpace(apiKey)) + { + _logger.LogWarning("OpenAI API key not configured"); + return null; + } + + _logger.LogInformation("Using OpenAI for transaction categorization with model {Model}", model); + return await CallOpenAIAsync(apiKey, prompt, model); + } + + private async Task CallOpenAIAsync(string apiKey, string prompt, string model = "gpt-4o-mini") { try { var requestBody = new { - model = "gpt-4o-mini", + model = model, messages = new[] { new { role = "system", content = "You are a financial transaction categorization expert. Always respond with valid JSON only." }, @@ -298,7 +346,7 @@ public class TransactionAICategorizer : ITransactionAICategorizer if (apiResponse?.Choices == null || apiResponse.Choices.Length == 0) return null; - var content = apiResponse.Choices[0].Message?.Content; + var content = OpenAIToolUseHelper.CleanJsonResponse(apiResponse.Choices[0].Message?.Content); if (string.IsNullOrWhiteSpace(content)) return null; @@ -414,5 +462,6 @@ public class ApplyProposalResult { public bool Success { get; set; } public bool RuleCreated { get; set; } + public bool RuleUpdated { get; set; } public string? ErrorMessage { get; set; } }