Skip to content

Commit e57293a

Browse files
authored
KnowPro.NET Updates (#1758)
* Knowledge Extraction * Knowledge Json Translator * Knowledge extraction schema * ActionParam parsing * Knowledge merging * Batch extraction * Refactoring * Bug fixes
1 parent c892ab4 commit e57293a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+647
-69
lines changed

dotnet/typeagent/examples/knowProConsole/Includes.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
global using TypeAgent.AIClient;
1919
global using TypeAgent.KnowPro;
20+
global using TypeAgent.KnowPro.KnowledgeExtractor;
2021
global using TypeAgent.KnowPro.Storage.Local;
2122
global using TypeAgent.KnowPro.Storage.Sqlite;
2223
global using TypeAgent.ConversationMemory;

dotnet/typeagent/examples/knowProConsole/PodcastCommands.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,14 @@ private async Task PodcastImportIndexAsync(ParseResult args, CancellationToken c
129129
}
130130
}
131131

132-
133132
private Podcast CreatePodcast(string name, bool createNew)
134133
{
135-
// TODO: standardize this boilerplate, esp the cache binding
136-
var model = new TextEmbeddingModelWithCache(256);
137-
ConversationSettings settings = new ConversationSettings(model);
138-
var provider = _kpContext.CreateStorageProvider<PodcastMessage, PodcastMessageMeta>(settings, name, createNew);
139-
model.Cache.PersistentCache = provider.GetEmbeddingCache();
134+
MemorySettings settings = new MemorySettings();
135+
var provider = _kpContext.CreateStorageProvider<PodcastMessage, PodcastMessageMeta>(
136+
settings.ConversationSettings,
137+
name,
138+
createNew
139+
);
140140

141141
var podcast = new Podcast(settings, provider);
142142
return podcast;

dotnet/typeagent/examples/knowProConsole/TestCommands.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public IList<Command> GetCommands()
2222
SearchMessagesTermsDef(),
2323
TestEmbeddingsDef(),
2424
SearchQueryTermsDef(),
25+
KnowledgeDef(),
2526
BuildIndexDef(),
2627
];
2728
}
@@ -308,6 +309,36 @@ private async Task SearchQueryTermsAsync(ParseResult args, CancellationToken can
308309
KnowProWriter.WriteJson(result);
309310
}
310311

