Feature: Add AI tool-use framework for database-aware receipt parsing

Introduce provider-agnostic tool definitions (AIToolRegistry) and an
executor (AIToolExecutor) that lets AI models query MoneyMap's database
during receipt parsing via search_categories, search_transactions, and
search_merchants tools. Includes an enriched-context fallback for
providers that don't support function calling (Ollama).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-15 19:13:47 -05:00
parent 865195ad16
commit 5c0f0f3fca
2 changed files with 439 additions and 0 deletions

View File

@@ -0,0 +1,159 @@
using System.Text.Json.Serialization;
namespace MoneyMap.Services.AITools
{
/// <summary>
/// Provider-agnostic tool definition for AI function calling.
/// </summary>
public class AIToolDefinition
{
public string Name { get; set; } = "";
public string Description { get; set; } = "";
public List<AIToolParameter> Parameters { get; set; } = new();
}
public class AIToolParameter
{
public string Name { get; set; } = "";
public string Type { get; set; } = "string"; // string, number, integer
public string Description { get; set; } = "";
public bool Required { get; set; }
}
/// <summary>
/// Represents a tool call from the AI model.
/// </summary>
public class AIToolCall
{
public string Id { get; set; } = "";
public string Name { get; set; } = "";
public Dictionary<string, object?> Arguments { get; set; } = new();
public string? GetString(string key)
{
if (Arguments.TryGetValue(key, out var val) && val != null)
return val.ToString();
return null;
}
public decimal? GetDecimal(string key)
{
if (Arguments.TryGetValue(key, out var val) && val != null)
{
if (decimal.TryParse(val.ToString(), out var d))
return d;
}
return null;
}
public int? GetInt(string key)
{
if (Arguments.TryGetValue(key, out var val) && val != null)
{
if (int.TryParse(val.ToString(), out var i))
return i;
}
return null;
}
}
/// <summary>
/// Result of executing a tool, returned to the AI.
/// </summary>
public class AIToolResult
{
public string ToolCallId { get; set; } = "";
public string Content { get; set; } = "";
public bool IsError { get; set; }
}
/// <summary>
/// Static registry of all tools available to the receipt parsing AI.
/// </summary>
public static class AIToolRegistry
{
public static List<AIToolDefinition> GetAllTools() => new()
{
new AIToolDefinition
{
Name = "search_categories",
Description = "Search existing expense categories in the system. Returns category names with their matching patterns and associated merchants. Use this to find the correct category name for line items and the overall receipt instead of inventing new ones.",
Parameters = new()
{
new AIToolParameter
{
Name = "query",
Type = "string",
Description = "Optional filter text to search category names (e.g., 'grocery', 'utility'). Omit to get all categories.",
Required = false
}
}
},
new AIToolDefinition
{
Name = "search_transactions",
Description = "Search bank transactions to find one that matches this receipt. Returns transaction ID, date, amount, name, merchant, and category. Use this to suggest which transaction this receipt belongs to.",
Parameters = new()
{
new AIToolParameter
{
Name = "merchant",
Type = "string",
Description = "Merchant or store name to search for (partial match)",
Required = false
},
new AIToolParameter
{
Name = "minDate",
Type = "string",
Description = "Earliest transaction date (YYYY-MM-DD format)",
Required = false
},
new AIToolParameter
{
Name = "maxDate",
Type = "string",
Description = "Latest transaction date (YYYY-MM-DD format)",
Required = false
},
new AIToolParameter
{
Name = "minAmount",
Type = "number",
Description = "Minimum absolute transaction amount",
Required = false
},
new AIToolParameter
{
Name = "maxAmount",
Type = "number",
Description = "Maximum absolute transaction amount",
Required = false
},
new AIToolParameter
{
Name = "limit",
Type = "integer",
Description = "Maximum results to return (default 10, max 20)",
Required = false
}
}
},
new AIToolDefinition
{
Name = "search_merchants",
Description = "Search known merchants by name. Returns merchant name, transaction count, and most common category. Use this to find the correct merchant name and see what category is typically used for them.",
Parameters = new()
{
new AIToolParameter
{
Name = "query",
Type = "string",
Description = "Merchant name to search for (partial match)",
Required = true
}
}
}
};
}
}

View File

