Improve: AI categorizer with rule matching and unified model routing
Categorizer now pre-fetches existing rules and includes matching rules in prompts so the AI respects established mappings. Unified model routing via CallModelAsync replaces separate provider branching. Improved pattern instructions require exact transaction name substrings. Add rule update support (RuleUpdated) when a pattern exists with a different category. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -38,26 +38,10 @@ public class TransactionAICategorizer : ITransactionAICategorizer
|
||||
|
||||
public async Task<AICategoryProposal?> ProposeCategorizationAsync(Transaction transaction, string? model = null)
|
||||
{
|
||||
var provider = _config["AI:CategorizationProvider"] ?? "OpenAI";
|
||||
var selectedModel = model ?? _config["AI:ReceiptParsingModel"] ?? "gpt-4o-mini";
|
||||
var prompt = await BuildPromptAsync(transaction);
|
||||
|
||||
AICategorizationResponse? response;
|
||||
|
||||
if (provider.Equals("LlamaCpp", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
_logger.LogInformation("Using LlamaCpp for transaction categorization with model {Model}", model ?? "default");
|
||||
response = await CallLlamaCppAsync(prompt, model);
|
||||
}
|
||||
else
|
||||
{
|
||||
var apiKey = _config["OpenAI:ApiKey"] ?? Environment.GetEnvironmentVariable("OPENAI_API_KEY");
|
||||
if (string.IsNullOrWhiteSpace(apiKey))
|
||||
{
|
||||
_logger.LogWarning("OpenAI API key not configured");
|
||||
return null;
|
||||
}
|
||||
response = await CallOpenAIAsync(apiKey, prompt);
|
||||
}
|
||||
var response = await CallModelAsync(prompt, selectedModel);
|
||||
|
||||
if (response == null)
|
||||
return null;
|
||||
@@ -79,17 +63,21 @@ public class TransactionAICategorizer : ITransactionAICategorizer
|
||||
{
|
||||
var proposals = new List<AICategoryProposal>();
|
||||
|
||||
// Pre-fetch existing categories once to avoid concurrent DbContext access
|
||||
// Pre-fetch existing categories and all rules once to avoid concurrent DbContext access
|
||||
var existingCategories = await _db.CategoryMappings
|
||||
.Select(m => m.Category)
|
||||
.Distinct()
|
||||
.OrderBy(c => c)
|
||||
.ToListAsync();
|
||||
|
||||
var allRules = await _db.CategoryMappings
|
||||
.Include(m => m.Merchant)
|
||||
.ToListAsync();
|
||||
|
||||
// Process transactions sequentially to avoid DbContext concurrency issues
|
||||
foreach (var transaction in transactions)
|
||||
{
|
||||
var result = await ProposeCategorizationWithCategoriesAsync(transaction, existingCategories, model);
|
||||
var result = await ProposeCategorizationWithCategoriesAsync(transaction, existingCategories, allRules, model);
|
||||
if (result != null)
|
||||
proposals.Add(result);
|
||||
}
|
||||
@@ -100,28 +88,21 @@ public class TransactionAICategorizer : ITransactionAICategorizer
|
||||
private async Task<AICategoryProposal?> ProposeCategorizationWithCategoriesAsync(
|
||||
Transaction transaction,
|
||||
List<string> existingCategories,
|
||||
List<CategoryMapping> allRules,
|
||||
string? model = null)
|
||||
{
|
||||
var provider = _config["AI:CategorizationProvider"] ?? "OpenAI";
|
||||
var prompt = BuildPromptWithCategories(transaction, existingCategories);
|
||||
var selectedModel = model ?? _config["AI:ReceiptParsingModel"] ?? "gpt-4o-mini";
|
||||
|
||||
AICategorizationResponse? response;
|
||||
// Find rules whose pattern matches this transaction name
|
||||
var matchingRules = allRules
|
||||
.Where(r => transaction.Name.Contains(r.Pattern, StringComparison.OrdinalIgnoreCase))
|
||||
.OrderByDescending(r => r.Priority)
|
||||
.ThenByDescending(r => r.Pattern.Length) // Prefer more specific patterns
|
||||
.ToList();
|
||||
|
||||
if (provider.Equals("LlamaCpp", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
_logger.LogInformation("Using LlamaCpp for transaction categorization with model {Model}", model ?? "default");
|
||||
response = await CallLlamaCppAsync(prompt, model);
|
||||
}
|
||||
else
|
||||
{
|
||||
var apiKey = _config["OpenAI:ApiKey"] ?? Environment.GetEnvironmentVariable("OPENAI_API_KEY");
|
||||
if (string.IsNullOrWhiteSpace(apiKey))
|
||||
{
|
||||
_logger.LogWarning("OpenAI API key not configured");
|
||||
return null;
|
||||
}
|
||||
response = await CallOpenAIAsync(apiKey, prompt);
|
||||
}
|
||||
var prompt = BuildPromptWithCategoriesAndRules(transaction, existingCategories, matchingRules);
|
||||
|
||||
var response = await CallModelAsync(prompt, selectedModel);
|
||||
|
||||
if (response == null)
|
||||
return null;
|
||||
@@ -161,27 +142,39 @@ public class TransactionAICategorizer : ITransactionAICategorizer
|
||||
transaction.MerchantId = merchant.Id;
|
||||
}
|
||||
|
||||
// Create category mapping rule if requested
|
||||
bool ruleCreated = false;
|
||||
bool ruleUpdated = false;
|
||||
|
||||
// Create or update category mapping rule if requested
|
||||
if (createRule && !string.IsNullOrWhiteSpace(proposal.Pattern))
|
||||
{
|
||||
// Check if rule already exists
|
||||
var existingRule = await _db.CategoryMappings
|
||||
.FirstOrDefaultAsync(m => m.Pattern == proposal.Pattern);
|
||||
|
||||
if (existingRule == null)
|
||||
{
|
||||
var merchantId = transaction.MerchantId;
|
||||
var newMapping = new CategoryMapping
|
||||
{
|
||||
Category = proposal.Category,
|
||||
Pattern = proposal.Pattern,
|
||||
MerchantId = merchantId,
|
||||
MerchantId = transaction.MerchantId,
|
||||
Priority = proposal.Priority,
|
||||
Confidence = proposal.Confidence,
|
||||
CreatedBy = "AI",
|
||||
CreatedAt = DateTime.UtcNow
|
||||
};
|
||||
_db.CategoryMappings.Add(newMapping);
|
||||
ruleCreated = true;
|
||||
}
|
||||
else if (existingRule.Category != proposal.Category)
|
||||
{
|
||||
existingRule.Category = proposal.Category;
|
||||
existingRule.MerchantId = transaction.MerchantId;
|
||||
existingRule.Priority = proposal.Priority;
|
||||
existingRule.Confidence = proposal.Confidence;
|
||||
existingRule.CreatedBy = "AI";
|
||||
existingRule.CreatedAt = DateTime.UtcNow;
|
||||
ruleUpdated = true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -190,7 +183,8 @@ public class TransactionAICategorizer : ITransactionAICategorizer
|
||||
return new ApplyProposalResult
|
||||
{
|
||||
Success = true,
|
||||
RuleCreated = createRule && !string.IsNullOrWhiteSpace(proposal.Pattern)
|
||||
RuleCreated = ruleCreated,
|
||||
RuleUpdated = ruleUpdated
|
||||
};
|
||||
}
|
||||
|
||||
@@ -203,10 +197,26 @@ public class TransactionAICategorizer : ITransactionAICategorizer
|
||||
.OrderBy(c => c)
|
||||
.ToListAsync();
|
||||
|
||||
return BuildPromptWithCategories(transaction, existingCategories);
|
||||
// Load all rules and find matches in memory (pattern-in-name is hard to express in SQL)
|
||||
var allRules = await _db.CategoryMappings
|
||||
.Include(m => m.Merchant)
|
||||
.ToListAsync();
|
||||
|
||||
var matchingRules = allRules
|
||||
.Where(r => transaction.Name.Contains(r.Pattern, StringComparison.OrdinalIgnoreCase))
|
||||
.OrderByDescending(r => r.Priority)
|
||||
.ThenByDescending(r => r.Pattern.Length)
|
||||
.ToList();
|
||||
|
||||
return BuildPromptWithCategoriesAndRules(transaction, existingCategories, matchingRules);
|
||||
}
|
||||
|
||||
private string BuildPromptWithCategories(Transaction transaction, List<string> existingCategories)
|
||||
{
|
||||
return BuildPromptWithCategoriesAndRules(transaction, existingCategories, new List<CategoryMapping>());
|
||||
}
|
||||
|
||||
private string BuildPromptWithCategoriesAndRules(Transaction transaction, List<string> existingCategories, List<CategoryMapping> matchingRules)
|
||||
{
|
||||
var categoryList = existingCategories.Any()
|
||||
? string.Join(", ", existingCategories)
|
||||
@@ -243,6 +253,22 @@ public class TransactionAICategorizer : ITransactionAICategorizer
|
||||
if (transaction.IsTransfer)
|
||||
sb.AppendLine($"- Transfer to: {transaction.TransferToAccount?.DisplayLabel ?? "Unknown"}");
|
||||
|
||||
// Include matching rules so the AI respects existing mappings
|
||||
if (matchingRules.Any())
|
||||
{
|
||||
sb.AppendLine();
|
||||
sb.AppendLine("EXISTING RULES that match this transaction (you MUST use these categories unless clearly wrong):");
|
||||
foreach (var rule in matchingRules)
|
||||
{
|
||||
var createdBy = rule.CreatedBy ?? "Unknown";
|
||||
var merchantName = rule.Merchant?.Name;
|
||||
sb.Append($" - Pattern \"{rule.Pattern}\" → Category \"{rule.Category}\"");
|
||||
if (!string.IsNullOrWhiteSpace(merchantName))
|
||||
sb.Append($", Merchant \"{merchantName}\"");
|
||||
sb.AppendLine($" (created by {createdBy})");
|
||||
}
|
||||
}
|
||||
|
||||
sb.AppendLine();
|
||||
sb.AppendLine($"Existing categories in this system: {categoryList}");
|
||||
sb.AppendLine();
|
||||
@@ -250,27 +276,49 @@ public class TransactionAICategorizer : ITransactionAICategorizer
|
||||
sb.AppendLine("{");
|
||||
sb.AppendLine(" \"category\": \"Category name\",");
|
||||
sb.AppendLine(" \"canonical_merchant\": \"Clean merchant name (e.g., 'Walmart' from 'WAL-MART #1234')\",");
|
||||
sb.AppendLine(" \"pattern\": \"Pattern to match future transactions (e.g., 'WALMART' or 'SUBWAY')\",");
|
||||
sb.AppendLine(" \"pattern\": \"EXACT substring from the transaction Name that identifies this merchant\",");
|
||||
sb.AppendLine(" \"priority\": 0,");
|
||||
sb.AppendLine(" \"confidence\": 0.85,");
|
||||
sb.AppendLine(" \"reasoning\": \"Brief explanation\"");
|
||||
sb.AppendLine("}");
|
||||
sb.AppendLine();
|
||||
sb.AppendLine("Guidelines:");
|
||||
sb.AppendLine("- If an existing rule matches this transaction, you MUST use that rule's category and merchant. Only deviate if the existing rule is clearly incorrect.");
|
||||
sb.AppendLine("- Prefer using existing categories when appropriate");
|
||||
sb.AppendLine("- CRITICAL: The pattern MUST be a substring that actually appears in the transaction Name field above. It is used for case-insensitive contains matching. Do NOT invent or clean up the pattern. Extract the shortest distinctive substring from the Name that would identify this merchant. For example, if the Name is 'DEBIT PURCHASE -VISA Kindle Unltd*0M6888', use 'Kindle Unltd' NOT 'Kindle Unlimited'. If the Name is 'WAL-MART #1234 SPRINGFIELD', use 'WAL-MART' NOT 'WALMART'.");
|
||||
sb.AppendLine("- confidence: Your certainty in this categorization (0.0-1.0). Use ~0.9+ for obvious matches like 'WALMART' -> Groceries. Use ~0.7-0.8 for likely matches. Use ~0.5-0.6 for uncertain/ambiguous transactions.");
|
||||
sb.AppendLine("- Return ONLY valid JSON, no additional text.");
|
||||
|
||||
return sb.ToString();
|
||||
}
|
||||
|
||||
private async Task<AICategorizationResponse?> CallOpenAIAsync(string apiKey, string prompt)
|
||||
private async Task<AICategorizationResponse?> CallModelAsync(string prompt, string model)
|
||||
{
|
||||
if (model.StartsWith("llamacpp:", StringComparison.OrdinalIgnoreCase))
|
||||
{
|
||||
_logger.LogInformation("Using LlamaCpp for transaction categorization with model {Model}", model);
|
||||
return await CallLlamaCppAsync(prompt, model);
|
||||
}
|
||||
|
||||
// Default to OpenAI
|
||||
var apiKey = _config["OpenAI:ApiKey"] ?? Environment.GetEnvironmentVariable("OPENAI_API_KEY");
|
||||
if (string.IsNullOrWhiteSpace(apiKey))
|
||||
{
|
||||
_logger.LogWarning("OpenAI API key not configured");
|
||||
return null;
|
||||
}
|
||||
|
||||
_logger.LogInformation("Using OpenAI for transaction categorization with model {Model}", model);
|
||||
return await CallOpenAIAsync(apiKey, prompt, model);
|
||||
}
|
||||
|
||||
private async Task<AICategorizationResponse?> CallOpenAIAsync(string apiKey, string prompt, string model = "gpt-4o-mini")
|
||||
{
|
||||
try
|
||||
{
|
||||
var requestBody = new
|
||||
{
|
||||
model = "gpt-4o-mini",
|
||||
model = model,
|
||||
messages = new[]
|
||||
{
|
||||
new { role = "system", content = "You are a financial transaction categorization expert. Always respond with valid JSON only." },
|
||||
@@ -298,7 +346,7 @@ public class TransactionAICategorizer : ITransactionAICategorizer
|
||||
if (apiResponse?.Choices == null || apiResponse.Choices.Length == 0)
|
||||
return null;
|
||||
|
||||
var content = apiResponse.Choices[0].Message?.Content;
|
||||
var content = OpenAIToolUseHelper.CleanJsonResponse(apiResponse.Choices[0].Message?.Content);
|
||||
if (string.IsNullOrWhiteSpace(content))
|
||||
return null;
|
||||
|
||||
@@ -414,5 +462,6 @@ public class ApplyProposalResult
|
||||
{
|
||||
public bool Success { get; set; }
|
||||
public bool RuleCreated { get; set; }
|
||||
public bool RuleUpdated { get; set; }
|
||||
public string? ErrorMessage { get; set; }
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user