using Microsoft.EntityFrameworkCore;
using MoneyMap.Data;
using System.Text;
using System.Text.Json;
namespace MoneyMap.Services.AITools
{
public interface IAIToolExecutor
{
///
/// Execute a single tool call and return the result as JSON.
///
Task ExecuteAsync(AIToolCall toolCall);
///
/// Pre-fetch all relevant context as a text block for providers that don't support tool use (Ollama).
///
Task GetEnrichedContextAsync(DateTime? receiptDate = null, decimal? total = null, string? merchantHint = null);
}
public class AIToolExecutor : IAIToolExecutor
{
private readonly MoneyMapContext _db;
private readonly ILogger _logger;
private const int MaxResults = 20;
public AIToolExecutor(MoneyMapContext db, ILogger logger)
{
_db = db;
_logger = logger;
}
public async Task ExecuteAsync(AIToolCall toolCall)
{
_logger.LogInformation("Executing AI tool: {ToolName} with args: {Args}",
toolCall.Name, JsonSerializer.Serialize(toolCall.Arguments));
try
{
var result = toolCall.Name switch
{
"search_categories" => await SearchCategoriesAsync(toolCall),
"search_transactions" => await SearchTransactionsAsync(toolCall),
"search_merchants" => await SearchMerchantsAsync(toolCall),
_ => $"{{\"error\": \"Unknown tool: {toolCall.Name}\"}}"
};
_logger.LogInformation("Tool {ToolName} returned {Length} chars", toolCall.Name, result.Length);
return new AIToolResult
{
ToolCallId = toolCall.Id,
Content = result
};
}
catch (Exception ex)
{
_logger.LogError(ex, "Error executing tool {ToolName}", toolCall.Name);
return new AIToolResult
{
ToolCallId = toolCall.Id,
Content = JsonSerializer.Serialize(new { error = ex.Message }),
IsError = true
};
}
}
public async Task GetEnrichedContextAsync(DateTime? receiptDate, decimal? total, string? merchantHint)
{
var sb = new StringBuilder();
sb.AppendLine("=== DATABASE CONTEXT (use this to match categories and transactions) ===");
sb.AppendLine();
// Categories
var categories = await _db.CategoryMappings
.Include(cm => cm.Merchant)
.OrderBy(cm => cm.Category)
.ToListAsync();
var grouped = categories.GroupBy(c => c.Category).ToList();
sb.AppendLine($"EXISTING CATEGORIES ({grouped.Count} total):");
foreach (var group in grouped)
{
var patterns = group.Select(c => c.Pattern).Take(5);
var merchants = group.Where(c => c.Merchant != null).Select(c => c.Merchant!.Name).Distinct().Take(3);
sb.Append($" - {group.Key}: patterns=[{string.Join(", ", patterns)}]");
if (merchants.Any())
sb.Append($", merchants=[{string.Join(", ", merchants)}]");
sb.AppendLine();
}
sb.AppendLine();
// Merchants matching hint
if (!string.IsNullOrWhiteSpace(merchantHint))
{
var matchingMerchants = await _db.Merchants
.Where(m => m.Name.Contains(merchantHint))
.Select(m => new
{
m.Name,
TransactionCount = m.Transactions.Count,
TopCategory = m.Transactions
.Where(t => t.Category != "")
.GroupBy(t => t.Category)
.OrderByDescending(g => g.Count())
.Select(g => g.Key)
.FirstOrDefault()
})
.Take(10)
.ToListAsync();
if (matchingMerchants.Count > 0)
{
sb.AppendLine($"MATCHING MERCHANTS for \"{merchantHint}\":");
foreach (var m in matchingMerchants)
sb.AppendLine($" - {m.Name} ({m.TransactionCount} transactions, typical category: {m.TopCategory ?? "none"})");
sb.AppendLine();
}
}
// Matching transactions
if (receiptDate.HasValue || total.HasValue)
{
var txQuery = _db.Transactions
.Include(t => t.Merchant)
.Where(t => !_db.Receipts.Any(r => r.TransactionId == t.Id))
.AsQueryable();
if (receiptDate.HasValue)
{
var minDate = receiptDate.Value.AddDays(-1);
var maxDate = receiptDate.Value.AddDays(7);
txQuery = txQuery.Where(t => t.Date >= minDate && t.Date <= maxDate);
}
if (total.HasValue)
{
var absTotal = Math.Abs(total.Value);
var minAmt = absTotal * 0.9m;
var maxAmt = absTotal * 1.1m;
txQuery = txQuery.Where(t =>
(t.Amount >= -maxAmt && t.Amount <= -minAmt) ||
(t.Amount >= minAmt && t.Amount <= maxAmt));
}
var transactions = await txQuery
.OrderBy(t => t.Date)
.Take(10)
.ToListAsync();
if (transactions.Count > 0)
{
sb.AppendLine("CANDIDATE TRANSACTIONS (unmapped, matching date/amount):");
foreach (var t in transactions)
{
sb.AppendLine($" - ID={t.Id}, Date={t.Date:yyyy-MM-dd}, Amount={t.Amount:C}, Name=\"{t.Name}\", " +
$"Merchant={t.Merchant?.Name ?? "none"}, Category={t.Category}");
}
sb.AppendLine();
}
}
sb.AppendLine("=== END DATABASE CONTEXT ===");
return sb.ToString();
}
private async Task SearchCategoriesAsync(AIToolCall toolCall)
{
var query = toolCall.GetString("query");
var mappings = _db.CategoryMappings
.Include(cm => cm.Merchant)
.AsQueryable();
if (!string.IsNullOrWhiteSpace(query))
mappings = mappings.Where(cm => cm.Category.Contains(query));
var results = await mappings
.OrderBy(cm => cm.Category)
.ToListAsync();
var grouped = results
.GroupBy(c => c.Category)
.Take(MaxResults)
.Select(g => new
{
category = g.Key,
patterns = g.Select(c => c.Pattern).Take(5).ToList(),
merchants = g.Where(c => c.Merchant != null)
.Select(c => c.Merchant!.Name)
.Distinct()
.Take(5)
.ToList()
})
.ToList();
return JsonSerializer.Serialize(new { categories = grouped });
}
private async Task SearchTransactionsAsync(AIToolCall toolCall)
{
var merchant = toolCall.GetString("merchant");
var minDateStr = toolCall.GetString("minDate");
var maxDateStr = toolCall.GetString("maxDate");
var minAmount = toolCall.GetDecimal("minAmount");
var maxAmount = toolCall.GetDecimal("maxAmount");
var limit = toolCall.GetInt("limit") ?? 10;
limit = Math.Min(limit, MaxResults);
var txQuery = _db.Transactions
.Include(t => t.Merchant)
.Where(t => !_db.Receipts.Any(r => r.TransactionId == t.Id))
.AsQueryable();
if (!string.IsNullOrWhiteSpace(merchant))
{
txQuery = txQuery.Where(t =>
t.Name.Contains(merchant) ||
(t.Merchant != null && t.Merchant.Name.Contains(merchant)));
}
if (DateTime.TryParse(minDateStr, out var minDate))
txQuery = txQuery.Where(t => t.Date >= minDate);
if (DateTime.TryParse(maxDateStr, out var maxDate))
txQuery = txQuery.Where(t => t.Date <= maxDate);
if (minAmount.HasValue)
{
var min = minAmount.Value;
txQuery = txQuery.Where(t => t.Amount <= -min || t.Amount >= min);
}
if (maxAmount.HasValue)
{
var max = maxAmount.Value;
txQuery = txQuery.Where(t => t.Amount >= -max && t.Amount <= max);
}
var transactions = await txQuery
.OrderByDescending(t => t.Date)
.Take(limit)
.Select(t => new
{
id = t.Id,
date = t.Date.ToString("yyyy-MM-dd"),
amount = t.Amount,
name = t.Name,
merchant = t.Merchant != null ? t.Merchant.Name : null,
category = t.Category
})
.ToListAsync();
return JsonSerializer.Serialize(new { transactions });
}
private async Task SearchMerchantsAsync(AIToolCall toolCall)
{
var query = toolCall.GetString("query") ?? "";
var merchants = await _db.Merchants
.Where(m => m.Name.Contains(query))
.Select(m => new
{
name = m.Name,
transactionCount = m.Transactions.Count,
topCategory = m.Transactions
.Where(t => t.Category != "")
.GroupBy(t => t.Category)
.OrderByDescending(g => g.Count())
.Select(g => g.Key)
.FirstOrDefault()
})
.Take(MaxResults)
.ToListAsync();
return JsonSerializer.Serialize(new { merchants });
}
}
}