Files
MoneyMap/MoneyMap.Core/Services/TransactionAICategorizer.cs
T
2026-04-20 18:18:20 -04:00

468 lines
18 KiB
C#

using Microsoft.EntityFrameworkCore;
using MoneyMap.Data;
using MoneyMap.Models;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
namespace MoneyMap.Services;
public interface ITransactionAICategorizer
{
Task<AICategoryProposal?> ProposeCategorizationAsync(Transaction transaction, string? model = null);
Task<List<AICategoryProposal>> ProposeBatchCategorizationAsync(List<Transaction> transactions, string? model = null);
Task<ApplyProposalResult> ApplyProposalAsync(long transactionId, AICategoryProposal proposal, bool createRule = true);
}
public class TransactionAICategorizer : ITransactionAICategorizer
{
private readonly HttpClient _httpClient;
private readonly MoneyMapContext _db;
private readonly IConfiguration _config;
private readonly LlamaCppVisionClient _llamaClient;
private readonly ILogger<TransactionAICategorizer> _logger;
public TransactionAICategorizer(
HttpClient httpClient,
MoneyMapContext db,
IConfiguration config,
LlamaCppVisionClient llamaClient,
ILogger<TransactionAICategorizer> logger)
{
_httpClient = httpClient;
_db = db;
_config = config;
_llamaClient = llamaClient;
_logger = logger;
}
public async Task<AICategoryProposal?> ProposeCategorizationAsync(Transaction transaction, string? model = null)
{
var selectedModel = model ?? _config["AI:ReceiptParsingModel"] ?? "gpt-4o-mini";
var prompt = await BuildPromptAsync(transaction);
var response = await CallModelAsync(prompt, selectedModel);
if (response == null)
return null;
return new AICategoryProposal
{
TransactionId = transaction.Id,
Category = response.Category ?? "",
CanonicalMerchant = response.CanonicalMerchant,
Pattern = response.Pattern,
Priority = response.Priority,
Confidence = response.Confidence,
Reasoning = response.Reasoning,
CreateRule = response.Confidence >= 0.7m // High confidence = auto-create rule
};
}
public async Task<List<AICategoryProposal>> ProposeBatchCategorizationAsync(List<Transaction> transactions, string? model = null)
{
var proposals = new List<AICategoryProposal>();
// Pre-fetch existing categories and all rules once to avoid concurrent DbContext access
var existingCategories = await _db.CategoryMappings
.Select(m => m.Category)
.Distinct()
.OrderBy(c => c)
.ToListAsync();
var allRules = await _db.CategoryMappings
.Include(m => m.Merchant)
.ToListAsync();
// Process transactions sequentially to avoid DbContext concurrency issues
foreach (var transaction in transactions)
{
var result = await ProposeCategorizationWithCategoriesAsync(transaction, existingCategories, allRules, model);
if (result != null)
proposals.Add(result);
}
return proposals;
}
private async Task<AICategoryProposal?> ProposeCategorizationWithCategoriesAsync(
Transaction transaction,
List<string> existingCategories,
List<CategoryMapping> allRules,
string? model = null)
{
var selectedModel = model ?? _config["AI:ReceiptParsingModel"] ?? "gpt-4o-mini";
// Find rules whose pattern matches this transaction name
var matchingRules = allRules
.Where(r => transaction.Name.Contains(r.Pattern, StringComparison.OrdinalIgnoreCase))
.OrderByDescending(r => r.Priority)
.ThenByDescending(r => r.Pattern.Length) // Prefer more specific patterns
.ToList();
var prompt = BuildPromptWithCategoriesAndRules(transaction, existingCategories, matchingRules);
var response = await CallModelAsync(prompt, selectedModel);
if (response == null)
return null;
return new AICategoryProposal
{
TransactionId = transaction.Id,
Category = response.Category ?? "",
CanonicalMerchant = response.CanonicalMerchant,
Pattern = response.Pattern,
Priority = response.Priority,
Confidence = response.Confidence,
Reasoning = response.Reasoning,
CreateRule = response.Confidence >= 0.7m
};
}
public async Task<ApplyProposalResult> ApplyProposalAsync(long transactionId, AICategoryProposal proposal, bool createRule = true)
{
var transaction = await _db.Transactions.FindAsync(transactionId);
if (transaction == null)
return new ApplyProposalResult { Success = false, ErrorMessage = "Transaction not found" };
// Update transaction category
transaction.Category = proposal.Category;
// Handle merchant
if (!string.IsNullOrWhiteSpace(proposal.CanonicalMerchant))
{
var merchant = await _db.Merchants.FirstOrDefaultAsync(m => m.Name == proposal.CanonicalMerchant);
if (merchant == null)
{
merchant = new Merchant { Name = proposal.CanonicalMerchant };
_db.Merchants.Add(merchant);
await _db.SaveChangesAsync();
}
transaction.MerchantId = merchant.Id;
}
bool ruleCreated = false;
bool ruleUpdated = false;
// Create or update category mapping rule if requested
if (createRule && !string.IsNullOrWhiteSpace(proposal.Pattern))
{
var existingRule = await _db.CategoryMappings
.FirstOrDefaultAsync(m => m.Pattern == proposal.Pattern);
if (existingRule == null)
{
var newMapping = new CategoryMapping
{
Category = proposal.Category,
Pattern = proposal.Pattern,
MerchantId = transaction.MerchantId,
Priority = proposal.Priority,
Confidence = proposal.Confidence,
CreatedBy = "AI",
CreatedAt = DateTime.UtcNow
};
_db.CategoryMappings.Add(newMapping);
ruleCreated = true;
}
else if (existingRule.Category != proposal.Category)
{
existingRule.Category = proposal.Category;
existingRule.MerchantId = transaction.MerchantId;
existingRule.Priority = proposal.Priority;
existingRule.Confidence = proposal.Confidence;
existingRule.CreatedBy = "AI";
existingRule.CreatedAt = DateTime.UtcNow;
ruleUpdated = true;
}
}
await _db.SaveChangesAsync();
return new ApplyProposalResult
{
Success = true,
RuleCreated = ruleCreated,
RuleUpdated = ruleUpdated
};
}
private async Task<string> BuildPromptAsync(Transaction transaction)
{
// Get existing categories from database for better suggestions
var existingCategories = await _db.CategoryMappings
.Select(m => m.Category)
.Distinct()
.OrderBy(c => c)
.ToListAsync();
// Load all rules and find matches in memory (pattern-in-name is hard to express in SQL)
var allRules = await _db.CategoryMappings
.Include(m => m.Merchant)
.ToListAsync();
var matchingRules = allRules
.Where(r => transaction.Name.Contains(r.Pattern, StringComparison.OrdinalIgnoreCase))
.OrderByDescending(r => r.Priority)
.ThenByDescending(r => r.Pattern.Length)
.ToList();
return BuildPromptWithCategoriesAndRules(transaction, existingCategories, matchingRules);
}
private string BuildPromptWithCategories(Transaction transaction, List<string> existingCategories)
{
return BuildPromptWithCategoriesAndRules(transaction, existingCategories, new List<CategoryMapping>());
}
private string BuildPromptWithCategoriesAndRules(Transaction transaction, List<string> existingCategories, List<CategoryMapping> matchingRules)
{
var categoryList = existingCategories.Any()
? string.Join(", ", existingCategories)
: "Restaurants, Fast Food, Coffee Shop, Groceries, Convenience Store, Gas & Auto, Online shopping, Health, Entertainment, Utilities, Banking, Insurance";
var sb = new StringBuilder();
sb.AppendLine("Analyze this financial transaction and suggest a category and merchant name.");
sb.AppendLine();
sb.AppendLine("Transaction Details:");
sb.AppendLine($"- Name: \"{transaction.Name}\"");
sb.AppendLine($"- Memo: \"{transaction.Memo}\"");
sb.AppendLine($"- Amount: {transaction.Amount:C}");
sb.AppendLine($"- Date: {transaction.Date:yyyy-MM-dd}");
sb.AppendLine($"- Type: {(transaction.IsCredit ? "Credit/Income" : "Debit/Expense")}");
if (!string.IsNullOrWhiteSpace(transaction.Category))
sb.AppendLine($"- Current Category: \"{transaction.Category}\"");
if (transaction.Merchant != null)
sb.AppendLine($"- Current Merchant: \"{transaction.Merchant.Name}\"");
if (transaction.Card != null)
sb.AppendLine($"- Card: {transaction.Card.Owner} - ****{transaction.Card.Last4}");
if (transaction.Account != null)
sb.AppendLine($"- Account: {transaction.Account.DisplayLabel}");
if (!string.IsNullOrWhiteSpace(transaction.Notes))
sb.AppendLine($"- Notes: \"{transaction.Notes}\"");
if (!string.IsNullOrWhiteSpace(transaction.Last4))
sb.AppendLine($"- Last 4 digits: {transaction.Last4}");
if (transaction.IsTransfer)
sb.AppendLine($"- Transfer to: {transaction.TransferToAccount?.DisplayLabel ?? "Unknown"}");
// Include matching rules so the AI respects existing mappings
if (matchingRules.Any())
{
sb.AppendLine();
sb.AppendLine("EXISTING RULES that match this transaction (you MUST use these categories unless clearly wrong):");
foreach (var rule in matchingRules)
{
var createdBy = rule.CreatedBy ?? "Unknown";
var merchantName = rule.Merchant?.Name;
sb.Append($" - Pattern \"{rule.Pattern}\" → Category \"{rule.Category}\"");
if (!string.IsNullOrWhiteSpace(merchantName))
sb.Append($", Merchant \"{merchantName}\"");
sb.AppendLine($" (created by {createdBy})");
}
}
sb.AppendLine();
sb.AppendLine($"Existing categories in this system: {categoryList}");
sb.AppendLine();
sb.AppendLine("Provide your analysis in JSON format:");
sb.AppendLine("{");
sb.AppendLine(" \"category\": \"Category name\",");
sb.AppendLine(" \"canonical_merchant\": \"Clean merchant name (e.g., 'Walmart' from 'WAL-MART #1234')\",");
sb.AppendLine(" \"pattern\": \"EXACT substring from the transaction Name that identifies this merchant\",");
sb.AppendLine(" \"priority\": 0,");
sb.AppendLine(" \"confidence\": 0.85,");
sb.AppendLine(" \"reasoning\": \"Brief explanation\"");
sb.AppendLine("}");
sb.AppendLine();
sb.AppendLine("Guidelines:");
sb.AppendLine("- If an existing rule matches this transaction, you MUST use that rule's category and merchant. Only deviate if the existing rule is clearly incorrect.");
sb.AppendLine("- Prefer using existing categories when appropriate");
sb.AppendLine("- CRITICAL: The pattern MUST be a substring that actually appears in the transaction Name field above. It is used for case-insensitive contains matching. Do NOT invent or clean up the pattern. Extract the shortest distinctive substring from the Name that would identify this merchant. For example, if the Name is 'DEBIT PURCHASE -VISA Kindle Unltd*0M6888', use 'Kindle Unltd' NOT 'Kindle Unlimited'. If the Name is 'WAL-MART #1234 SPRINGFIELD', use 'WAL-MART' NOT 'WALMART'.");
sb.AppendLine("- confidence: Your certainty in this categorization (0.0-1.0). Use ~0.9+ for obvious matches like 'WALMART' -> Groceries. Use ~0.7-0.8 for likely matches. Use ~0.5-0.6 for uncertain/ambiguous transactions.");
sb.AppendLine("- Return ONLY valid JSON, no additional text.");
return sb.ToString();
}
private async Task<AICategorizationResponse?> CallModelAsync(string prompt, string model)
{
if (model.StartsWith("llamacpp:", StringComparison.OrdinalIgnoreCase))
{
_logger.LogInformation("Using LlamaCpp for transaction categorization with model {Model}", model);
return await CallLlamaCppAsync(prompt, model);
}
// Default to OpenAI
var apiKey = _config["OpenAI:ApiKey"] ?? Environment.GetEnvironmentVariable("OPENAI_API_KEY");
if (string.IsNullOrWhiteSpace(apiKey))
{
_logger.LogWarning("OpenAI API key not configured");
return null;
}
_logger.LogInformation("Using OpenAI for transaction categorization with model {Model}", model);
return await CallOpenAIAsync(apiKey, prompt, model);
}
private async Task<AICategorizationResponse?> CallOpenAIAsync(string apiKey, string prompt, string model = "gpt-4o-mini")
{
try
{
var requestBody = new
{
model = model,
messages = new[]
{
new { role = "system", content = "You are a financial transaction categorization expert. Always respond with valid JSON only." },
new { role = "user", content = prompt }
},
temperature = 0.1,
max_tokens = 300
};
var request = new HttpRequestMessage(HttpMethod.Post, "https://api.openai.com/v1/chat/completions");
request.Headers.Add("Authorization", $"Bearer {apiKey}");
request.Content = new StringContent(
JsonSerializer.Serialize(requestBody),
Encoding.UTF8,
"application/json"
);
var response = await _httpClient.SendAsync(request);
if (!response.IsSuccessStatusCode)
return null;
var json = await response.Content.ReadAsStringAsync();
var apiResponse = JsonSerializer.Deserialize<OpenAIChatResponse>(json);
if (apiResponse?.Choices == null || apiResponse.Choices.Length == 0)
return null;
var content = OpenAIToolUseHelper.CleanJsonResponse(apiResponse.Choices[0].Message?.Content);
if (string.IsNullOrWhiteSpace(content))
return null;
return JsonSerializer.Deserialize<AICategorizationResponse>(content, new JsonSerializerOptions
{
PropertyNameCaseInsensitive = true
});
}
catch (HttpRequestException ex)
{
_logger.LogError(ex, "OpenAI API request failed: {Message}", ex.Message);
return null;
}
catch (JsonException ex)
{
_logger.LogError(ex, "Failed to parse OpenAI response JSON: {Message}", ex.Message);
return null;
}
catch (Exception ex)
{
_logger.LogError(ex, "Unexpected error calling OpenAI API: {Message}", ex.Message);
return null;
}
}
private async Task<AICategorizationResponse?> CallLlamaCppAsync(string prompt, string? model = null)
{
try
{
var selectedModel = model ?? _config["AI:ReceiptParsingModel"] ?? "gpt-4o-mini";
var systemPrompt = "You are a financial transaction categorization expert. Always respond with valid JSON only.";
var fullPrompt = $"{systemPrompt}\n\n{prompt}";
var result = await _llamaClient.SendTextPromptAsync(fullPrompt, selectedModel);
if (!result.IsSuccess)
{
_logger.LogWarning("LlamaCpp categorization failed: {Error}", result.ErrorMessage);
return null;
}
return JsonSerializer.Deserialize<AICategorizationResponse>(result.Content ?? "", new JsonSerializerOptions
{
PropertyNameCaseInsensitive = true
});
}
catch (JsonException ex)
{
_logger.LogError(ex, "Failed to parse LlamaCpp response JSON: {Message}", ex.Message);
return null;
}
catch (Exception ex)
{
_logger.LogError(ex, "Unexpected error calling LlamaCpp: {Message}", ex.Message);
return null;
}
}
// OpenAI API response models
private class OpenAIChatResponse
{
[JsonPropertyName("choices")]
public Choice[]? Choices { get; set; }
}
private class Choice
{
[JsonPropertyName("message")]
public Message? Message { get; set; }
}
private class Message
{
[JsonPropertyName("content")]
public string? Content { get; set; }
}
private class AICategorizationResponse
{
[JsonPropertyName("category")]
public string? Category { get; set; }
[JsonPropertyName("canonical_merchant")]
public string? CanonicalMerchant { get; set; }
[JsonPropertyName("pattern")]
public string? Pattern { get; set; }
[JsonPropertyName("priority")]
public int Priority { get; set; }
[JsonPropertyName("confidence")]
public decimal Confidence { get; set; }
[JsonPropertyName("reasoning")]
public string? Reasoning { get; set; }
}
}
public class AICategoryProposal
{
public long TransactionId { get; set; }
public string Category { get; set; } = "";
public string? CanonicalMerchant { get; set; }
public string? Pattern { get; set; }
public int Priority { get; set; }
public decimal Confidence { get; set; }
public string? Reasoning { get; set; }
public bool CreateRule { get; set; }
}
public class ApplyProposalResult
{
public bool Success { get; set; }
public bool RuleCreated { get; set; }
public bool RuleUpdated { get; set; }
public string? ErrorMessage { get; set; }
}