312+
private Command KnowledgeDef()
313+
{
314+
Command cmd = new("kpTestKnowledge")
315+
{
316+
Args.Arg<string>("text")
317+
};
318+
cmd.TreatUnmatchedTokensAsErrors = false;
319+
cmd.SetAction(this.KnowledgeAsync);
320+
return cmd;
321+
}
322+
323+
private async Task KnowledgeAsync(ParseResult args, CancellationToken cancellationToken)
324+
{
325+
// IConversation conversation = EnsureConversation();
326+
327+
NamedArgs namedArgs = new NamedArgs(args);
328+
var text = namedArgs.Get("text");
329+
if (string.IsNullOrEmpty(text))
330+
{
331+
return;
332+
}
333+
var model = new OpenAIChatModel();
334+
KnowledgeExtractor exctractor = new KnowledgeExtractor(model);
335+
var result = await exctractor.ExtractAsync(text, cancellationToken);
336+
if (result is not null)
337+
{
338+
KnowProWriter.WriteJson(result);
339+
}
340+
}
341+
311342
private IConversation EnsureConversation()
312343
{
313344
return (_kpContext.Conversation is not null)

dotnet/typeagent/src/aiclient/ITextEmbeddingModel.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public static async Task<List<float[]>> GenerateInBatchesAsync(
5151

5252
var embeddingChunks = await chunks.MapAsync(
5353
concurrency,
54-
(chunk) => model.GenerateAsync(chunk, cancellationToken),
54+
model.GenerateAsync,
5555
notifyProgress,
5656
cancellationToken
5757
).ConfigureAwait(false);
@@ -62,7 +62,7 @@ public static async Task<List<float[]>> GenerateInBatchesAsync(
6262
{
6363
return await texts.MapAsync(
6464
concurrency,
65-
(value) => model.GenerateAsync(value, cancellationToken),
65+
model.GenerateAsync,
6666
progress,
6767
cancellationToken
6868
).ConfigureAwait(false);

dotnet/typeagent/src/common/Async.cs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ public static class Async
2121
public static async Task<List<TResult>> MapAsync<T, TResult>(
2222
this IList<T> list,
2323
int concurrency,
24-
Func<T, Task<TResult>> processor,
24+
Func<T, CancellationToken, Task<TResult>> processor,
2525
Action<BatchProgress>? progress = null,
26-
CancellationToken cancellationToken = default)
26+
CancellationToken cancellationToken = default
27+
)
2728
{
2829
ArgumentVerify.ThrowIfNullOrEmpty(list, nameof(list));
2930
ArgumentVerify.ThrowIfLessThanEqualZero(concurrency, nameof(concurrency));
@@ -36,7 +37,7 @@ public static async Task<List<TResult>> MapAsync<T, TResult>(
3637

3738
private static async Task<List<TResult>> MapSequentialAsync<T, TResult>(
3839
IList<T> list,
39-
Func<T, Task<TResult>> processor,
40+
Func<T, CancellationToken, Task<TResult>> processor,
4041
Action<BatchProgress>? progress,
4142
CancellationToken cancellationToken
4243
)
@@ -46,7 +47,7 @@ CancellationToken cancellationToken
4647
{
4748
cancellationToken.ThrowIfCancellationRequested();
4849

49-
var result = await processor(list[i]);
50+
var result = await processor(list[i], cancellationToken);
5051
results.Add(result);
5152
if (progress is not null)
5253
{
@@ -59,7 +60,7 @@ CancellationToken cancellationToken
5960
private static async Task<List<TResult>> MapConcurrentAsync<T, TResult>(
6061
IList<T> list,
6162
int concurrency,
62-
Func<T, Task<TResult>> processor,
63+
Func<T, CancellationToken, Task<TResult>> processor,
6364
Action<BatchProgress>? progress,
6465
CancellationToken cancellationToken
6566
)
@@ -72,7 +73,7 @@ CancellationToken cancellationToken
7273

7374
int batchSize = Math.Min(concurrency, totalCount - startAt);
7475
var batch = list.Slice(startAt, batchSize);
75-
var tasks = batch.Map<T, Task<TResult>>(processor);
76+
var tasks = batch.Map<T, Task<TResult>>((t) => processor(t, cancellationToken));
7677

7778
var batchResults = await Task.WhenAll(tasks);
7879

dotnet/typeagent/src/common/Cache.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ public class LRUCache<TKey, TValue> : ICache<TKey, TValue>
2424
{
2525
private readonly Dictionary<TKey, LinkedListNode<KeyValuePair<TKey, TValue>>> _index;
2626
private readonly LinkedList<KeyValuePair<TKey, TValue>> _itemList;
27-
private readonly int _highWatermark;
28-
private readonly int _lowWatermark;
27+
private int _highWatermark;
28+
private int _lowWatermark;
2929

3030
public LRUCache(int maxEntries, IEqualityComparer<TKey> comparer = null)
3131
: this(maxEntries - 1, maxEntries, comparer)
@@ -47,7 +47,9 @@ public LRUCache(int lowWatermark, int highWatermark, IEqualityComparer<TKey> com
4747
}
4848

4949
public int Count => _itemList.Count;
50+
5051
public int HighWatermark => _highWatermark;
52+
5153
public int LowWatermark => _lowWatermark;
5254

5355
public event Action<KeyValuePair<TKey, TValue>> Purged;
@@ -134,6 +136,13 @@ public void Clear()
134136
_itemList.Clear();
135137
}
136138

139+
public void SetCount(int maxEntries)
140+
{
141+
_highWatermark = maxEntries;
142+
_lowWatermark = maxEntries - 1;
143+
Trim();
144+
}
145+
137146
void MakeMRU(LinkedListNode<KeyValuePair<TKey, TValue>> node)
138147
{
139148
_itemList.Remove(node);

dotnet/typeagent/src/common/ListExtensions.cs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ public static int BinarySearchFirst<T, TSearchValue>(
130130
return lo;
131131
}
132132

133-
public static List<ScoredItem<T>> GetTopK<T>(this IEnumerable<ScoredItem<T>> list, int topK)
133+
public static List<Scored<T>> GetTopK<T>(this IEnumerable<Scored<T>> list, int topK)
134134
{
135135
var topNList = new TopNCollection<T>(topK);
136136
topNList.Add(list);
@@ -144,4 +144,17 @@ public static void Fill<T>(this IList<T> list, T value, int count)
144144
list.Add(value);
145145
}
146146
}
147+
148+
public static T[] Append<T>(this T[]? list, T value)
149+
{
150+
if (list.IsNullOrEmpty())
151+
{
152+
return [value];
153+
}
154+
155+
T[] copy = new T[list.Length + 1];
156+
Array.Copy(list, copy, list.Length);
157+
copy[list.Length] = value;
158+
return copy;
159+
}
147160
}

dotnet/typeagent/src/common/Retry.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@ public RetrySettings()
1515
JitterRange = 0.5;
1616
}
1717

18+
public RetrySettings(int maxRetries)
19+
{
20+
ArgumentVerify.ThrowIfLessThan(maxRetries, 1, nameof(maxRetries));
21+
MaxRetries = maxRetries;
22+
}
23+
1824
public int MaxRetries { get; set; }
1925

2026
public int RetryPauseMs { get; set; }

dotnet/typeagent/src/common/ScoredItem.cs renamed to dotnet/typeagent/src/common/Scored.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
namespace TypeAgent.Common;
55

6-
public struct ScoredItem<T> : IComparable<ScoredItem<T>>
6+
public struct Scored<T> : IComparable<Scored<T>>
77
{
8-
public ScoredItem(T item, double score)
8+
public Scored(T item, double score)
99
{
1010
this.Item = item;
1111
this.Score = score;
@@ -14,25 +14,25 @@ public ScoredItem(T item, double score)
1414
public T Item { get; set; }
1515
public double Score { get; set; }
1616

17-
public readonly int CompareTo(ScoredItem<T> other)
17+
public readonly int CompareTo(Scored<T> other)
1818
{
1919
return this.Score.CompareTo(other.Score);
2020
}
2121

2222
public override readonly string ToString() => $"{this.Score}, {this.Item}";
2323

24-
public static implicit operator double(ScoredItem<T> src)
24+
public static implicit operator double(Scored<T> src)
2525
{
2626
return src.Score;
2727
}
2828

29-
public static implicit operator T(ScoredItem<T> src)
29+
public static implicit operator T(Scored<T> src)
3030
{
3131
return src.Item;
3232
}
3333

34-
public static implicit operator ScoredItem<T>(KeyValuePair<T, double> src)
34+
public static implicit operator Scored<T>(KeyValuePair<T, double> src)
3535
{
36-
return new ScoredItem<T>(src.Key, src.Value);
36+
return new Scored<T>(src.Key, src.Value);
3737
}
3838
}

dotnet/typeagent/src/common/TopNCollection.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ namespace TypeAgent.Common;
55

66
public class TopNCollection<T>
77
{
8-
private List<ScoredItem<T>>? _items;
8+
private List<Scored<T>>? _items;
99
private int _count; // Actual count, since items always ha a
1010
private int _maxCount;
1111

12-
public TopNCollection(int maxCount, List<ScoredItem<T>> buffer = null)
12+
public TopNCollection(int maxCount, List<Scored<T>> buffer = null)
1313
{
1414
ArgumentVerify.ThrowIfLessThan(maxCount, 1, nameof(maxCount));
1515

@@ -21,7 +21,7 @@ public TopNCollection(int maxCount, List<ScoredItem<T>> buffer = null)
2121

2222
public int Count => _count;
2323

24-
public ScoredItem<T> GetTop()
24+
public Scored<T> GetTop()
2525
{
2626
VerifyNotEmpty();
2727
return _items[1];
@@ -44,17 +44,17 @@ public void Add(T item, double score)
4444
}
4545
RemoveTop();
4646
_count++;
47-
_items[_count] = new ScoredItem<T>(item, score);
47+
_items[_count] = new Scored<T>(item, score);
4848
}
4949
else
5050
{
5151
_count++;
52-
_items.Add(new ScoredItem<T>(item, score));
52+
_items.Add(new Scored<T>(item, score));
5353
}
5454
UpHeap(_count);
5555
}
5656

57-
public void Add(IEnumerable<ScoredItem<T>> items)
57+
public void Add(IEnumerable<Scored<T>> items)
5858
{
5959
ArgumentVerify.ThrowIfNull(items, nameof(items));
6060
foreach(var item in items)
@@ -68,7 +68,7 @@ public void Add(IEnumerable<ScoredItem<T>> items)
6868
/// Returns the sorted buffer, and clears the collection
6969
/// </summary>
7070
/// <returns></returns>
71-
public List<ScoredItem<T>> ByRankAndClear()
71+
public List<Scored<T>> ByRankAndClear()
7272
{
7373
if (_count == 0)
7474
{
@@ -98,7 +98,7 @@ private void SortDescending()
9898
_count = count;
9999
}
100100

101-
private ScoredItem<T> RemoveTop()
101+
private Scored<T> RemoveTop()
102102
{
103103
// At the top
104104
var item = _items[1];
@@ -158,7 +158,7 @@ private void EnsureInitialized()
158158
if (_items is null)
159159
{
160160
_items = [];
161-
_items.Add(new ScoredItem<T>() { Score = double.MinValue, Item = default });
161+
_items.Add(new Scored<T>() { Score = double.MinValue, Item = default });
162162
}
163163
}
164164

0 commit comments

Comments
 (0)