@@ -48,35 +48,11 @@ fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
4848 ) ?;
4949 }
5050
51- let sql = "
52- WITH semantic_search AS (
53- SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank
54- FROM documents
55- ORDER BY embedding <=> $2
56- LIMIT 20
57- ),
58- keyword_search AS (
59- SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
60- FROM documents, plainto_tsquery('english', $1) query
61- WHERE to_tsvector('english', content) @@ query
62- ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
63- LIMIT 20
64- )
65- SELECT
66- COALESCE(semantic_search.id, keyword_search.id) AS id,
67- COALESCE(1.0 / ($3::double precision + semantic_search.rank), 0.0) +
68- COALESCE(1.0 / ($3::double precision + keyword_search.rank), 0.0) AS score
69- FROM semantic_search
70- FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
71- ORDER BY score DESC
72- LIMIT 5
73- " ;
74-
7551 let query = "growling bear" ;
7652 let query_embedding = model. embed ( query) ?;
7753 let k = 60.0 ;
7854
79- for row in client. query ( sql , & [ & query, & Vector :: from ( query_embedding) , & k] ) ? {
55+ for row in client. query ( HYBRID_SQL , & [ & query, & Vector :: from ( query_embedding) , & k] ) ? {
8056 let id: i32 = row. get ( 0 ) ;
8157 let score: f64 = row. get ( 1 ) ;
8258 println ! ( "document: {}, RRF score: {}" , id, score) ;
@@ -85,6 +61,30 @@ fn main() -> Result<(), Box<dyn Error + Send + Sync>> {
8561 Ok ( ( ) )
8662}
8763
64+ const HYBRID_SQL : & str = "
65+ WITH semantic_search AS (
66+ SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank
67+ FROM documents
68+ ORDER BY embedding <=> $2
69+ LIMIT 20
70+ ),
71+ keyword_search AS (
72+ SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
73+ FROM documents, plainto_tsquery('english', $1) query
74+ WHERE to_tsvector('english', content) @@ query
75+ ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
76+ LIMIT 20
77+ )
78+ SELECT
79+ COALESCE(semantic_search.id, keyword_search.id) AS id,
80+ COALESCE(1.0 / ($3::double precision + semantic_search.rank), 0.0) +
81+ COALESCE(1.0 / ($3::double precision + keyword_search.rank), 0.0) AS score
82+ FROM semantic_search
83+ FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
84+ ORDER BY score DESC
85+ LIMIT 5
86+ " ;
87+
8888struct EmbeddingModel {
8989 tokenizer : Tokenizer ,
9090 model : BertModel ,
0 commit comments