@@ -0,0 +1,280 @@
using Microsoft.EntityFrameworkCore;
using MoneyMap.Data;
using System.Text;
using System.Text.Json;
namespace MoneyMap.Services.AITools
{
public interface IAIToolExecutor
{
/// <summary>
/// Execute a single tool call and return the result as JSON.
/// </summary>
Task<AIToolResult> ExecuteAsync(AIToolCall toolCall);
/// <summary>
/// Pre-fetch all relevant context as a text block for providers that don't support tool use (Ollama).
/// </summary>
Task<string> GetEnrichedContextAsync(DateTime? receiptDate = null, decimal? total = null, string? merchantHint = null);
}
public class AIToolExecutor : IAIToolExecutor
{
private readonly MoneyMapContext _db;
private readonly ILogger<AIToolExecutor> _logger;
private const int MaxResults = 20;
public AIToolExecutor(MoneyMapContext db, ILogger<AIToolExecutor> logger)
{
_db = db;
_logger = logger;
}
public async Task<AIToolResult> ExecuteAsync(AIToolCall toolCall)
{
_logger.LogInformation("Executing AI tool: {ToolName} with args: {Args}",
toolCall.Name, JsonSerializer.Serialize(toolCall.Arguments));
try
{
var result = toolCall.Name switch
{
"search_categories" => await SearchCategoriesAsync(toolCall),
"search_transactions" => await SearchTransactionsAsync(toolCall),
"search_merchants" => await SearchMerchantsAsync(toolCall),
_ => $"{{\"error\": \"Unknown tool: {toolCall.Name}\"}}"
};
_logger.LogInformation("Tool {ToolName} returned {Length} chars", toolCall.Name, result.Length);
return new AIToolResult
{
ToolCallId = toolCall.Id,
Content = result
};
}
catch (Exception ex)
{
_logger.LogError(ex, "Error executing tool {ToolName}", toolCall.Name);
return new AIToolResult
{
ToolCallId = toolCall.Id,
Content = JsonSerializer.Serialize(new { error = ex.Message }),
IsError = true
};
}
}
public async Task<string> GetEnrichedContextAsync(DateTime? receiptDate, decimal? total, string? merchantHint)
{
var sb = new StringBuilder();
sb.AppendLine("=== DATABASE CONTEXT (use this to match categories and transactions) ===");
sb.AppendLine();
// Categories
var categories = await _db.CategoryMappings
.Include(cm => cm.Merchant)
.OrderBy(cm => cm.Category)
.ToListAsync();
var grouped = categories.GroupBy(c => c.Category).ToList();
sb.AppendLine($"EXISTING CATEGORIES ({grouped.Count} total):");
foreach (var group in grouped)
{
var patterns = group.Select(c => c.Pattern).Take(5);
var merchants = group.Where(c => c.Merchant != null).Select(c => c.Merchant!.Name).Distinct().Take(3);
sb.Append($" - {group.Key}: patterns=[{string.Join(", ", patterns)}]");
if (merchants.Any())
sb.Append($", merchants=[{string.Join(", ", merchants)}]");
sb.AppendLine();
}
sb.AppendLine();
// Merchants matching hint
if (!string.IsNullOrWhiteSpace(merchantHint))
{
var matchingMerchants = await _db.Merchants
.Where(m => m.Name.Contains(merchantHint))
.Select(m => new
{
m.Name,
TransactionCount = m.Transactions.Count,
TopCategory = m.Transactions
.Where(t => t.Category != "")
.GroupBy(t => t.Category)
.OrderByDescending(g => g.Count())
.Select(g => g.Key)
.FirstOrDefault()
})
.Take(10)
.ToListAsync();
if (matchingMerchants.Count > 0)
{
sb.AppendLine($"MATCHING MERCHANTS for \"{merchantHint}\":");
foreach (var m in matchingMerchants)
sb.AppendLine($" - {m.Name} ({m.TransactionCount} transactions, typical category: {m.TopCategory ?? "none"})");
sb.AppendLine();
}
}
// Matching transactions
if (receiptDate.HasValue || total.HasValue)
{
var txQuery = _db.Transactions
.Include(t => t.Merchant)
.Where(t => !_db.Receipts.Any(r => r.TransactionId == t.Id))
.AsQueryable();
if (receiptDate.HasValue)
{
var minDate = receiptDate.Value.AddDays(-1);
var maxDate = receiptDate.Value.AddDays(7);
txQuery = txQuery.Where(t => t.Date >= minDate && t.Date <= maxDate);
}
if (total.HasValue)
{
var absTotal = Math.Abs(total.Value);
var minAmt = absTotal * 0.9m;
var maxAmt = absTotal * 1.1m;
txQuery = txQuery.Where(t =>
(t.Amount >= -maxAmt && t.Amount <= -minAmt) ||
(t.Amount >= minAmt && t.Amount <= maxAmt));
}
var transactions = await txQuery
.OrderBy(t => t.Date)
.Take(10)
.ToListAsync();
if (transactions.Count > 0)
{
sb.AppendLine("CANDIDATE TRANSACTIONS (unmapped, matching date/amount):");
foreach (var t in transactions)
{
sb.AppendLine($" - ID={t.Id}, Date={t.Date:yyyy-MM-dd}, Amount={t.Amount:C}, Name=\"{t.Name}\", " +
$"Merchant={t.Merchant?.Name ?? "none"}, Category={t.Category}");
}
sb.AppendLine();
}
}
sb.AppendLine("=== END DATABASE CONTEXT ===");
return sb.ToString();
}
private async Task<string> SearchCategoriesAsync(AIToolCall toolCall)
{
var query = toolCall.GetString("query");
var mappings = _db.CategoryMappings
.Include(cm => cm.Merchant)
.AsQueryable();
if (!string.IsNullOrWhiteSpace(query))
mappings = mappings.Where(cm => cm.Category.Contains(query));
var results = await mappings
.OrderBy(cm => cm.Category)
.ToListAsync();
var grouped = results
.GroupBy(c => c.Category)
.Take(MaxResults)
.Select(g => new
{
category = g.Key,
patterns = g.Select(c => c.Pattern).Take(5).ToList(),
merchants = g.Where(c => c.Merchant != null)
.Select(c => c.Merchant!.Name)
.Distinct()
.Take(5)
.ToList()
})
.ToList();
return JsonSerializer.Serialize(new { categories = grouped });
}
private async Task<string> SearchTransactionsAsync(AIToolCall toolCall)
{
var merchant = toolCall.GetString("merchant");
var minDateStr = toolCall.GetString("minDate");
var maxDateStr = toolCall.GetString("maxDate");
var minAmount = toolCall.GetDecimal("minAmount");
var maxAmount = toolCall.GetDecimal("maxAmount");
var limit = toolCall.GetInt("limit") ?? 10;
limit = Math.Min(limit, MaxResults);
var txQuery = _db.Transactions
.Include(t => t.Merchant)
.Where(t => !_db.Receipts.Any(r => r.TransactionId == t.Id))
.AsQueryable();
if (!string.IsNullOrWhiteSpace(merchant))
{
txQuery = txQuery.Where(t =>
t.Name.Contains(merchant) ||
(t.Merchant != null && t.Merchant.Name.Contains(merchant)));
}
if (DateTime.TryParse(minDateStr, out var minDate))
txQuery = txQuery.Where(t => t.Date >= minDate);
if (DateTime.TryParse(maxDateStr, out var maxDate))
txQuery = txQuery.Where(t => t.Date <= maxDate);
if (minAmount.HasValue)
{
var min = minAmount.Value;
txQuery = txQuery.Where(t => t.Amount <= -min || t.Amount >= min);
}
if (maxAmount.HasValue)
{
var max = maxAmount.Value;
txQuery = txQuery.Where(t => t.Amount >= -max && t.Amount <= max);
}
var transactions = await txQuery
.OrderByDescending(t => t.Date)
.Take(limit)
.Select(t => new
{
id = t.Id,
date = t.Date.ToString("yyyy-MM-dd"),
amount = t.Amount,
name = t.Name,
merchant = t.Merchant != null ? t.Merchant.Name : null,
category = t.Category
})
.ToListAsync();
return JsonSerializer.Serialize(new { transactions });
}
private async Task<string> SearchMerchantsAsync(AIToolCall toolCall)
{
var query = toolCall.GetString("query") ?? "";
var merchants = await _db.Merchants
.Where(m => m.Name.Contains(query))
.Select(m => new
{
name = m.Name,
transactionCount = m.Transactions.Count,
topCategory = m.Transactions
.Where(t => t.Category != "")
.GroupBy(t => t.Category)
.OrderByDescending(g => g.Count())
.Select(g => g.Key)
.FirstOrDefault()
})
.Take(MaxResults)
.ToListAsync();
return JsonSerializer.Serialize(new { merchants });
}
}
}