Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 96 additions & 19 deletions LLMlean/API/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,6 @@ structure ChatGenerationOptions where
numSamples : Nat := defaultSamples
deriving ToJson, FromJson

structure ChatGenerationOptionsQed where
temperature : Float := defaultTemperature
maxTokens : Nat := defaultMaxTokens
stopSequences : List String := defaultStopProof
numSamples : Nat := defaultSamples
deriving ToJson, FromJson

structure OpenAIMessage where
role : String
content : String
Expand Down Expand Up @@ -240,18 +233,6 @@ def getChatGenerationOptions (api : API) (tacticKind : TacticKind): CoreM ChatGe
stopSequences := defaultStopTactic
}

def getChatGenerationOptionsQed (api : API) (tacticKind : TacticKind) : CoreM ChatGenerationOptionsQed := do
let numSamples ← getNumSamples api tacticKind
let maxTokens ← getMaxTokens api tacticKind
-- Print configuration in verbose mode
printConfiguration api tacticKind numSamples maxTokens
return {
numSamples := numSamples
temperature := defaultTemperature,
maxTokens := maxTokens,
stopSequences := defaultStopProof
}

/--
Parses a string consisting of Markdown text, and extracts the Lean code blocks.
The code blocks are enclosed in triple backticks.
Expand All @@ -270,4 +251,100 @@ def getMarkdownLeanCodeBlocks (markdown : String) : List String := Id.run do
blocks := blocks ++ [part.headD ""]
return blocks


/--
Parses a proof out of a response from the LLM.
The proof is expected to be enclosed in `[PROOF]...[/PROOF]` tags.
-/
def splitProof (text : String) : String :=
let text := ((text.splitOn "[PROOF]").tailD [text]).headD text
match (text.splitOn "[/PROOF]").head? with
| some s => s.trim
| none => text.trim

def generateOpenAI (prompts : List String)
(api : API) (options : ChatGenerationOptions) : IO $ Array (String × Float) := do
let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity
for prompt in prompts do
let req : OpenAIGenerationRequest := {
model := api.model,
messages := [
{
role := "user",
content := prompt
}
],
n := options.numSamples,
temperature := options.temperature,
max_tokens := options.maxTokens,
stop := options.stopSequences
}
let res : OpenAIResponse ← post req api.baseUrl api.key
for result in res.choices do
results := results.insert result.message.content

let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0)
return finalResults

def generateAnthropic (prompts : List String)
(api : API) (options : ChatGenerationOptions) : IO $ Array (String × Float) := do
let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity
for prompt in prompts do
for i in List.range options.numSamples do
let temperature := if i == 1 then 0.0 else options.temperature
let req : AnthropicGenerationRequest := {
model := api.model,
messages := [
{
role := "user",
content := prompt
}
],
temperature := temperature,
max_tokens := options.maxTokens,
stop_sequences := options.stopSequences
}
let res : AnthropicResponse ← post req api.baseUrl api.key
for result in res.content do
results := results.insert result.text

let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0)
return finalResults

def generateOllama (prompts : List String)
(api : API) (options : ChatGenerationOptions) : IO $ Array (String × Float) := do
let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity
for prompt in prompts do
for i in List.range options.numSamples do
let temperature := if i == 1 then 0.0 else options.temperature
let req : OllamaGenerationRequest := {
model := api.model,
prompt := prompt,
stream := false,
options := {
temperature := temperature,
stop := options.stopSequences,
num_predict := options.maxTokens
}
}
let res : OllamaResponse ← post req api.baseUrl api.key
results := results.insert res.response

return results.toArray.map (fun x => (x, 1.0))

/-!
## Main Handler
-/
def Config.API.generate
(api : API) (prompts : List String) (options : ChatGenerationOptions): CoreM $ Array (String × Float) := do
match api.kind with
| APIKind.Ollama =>
generateOllama prompts api options
| APIKind.TogetherAI =>
generateOpenAI prompts api options
| APIKind.OpenAI =>
generateOpenAI prompts api options
| APIKind.Anthropic =>
generateAnthropic prompts api options

end LLMlean
204 changes: 16 additions & 188 deletions LLMlean/API/ProofGen.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,108 +6,17 @@ open Lean LLMlean.Config

namespace LLMlean

/--
Parses a proof out of a response from the LLM.
The proof is expected to be enclosed in `[PROOF]...[/PROOF]` tags.
-/
def splitProof (text : String) : String :=
let text := ((text.splitOn "[PROOF]").tailD [text]).headD text
match (text.splitOn "[/PROOF]").head? with
| some s => s.trim
| none => text.trim

/-!
## OpenAI
-/
def parseResponseQedOpenAI (res: OpenAIResponse) : Array String :=
(res.choices.map fun x => (splitProof x.message.content)).toArray

def qedOpenAI (prompts : List String)
(api : API) (options : ChatGenerationOptionsQed) : IO $ Array (String × Float) := do
let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity
for prompt in prompts do
let req : OpenAIGenerationRequest := {
model := api.model,
messages := [
{
role := "user",
content := prompt
}
],
n := options.numSamples,
temperature := options.temperature,
max_tokens := options.maxTokens,
stop := options.stopSequences
}
let res : OpenAIResponse ← post req api.baseUrl api.key
for result in (parseResponseQedOpenAI res) do
results := results.insert result

let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0)
return finalResults

/-!
## Anthropic
-/
def parseResponseQedAnthropic (res: AnthropicResponse) : Array String :=
(res.content.map fun x => (splitProof x.text)).toArray
open API

