Skip to content

Commit 2f8f0b2

Browse files
authored
KnowPro.NET: Bug fixes and improvements (#1764)
* Fixed bug recently introduced query compiler. Terms were being over-written. * More efficient Topic and Tag Matching; support wildcards
1 parent 7e6790d commit 2f8f0b2

File tree

11 files changed

+190
-24
lines changed

11 files changed

+190
-24
lines changed

dotnet/typeagent/examples/knowProConsole/ConversationEventHandler.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ namespace KnowProConsole;
66
public class ConversationEventHandler
77
{
88
InplaceText _inplaceUpdate;
9+
string _prevEventType = string.Empty;
910

1011
public ConversationEventHandler()
1112
{
@@ -47,6 +48,11 @@ private void KnowledgeExtractor_OnExtracted(BatchProgress item)
4748

4849
private void WriteProgress(BatchProgress progress, string label)
4950
{
51+
if (_prevEventType != label)
52+
{
53+
ConsoleWriter.WriteLine();
54+
_prevEventType = label;
55+
}
5056
_inplaceUpdate.Write($"[{label}: {progress.CountCompleted} / {progress.Count}]");
5157
}
5258
}

dotnet/typeagent/examples/knowProConsole/TestCommands.cs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -401,16 +401,10 @@ private async Task AnswerAsync(ParseResult args, CancellationToken cancellationT
401401
NamedArgs namedArgs = new NamedArgs(args);
402402
AnswerContext context = new AnswerContext();
403403

404-
List<ConcreteEntity> entities = await conversation.SemanticRefs.SelectAsync<SemanticRef, ConcreteEntity>(
405-
(sr) => sr.KnowledgeType == KnowledgeType.Entity ? sr.AsEntity() : null,
406-
cancellationToken
407-
);
404+
IList<ConcreteEntity> entities = await conversation.SemanticRefs.GetAllEntitiesAsync(cancellationToken);
408405
entities = [.. entities.ToDistinct()];
409406

410-
List<Topic> topics = await conversation.SemanticRefs.SelectAsync<SemanticRef, Topic>(
411-
(sr) => sr.KnowledgeType == KnowledgeType.Topic ? sr.AsTopic() : null,
412-
cancellationToken
413-
);
407+
IList<Topic> topics = await conversation.SemanticRefs.GetAllTopicsAsync(cancellationToken);
414408
topics = [.. topics.ToDistinct()];
415409

416410
context.Entities = entities.Map((e) => new RelevantEntity { Entity = e });

dotnet/typeagent/src/aiclient/EnvVars.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public static string Get(string key, string? keySuffix = null, string? defaultVa
4040

4141
public static int GetInt(string key, string? keySuffix = null, int? defaultValue = null)
4242
{
43-
var numString = Get(key, keySuffix, string.Empty);
43+
var numString = Get(key, keySuffix, defaultValue?.ToString());
4444
if (string.IsNullOrEmpty(numString) && defaultValue is not null)
4545
{
4646
return defaultValue.Value;
@@ -69,4 +69,4 @@ public static string ToVarName(string key, string? keySuffix = null)
6969
{
7070
return !string.IsNullOrEmpty(keySuffix) ? key + "_" + keySuffix : key;
7171
}
72-
}
72+
}

dotnet/typeagent/src/knowpro/ISemanticRefCollection.cs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,16 @@ namespace TypeAgent.KnowPro;
66
public interface ISemanticRefCollection : IAsyncCollection<SemanticRef>
77
{
88
ValueTask<TextRange> GetTextRangeAsync(int ordinal, CancellationToken cancellationToken = default);
9+
910
ValueTask<IList<TextRange>> GetTextRangeAsync(IList<int> ordinals, CancellationToken cancellationToken = default);
1011

1112
ValueTask<KnowledgeType> GetKnowledgeTypeAsync(int ordinal, CancellationToken cancellation = default);
13+
1214
ValueTask<IList<KnowledgeType>> GetKnowledgeTypeAsync(IList<int> ordinal, CancellationToken cancellation = default);
1315

14-
// TODO
15-
// Add methods to enumerate by knowledge Type, casting appropriately.
16-
// More efficient than looping over all
16+
ValueTask<IList<SemanticRef>> GetAllAsync(KnowledgeType? kType = null, CancellationToken cancellationToken = default);
17+
18+
ValueTask<IList<ScoredSemanticRefOrdinal>> GetAllOrdinalsAsync(KnowledgeType? kType = null, CancellationToken cancellationToken = default);
1719

1820
event Action<BatchProgress> OnKnowledgeExtracted;
1921
void NotifyKnowledgeProgress(BatchProgress progress);
@@ -90,4 +92,30 @@ public static async ValueTask<IList<Scored<ConcreteEntity>>> GetDistinctEntities
9092
? entitites.GetTopK(topK.Value)
9193
: [.. entitites];
9294
}
95+
96+
public static async ValueTask<IList<ConcreteEntity>> GetAllEntitiesAsync(
97+
this ISemanticRefCollection semanticRefs,
98+
CancellationToken cancellation = default
99+
)
100+
{
101+
var list = await semanticRefs.GetAllAsync(
102+
KnowledgeType.Entity,
103+
cancellation
104+
).ConfigureAwait(false);
105+
106+
return [.. list.Select((sr) => sr.AsEntity())];
107+
}
108+
109+
public static async ValueTask<IList<Topic>> GetAllTopicsAsync(
110+
this ISemanticRefCollection semanticRefs,
111+
CancellationToken cancellation = default
112+
)
113+
{
114+
var list = await semanticRefs.GetAllAsync(
115+
KnowledgeType.Topic,
116+
cancellation
117+
).ConfigureAwait(false);
118+
119+
return [.. list.Select((sr) => sr.AsTopic())];
120+
}
93121
}

dotnet/typeagent/src/knowpro/Query/LookupExtensions.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,4 +130,38 @@ public static async ValueTask<IList<ScoredSemanticRefOrdinal>> LookupPropertyAsy
130130
}
131131
return scoredRefs;
132132
}
133+
134+
public static async ValueTask<IList<ScoredSemanticRefOrdinal>?> LookupAllAsync(
135+
this ISemanticRefCollection semanticRefs,
136+
QueryEvalContext context,
137+
KnowledgeType knowledgeType,
138+
TextRangesInScope? rangesInScope,
139+
ScoreBooster? scoreBooster = null
140+
)
141+
{
142+
IList<ScoredSemanticRefOrdinal> scoredOrdinals = await semanticRefs.GetAllOrdinalsAsync(
143+
knowledgeType
144+
).ConfigureAwait(false);
145+
146+
if (rangesInScope is null)
147+
{
148+
return scoredOrdinals;
149+
}
150+
151+
// TODO: avoid this double alloction
152+
IList<TextRange> ranges = await semanticRefs.GetTextRangeAsync(
153+
scoredOrdinals.ToOrdinals()
154+
).ConfigureAwait(false);
155+
156+
List<ScoredSemanticRefOrdinal> scoredRefs = [];
157+
int count = ranges.Count;
158+
for (int i = 0; i < count; ++i)
159+
{
160+
if (rangesInScope.IsRangeInScope(ranges[i]))
161+
{
162+
scoredRefs.Add(scoredOrdinals[i]);
163+
}
164+
}
165+
return scoredRefs;
166+
}
133167
}

dotnet/typeagent/src/knowpro/Query/MatchTermExpr.cs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ Term relatedTerm
115115
}
116116
}
117117

118-
private ValueTask<IList<ScoredSemanticRefOrdinal>?> LookupTermAsync(QueryEvalContext context, Term term)
118+
protected virtual ValueTask<IList<ScoredSemanticRefOrdinal>?> LookupTermAsync(QueryEvalContext context, Term term)
119119
{
120120
return context.SemanticRefIndex.LookupTermAsync(
121121
context,
@@ -292,3 +292,36 @@ string propertyValue
292292
);
293293
}
294294
}
295+
296+
internal class MatchKnowledgeTypeExpr : MatchSearchTermExpr
297+
{
298+
KnowledgeType _kType;
299+
300+
301+
public MatchKnowledgeTypeExpr(SearchTerm searchTerm, KnowledgeType kType)
302+
: base(searchTerm)
303+
{
304+
_kType = kType;
305+
}
306+
307+
protected override ValueTask<IList<ScoredSemanticRefOrdinal>?> LookupTermAsync(QueryEvalContext context, Term term)
308+
{
309+
if (SearchTerm.IsWildcard())
310+
{
311+
return context.Conversation.SemanticRefs.LookupAllAsync(
312+
context,
313+
_kType,
314+
context.TextRangesInScope,
315+
ScoreBooster
316+
);
317+
}
318+
319+
return context.SemanticRefIndex.LookupTermAsync(
320+
context,
321+
term,
322+
context.TextRangesInScope,
323+
_kType,
324+
ScoreBooster
325+
);
326+
}
327+
}

