diff --git a/koboldcpp.py b/koboldcpp.py index 69e096fc0c1..b7ce5681ad2 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -3512,9 +3512,9 @@ def raise_exception(msg): last_assist_msg = "" if messages: last_assist_msg = messages[-1]["content"] - assist_should_prefill = (messages and messages[-1]["role"] == "assistant" and last_assist_msg and isinstance(last_assist_msg, str) and len(last_assist_msg.strip())>0) #avoid single character newline or space content + assist_should_prefill = (messages and messages[-1]["role"].lower() == "assistant" and last_assist_msg and isinstance(last_assist_msg, str) and len(last_assist_msg.strip())>0) #avoid single character newline or space content last_assist_msg = "" if not assist_should_prefill else last_assist_msg - messages_for_render = messages[:-1] if assist_should_prefill else messages + messages_for_render = messages[:-1] if len(messages) > 1 and assist_should_prefill else messages if tools and len(tools)>0: text = jinja_compiled_template.render(messages=messages_for_render, tools=tools, add_generation_prompt=True, bos_token="", eos_token="", **chat_template_kwargs) else: