diff --git a/guardrails/utils/parsing_utils.py b/guardrails/utils/parsing_utils.py index 11163549e..ada4bd17b 100644 --- a/guardrails/utils/parsing_utils.py +++ b/guardrails/utils/parsing_utils.py @@ -75,6 +75,15 @@ def get_code_block( def extract_json_from_ouput( output: str, ) -> Tuple[Optional[Union[Dict, List]], Optional[Exception]]: + # try to load the whole output as json first + # there can be corner cases with code blocks + # and json/codeblocks inside json + try: + output_as_dict = json.loads(output, strict=False) + return output_as_dict, None + except json.decoder.JSONDecodeError: + pass + # Find and extract json from code blocks extracted_code_block = output has_json_block, json_start, json_end = has_code_block(output, "json") diff --git a/tests/unit_tests/utils/test_json_utils.py b/tests/unit_tests/utils/test_json_utils.py index b34e13efb..331077928 100644 --- a/tests/unit_tests/utils/test_json_utils.py +++ b/tests/unit_tests/utils/test_json_utils.py @@ -1,5 +1,5 @@ import pytest - +import json from guardrails.utils.parsing_utils import extract_json_from_ouput @@ -75,6 +75,8 @@ not_even_json = "This isn't even json..." +codeblock_inside_json = json.dumps({"data": 'hello ```json\n{"foo":"<...>"}\n```'}) + @pytest.mark.parametrize( "llm_ouput,expected_output,expected_error", @@ -84,6 +86,11 @@ (no_code_block, {"a": 1}, None), (text_with_no_code_block, {"a": 1, "b": {"c": [{"d": 2}, {"e": 3}]}}, None), (text_with_json_code_block, {"a": 1}, None), + ( + codeblock_inside_json, + {"data": 'hello ```json\n{"foo":"<...>"}\n```'}, + None, + ), (js_code_block, None, "Expecting value: line 1 column 1 (char 0)"), ( invalid_json_code_block__quotes,