dotnet/typeagent/src/knowpro/Query/QueryCompiler.cs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ public async ValueTask<QueryOpExpr<IList<ScoredMessageOrdinal>>> CompileMessageS
157157
matchFilter
158158
);
159159
termExpressions.Add(searchTermExpr);
160-
compiledTerms[0].Terms.Add(searchTerm.Clone());
160+
compiledTerms[0].Terms.Add(searchTerm);
161161
break;
162162

163163
}
@@ -215,7 +215,7 @@ public async ValueTask<QueryOpExpr<IList<ScoredMessageOrdinal>>> CompileMessageS
215215
matchFilter
216216
);
217217
termExpressions.Add(searchTermExpr);
218-
compiledTerms[0].Terms.Add(searchTerm.Clone());
218+
compiledTerms[0].Terms.Add(searchTerm);
219219
break;
220220
}
221221
}
@@ -275,10 +275,12 @@ private QueryOpExpr<SemanticRefAccumulator> CompileSelect(
275275
propertyTerm.PropertyValue.Term.Weight ??= Settings.EntityTermMatchWeight;
276276
}
277277
return new MatchPropertySearchTermExpr(propertyTerm);
278+
278279
case "tag":
280+
return new MatchKnowledgeTypeExpr(propertyTerm.PropertyValue, KnowledgeType.Tag);
281+
279282
case "topic":
280-
// TODO
281-
throw new NotImplementedException();
283+
return new MatchKnowledgeTypeExpr(propertyTerm.PropertyValue, KnowledgeType.Topic);
282284
}
283285
}
284286
else
@@ -372,7 +374,8 @@ private async ValueTask<QueryOpExpr<MessageAccumulator>> CompileMessageReRankAsy
372374
QueryOpExpr<MessageAccumulator> srcExpr,
373375
string? rawQueryText,
374376
SearchOptions? options
375-
) {
377+
)
378+
{
376379
var messageIndex = _conversation.SecondaryIndexes.MessageIndex;
377380
int messageCount = await messageIndex.GetCountAsync(
378381
_cancellationToken

dotnet/typeagent/src/knowpro/SearchTerm.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,16 @@ public void AddRelated(Term term)
6161

6262
internal bool RelatedTermsRequired { get; set; }
6363

64+
/*
6465
internal SearchTerm Clone()
6566
{
6667
return new SearchTerm(Term);
6768
}
68-
69+
*/
6970
internal SearchTerm ToRequired()
7071
{
71-
var copy = this.Clone();
72-
copy.RelatedTermsRequired = true;
73-
return copy;
72+
RelatedTermsRequired = true;
73+
return this;
7474
}
7575

7676
public static implicit operator SearchTerm(string value)

dotnet/typeagent/src/knowproStorage/Sqlite/SqliteDatabase.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ public List<T> GetList<T>(string commandText, Func<SqliteDataReader, T> cb)
7171
return reader.GetList<T>(cb);
7272
}
7373

74+
public List<T> GetList<T>(string commandText, Action<SqliteCommand> addParams, Func<SqliteDataReader, T> cb)
75+
{
76+
ArgumentVerify.ThrowIfNullOrEmpty(commandText, nameof(commandText));
77+
78+
using var command = CreateCommand(commandText);
79+
if (addParams is not null)
80+
{
81+
addParams(command);
82+
}
83+
using var reader = command.ExecuteReader();
84+
85+
return reader.GetList<T>(cb);
86+
}
87+
7488
public List<T>? GetListOrNull<T>(string commandText, Func<SqliteDataReader, T> cb)
7589
{
7690
ArgumentVerify.ThrowIfNullOrEmpty(commandText, nameof(commandText));

dotnet/typeagent/src/knowproStorage/Sqlite/SqliteMessageTextIndex.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ public IEnumerable<KeyValuePair<int, NormalizedEmbeddingB>> GetSubset(IList<int>
208208
string[] placeholderIds = SqliteDatabase.MakeInPlaceholderParamIds(batch.Count);
209209
var rows = _db.Enumerate(
210210
$@"
211-
SELECT msg_id, embedding FROM MessageTextIndex
211+
SELECT msg_id, embedding
212212
FROM MessageTextIndex WHERE msg_id IN ({SqliteDatabase.MakeInStatement(placeholderIds)})
213213
ORDER BY msg_id",
214214
(cmd) => cmd.AddPlaceholderParameters(placeholderIds, batch),

0 commit comments

Comments
 (0)