def qedAnthropic (prompts : List String)
(api : API) (options : ChatGenerationOptionsQed) : IO $ Array (String × Float) := do
let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity
for prompt in prompts do
for i in List.range options.numSamples do
let temperature := if i == 1 then 0.0 else options.temperature
let req : AnthropicGenerationRequest := {
model := api.model,
messages := [
{
role := "user",
content := prompt
}
],
temperature := temperature,
max_tokens := options.maxTokens,
stop_sequences := options.stopSequences
}
let res : AnthropicResponse ← post req api.baseUrl api.key
for result in (parseResponseQedAnthropic res) do
results := results.insert result

let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0)
return finalResults

/-!
## Ollama
-/
def parseResponseQedOllama (res: OllamaResponse) : String :=
splitProof res.response

def qedOllama (prompts : List String)
(api : API) (options : ChatGenerationOptionsQed) : IO $ Array (String × Float) := do
let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity
for prompt in prompts do
for i in List.range options.numSamples do
let temperature := if i == 1 then 0.0 else options.temperature
let req : OllamaGenerationRequest := {
model := api.model,
prompt := prompt,
stream := false,
options := {
temperature := temperature,
stop := options.stopSequences,
num_predict := options.maxTokens
}
}
let res : OllamaResponse ← post req api.baseUrl api.key
results := results.insert (parseResponseQedOllama res)

let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0)
return finalResults

/-!
## Ollama with markdown output (e.g., Kimina Prover)
/--
Generates proof completions using the LLM API.
-/
def LLMlean.Config.API.proofCompletion
(api : API) (tacticState : String) (context : String) : CoreM $ Array (String × Float) := do
let prompts := makeQedPrompts api.promptKind context tacticState
let options ← getChatGenerationOptions api TacticKind.LLMQed
let responses := (← generate api prompts options).map (fun (x, p) => (splitProof x, p))
return responses.filter (fun (x, _) => filterGeneration x)

/--
Extracts proof from markdown response by finding the last code block
Expand Down Expand Up @@ -138,80 +47,6 @@ def extractProofFromMarkdownResponse (context : String) (response : String) : Op
-- If we can't find the context, return the whole block
some lastBlock.trim

def qedOllamaMarkdown (prompts : List String) (context : String)
(api : API) (options : ChatGenerationOptionsQed) : IO $ Array (String × Float) := do
let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity
for prompt in prompts do
for i in List.range options.numSamples do
let temperature := if i == 1 then 0.0 else options.temperature
let req : OllamaGenerationRequest := {
model := api.model,
prompt := prompt,
stream := false,
options := {
temperature := temperature,
stop := options.stopSequences,
num_predict := options.maxTokens
}
}
let res : OllamaResponse ← post req api.baseUrl api.key
match extractProofFromMarkdownResponse context res.response with
| some proof => results := results.insert proof
| none => results := results

let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0)
return finalResults

/-!
## Ollama with tactic output (e.g., BFS-Prover)
-/
def qedOllamaTactic (prompts : List String)
(api : API) (options : ChatGenerationOptionsQed) : IO $ Array (String × Float) := do
let mut results : Std.HashSet String := Std.HashSet.emptyWithCapacity
for prompt in prompts do
for i in List.range options.numSamples do
let temperature := if i == 1 then 0.0 else options.temperature
let req : OllamaGenerationRequest := {
model := api.model,
prompt := prompt,
stream := false,
options := {
temperature := temperature,
num_predict := options.maxTokens,
stop := options.stopSequences
}
}
let res : OllamaResponse ← post req api.baseUrl api.key
let tactic := res.response
results := results.insert tactic

let finalResults := (results.toArray.filter filterGeneration).map fun x => (x, 1.0)
return finalResults

/-!
## Main Handler
-/

/--
Generates proof completions using the LLM API.
-/
def LLMlean.Config.API.proofCompletion
(api : API) (tacticState : String) (context : String) : CoreM $ Array (String × Float) := do
let prompts := makeQedPrompts api.promptKind context tacticState
let options ← getChatGenerationOptionsQed api TacticKind.LLMQed
match api.kind with
| APIKind.Ollama =>
match api.responseFormat with
| ResponseFormat.Markdown =>
qedOllamaMarkdown prompts context api options
| _ =>
qedOllama prompts api options
| APIKind.TogetherAI =>
qedOpenAI prompts api options
| APIKind.OpenAI =>
qedOpenAI prompts api options
| APIKind.Anthropic =>
qedAnthropic prompts api options

/--
Generates proof completions with refinement context using the LLM API.
Expand All @@ -220,19 +55,12 @@ def LLMlean.Config.API.proofCompletionRefinement
(api : API) (tacticState : String) (context : String)
(previousAttempt : String) (errorMsg : String) : CoreM $ Array (String × Float) := do
let prompts := makeQedRefinementPrompts api.promptKind context tacticState previousAttempt errorMsg
let options ← getChatGenerationOptionsQed api TacticKind.LLMQed
match api.kind with
| APIKind.Ollama =>
match api.responseFormat with
| ResponseFormat.Markdown =>
qedOllamaMarkdown prompts context api options
| _ =>
qedOllama prompts api options
| APIKind.TogetherAI =>
qedOpenAI prompts api options
| APIKind.OpenAI =>
qedOpenAI prompts api options
| APIKind.Anthropic =>
qedAnthropic prompts api options
let options ← getChatGenerationOptions api TacticKind.LLMQed
let responses ← generate api prompts options
return Std.HashMap.toArray (responses.foldl (fun results (response, prob) =>
match extractProofFromMarkdownResponse context response with
| some proof => results.insert proof prob
| none => results
) {})

end LLMlean
Loading