Feature: Add LlamaCpp provider support to TransactionAICategorizer

Enhance AI categorization with multi-provider support:
- Add configurable AI:CategorizationProvider setting (OpenAI/LlamaCpp)
- Add CallLlamaCppAsync() for local LLM categorization
- Improve prompt with existing categories for consistency
- Include additional transaction context (card, account, transfer info)
- Fix batch processing to avoid DbContext concurrency issues
- Add model parameter to interface methods

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-15 22:52:54 -05:00
parent cba20f20fc
commit c43fe12124

View File

@@ -9,8 +9,8 @@ namespace MoneyMap.Services;
public interface ITransactionAICategorizer public interface ITransactionAICategorizer
{ {
Task<AICategoryProposal?> ProposeCategorizationAsync(Transaction transaction); Task<AICategoryProposal?> ProposeCategorizationAsync(Transaction transaction, string? model = null);
Task<List<AICategoryProposal>> ProposeBatchCategorizationAsync(List<Transaction> transactions); Task<List<AICategoryProposal>> ProposeBatchCategorizationAsync(List<Transaction> transactions, string? model = null);
Task<ApplyProposalResult> ApplyProposalAsync(long transactionId, AICategoryProposal proposal, bool createRule = true); Task<ApplyProposalResult> ApplyProposalAsync(long transactionId, AICategoryProposal proposal, bool createRule = true);
} }
@@ -19,30 +19,45 @@ public class TransactionAICategorizer : ITransactionAICategorizer
private readonly HttpClient _httpClient; private readonly HttpClient _httpClient;
private readonly MoneyMapContext _db; private readonly MoneyMapContext _db;
private readonly IConfiguration _config; private readonly IConfiguration _config;
private readonly LlamaCppVisionClient _llamaClient;
private readonly ILogger<TransactionAICategorizer> _logger; private readonly ILogger<TransactionAICategorizer> _logger;
public TransactionAICategorizer( public TransactionAICategorizer(
HttpClient httpClient, HttpClient httpClient,
MoneyMapContext db, MoneyMapContext db,
IConfiguration config, IConfiguration config,
LlamaCppVisionClient llamaClient,
ILogger<TransactionAICategorizer> logger) ILogger<TransactionAICategorizer> logger)
{ {
_httpClient = httpClient; _httpClient = httpClient;
_db = db; _db = db;
_config = config; _config = config;
_llamaClient = llamaClient;
_logger = logger; _logger = logger;
} }
public async Task<AICategoryProposal?> ProposeCategorizationAsync(Transaction transaction) public async Task<AICategoryProposal?> ProposeCategorizationAsync(Transaction transaction, string? model = null)
{ {
var apiKey = _config["OpenAI:ApiKey"] ?? Environment.GetEnvironmentVariable("OPENAI_API_KEY"); var provider = _config["AI:CategorizationProvider"] ?? "OpenAI";
if (string.IsNullOrWhiteSpace(apiKey)) var prompt = await BuildPromptAsync(transaction);
{
return null;
}
var prompt = BuildPrompt(transaction); AICategorizationResponse? response;
var response = await CallOpenAIAsync(apiKey, prompt);
if (provider.Equals("LlamaCpp", StringComparison.OrdinalIgnoreCase))
{
_logger.LogInformation("Using LlamaCpp for transaction categorization with model {Model}", model ?? "default");
response = await CallLlamaCppAsync(prompt, model);
}
else
{
var apiKey = _config["OpenAI:ApiKey"] ?? Environment.GetEnvironmentVariable("OPENAI_API_KEY");
if (string.IsNullOrWhiteSpace(apiKey))
{
_logger.LogWarning("OpenAI API key not configured");
return null;
}
response = await CallOpenAIAsync(apiKey, prompt);
}
if (response == null) if (response == null)
return null; return null;
@@ -60,23 +75,70 @@ public class TransactionAICategorizer : ITransactionAICategorizer
}; };
} }
public async Task<List<AICategoryProposal>> ProposeBatchCategorizationAsync(List<Transaction> transactions) public async Task<List<AICategoryProposal>> ProposeBatchCategorizationAsync(List<Transaction> transactions, string? model = null)
{ {
var proposals = new List<AICategoryProposal>(); var proposals = new List<AICategoryProposal>();
// Process in batches of 5 to avoid rate limits // Pre-fetch existing categories once to avoid concurrent DbContext access
var batches = transactions.Chunk(5); var existingCategories = await _db.CategoryMappings
.Select(m => m.Category)
.Distinct()
.OrderBy(c => c)
.ToListAsync();
foreach (var batch in batches) // Process transactions sequentially to avoid DbContext concurrency issues
foreach (var transaction in transactions)
{ {
var tasks = batch.Select(t => ProposeCategorizationAsync(t)); var result = await ProposeCategorizationWithCategoriesAsync(transaction, existingCategories, model);
var results = await Task.WhenAll(tasks); if (result != null)
proposals.AddRange(results.Where(r => r != null)!); proposals.Add(result);
} }
return proposals; return proposals;
} }
private async Task<AICategoryProposal?> ProposeCategorizationWithCategoriesAsync(
Transaction transaction,
List<string> existingCategories,
string? model = null)
{
var provider = _config["AI:CategorizationProvider"] ?? "OpenAI";
var prompt = BuildPromptWithCategories(transaction, existingCategories);
AICategorizationResponse? response;
if (provider.Equals("LlamaCpp", StringComparison.OrdinalIgnoreCase))
{
_logger.LogInformation("Using LlamaCpp for transaction categorization with model {Model}", model ?? "default");
response = await CallLlamaCppAsync(prompt, model);
}
else
{
var apiKey = _config["OpenAI:ApiKey"] ?? Environment.GetEnvironmentVariable("OPENAI_API_KEY");
if (string.IsNullOrWhiteSpace(apiKey))
{
_logger.LogWarning("OpenAI API key not configured");
return null;
}
response = await CallOpenAIAsync(apiKey, prompt);
}
if (response == null)
return null;
return new AICategoryProposal
{
TransactionId = transaction.Id,
Category = response.Category ?? "",
CanonicalMerchant = response.CanonicalMerchant,
Pattern = response.Pattern,
Priority = response.Priority,
Confidence = response.Confidence,
Reasoning = response.Reasoning,
CreateRule = response.Confidence >= 0.7m
};
}
public async Task<ApplyProposalResult> ApplyProposalAsync(long transactionId, AICategoryProposal proposal, bool createRule = true) public async Task<ApplyProposalResult> ApplyProposalAsync(long transactionId, AICategoryProposal proposal, bool createRule = true)
{ {
var transaction = await _db.Transactions.FindAsync(transactionId); var transaction = await _db.Transactions.FindAsync(transactionId);
@@ -132,40 +194,74 @@ public class TransactionAICategorizer : ITransactionAICategorizer
}; };
} }
private string BuildPrompt(Transaction transaction) private async Task<string> BuildPromptAsync(Transaction transaction)
{ {
return $@"Analyze this financial transaction and suggest a category and merchant name. // Get existing categories from database for better suggestions
var existingCategories = await _db.CategoryMappings
.Select(m => m.Category)
.Distinct()
.OrderBy(c => c)
.ToListAsync();
Transaction Details: return BuildPromptWithCategories(transaction, existingCategories);
- Name: ""{transaction.Name}""
- Memo: ""{transaction.Memo}""
- Amount: {transaction.Amount:C}
- Date: {transaction.Date:yyyy-MM-dd}
Provide your analysis in JSON format:
{{
""category"": ""Category name (e.g., Restaurants, Groceries, Gas & Auto)"",
""canonical_merchant"": ""Clean merchant name (e.g., 'Walmart' from 'WAL-MART #1234')"",
""pattern"": ""Pattern to match (e.g., 'WALMART' or 'SUBWAY')"",
""priority"": 0,
""confidence"": 0.95,
""reasoning"": ""Brief explanation""
}}
Common categories:
- Restaurants, Fast Food, Coffee Shop
- Groceries, Convenience Store
- Gas & Auto, Automotive
- Online shopping, Brick/mortar store
- Health, Pharmacy
- Entertainment, Streaming
- Utilities, Banking, Insurance
- Home Improvement, School
Return ONLY valid JSON, no additional text.";
} }
private async Task<OpenAIResponse?> CallOpenAIAsync(string apiKey, string prompt) private string BuildPromptWithCategories(Transaction transaction, List<string> existingCategories)
{
var categoryList = existingCategories.Any()
? string.Join(", ", existingCategories)
: "Restaurants, Fast Food, Coffee Shop, Groceries, Convenience Store, Gas & Auto, Online shopping, Health, Entertainment, Utilities, Banking, Insurance";
var sb = new StringBuilder();
sb.AppendLine("Analyze this financial transaction and suggest a category and merchant name.");
sb.AppendLine();
sb.AppendLine("Transaction Details:");
sb.AppendLine($"- Name: \"{transaction.Name}\"");
sb.AppendLine($"- Memo: \"{transaction.Memo}\"");
sb.AppendLine($"- Amount: {transaction.Amount:C}");
sb.AppendLine($"- Date: {transaction.Date:yyyy-MM-dd}");
sb.AppendLine($"- Type: {(transaction.IsCredit ? "Credit/Income" : "Debit/Expense")}");
if (!string.IsNullOrWhiteSpace(transaction.Category))
sb.AppendLine($"- Current Category: \"{transaction.Category}\"");
if (transaction.Merchant != null)
sb.AppendLine($"- Current Merchant: \"{transaction.Merchant.Name}\"");
if (transaction.Card != null)
sb.AppendLine($"- Card: {transaction.Card.Owner} - ****{transaction.Card.Last4}");
if (transaction.Account != null)
sb.AppendLine($"- Account: {transaction.Account.DisplayLabel}");
if (!string.IsNullOrWhiteSpace(transaction.Notes))
sb.AppendLine($"- Notes: \"{transaction.Notes}\"");
if (!string.IsNullOrWhiteSpace(transaction.Last4))
sb.AppendLine($"- Last 4 digits: {transaction.Last4}");
if (transaction.IsTransfer)
sb.AppendLine($"- Transfer to: {transaction.TransferToAccount?.DisplayLabel ?? "Unknown"}");
sb.AppendLine();
sb.AppendLine("Provide your analysis in JSON format:");
sb.AppendLine("{");
sb.AppendLine(" \"category\": \"Category name\",");
sb.AppendLine(" \"canonical_merchant\": \"Clean merchant name (e.g., 'Walmart' from 'WAL-MART #1234')\",");
sb.AppendLine(" \"pattern\": \"Pattern to match future transactions (e.g., 'WALMART' or 'SUBWAY')\",");
sb.AppendLine(" \"priority\": 0,");
sb.AppendLine(" \"confidence\": 0.95,");
sb.AppendLine(" \"reasoning\": \"Brief explanation\"");
sb.AppendLine("}");
sb.AppendLine();
sb.AppendLine($"Existing categories in this system: {categoryList}");
sb.AppendLine();
sb.AppendLine("Prefer using existing categories when appropriate. Return ONLY valid JSON, no additional text.");
return sb.ToString();
}
private async Task<AICategorizationResponse?> CallOpenAIAsync(string apiKey, string prompt)
{ {
try try
{ {
@@ -203,13 +299,10 @@ Return ONLY valid JSON, no additional text.";
if (string.IsNullOrWhiteSpace(content)) if (string.IsNullOrWhiteSpace(content))
return null; return null;
// Parse the JSON response from the AI return JsonSerializer.Deserialize<AICategorizationResponse>(content, new JsonSerializerOptions
var result = JsonSerializer.Deserialize<OpenAIResponse>(content, new JsonSerializerOptions
{ {
PropertyNameCaseInsensitive = true PropertyNameCaseInsensitive = true
}); });
return result;
} }
catch (HttpRequestException ex) catch (HttpRequestException ex)
{ {
@@ -228,6 +321,39 @@ Return ONLY valid JSON, no additional text.";
} }
} }
private async Task<AICategorizationResponse?> CallLlamaCppAsync(string prompt, string? model = null)
{
try
{
var selectedModel = model ?? _config["AI:CategorizationModel"] ?? "qwen2.5-coder-32b-instruct-q6_k";
var systemPrompt = "You are a financial transaction categorization expert. Always respond with valid JSON only.";
var fullPrompt = $"{systemPrompt}\n\n{prompt}";
var result = await _llamaClient.SendTextPromptAsync(fullPrompt, selectedModel);
if (!result.IsSuccess)
{
_logger.LogWarning("LlamaCpp categorization failed: {Error}", result.ErrorMessage);
return null;
}
return JsonSerializer.Deserialize<AICategorizationResponse>(result.Content ?? "", new JsonSerializerOptions
{
PropertyNameCaseInsensitive = true
});
}
catch (JsonException ex)
{
_logger.LogError(ex, "Failed to parse LlamaCpp response JSON: {Message}", ex.Message);
return null;
}
catch (Exception ex)
{
_logger.LogError(ex, "Unexpected error calling LlamaCpp: {Message}", ex.Message);
return null;
}
}
// OpenAI API response models // OpenAI API response models
private class OpenAIChatResponse private class OpenAIChatResponse
{ {
@@ -247,7 +373,7 @@ Return ONLY valid JSON, no additional text.";
public string? Content { get; set; } public string? Content { get; set; }
} }
private class OpenAIResponse private class AICategorizationResponse
{ {
[JsonPropertyName("category")] [JsonPropertyName("category")]
public string? Category { get; set; } public string? Category { get; set; }