diff --git a/libs/estdlib/src/lists.erl b/libs/estdlib/src/lists.erl index 5a256d4b5..9179ed6d7 100644 --- a/libs/estdlib/src/lists.erl +++ b/libs/estdlib/src/lists.erl @@ -452,25 +452,12 @@ any(Fun, L) -> %%----------------------------------------------------------------------------- %% @param L the list to flatten %% @returns flattened list -%% @doc recursively flattens elements of L into a single list +%% @doc Flattens elements of L into a single list %% @end %%----------------------------------------------------------------------------- -spec flatten(L :: list()) -> list(). -flatten(L) when is_list(L) -> - flatten(L, []). - -%% @private -%% pre: Accum is flattened -flatten([], Accum) -> - Accum; -flatten([H | T], Accum) when is_list(H) -> - FlattenedT = flatten(T, Accum), - flatten(H, FlattenedT); -flatten([H | T], Accum) -> - FlattenedT = flatten(T, Accum), - [H | FlattenedT]. - -%% post: return is flattened +flatten(_L) -> + erlang:nif_error(undefined). %%----------------------------------------------------------------------------- %% @param F the function to apply to elements of L diff --git a/src/libAtomVM/nifs.c b/src/libAtomVM/nifs.c index d095fd35a..3ea1e3a3e 100644 --- a/src/libAtomVM/nifs.c +++ b/src/libAtomVM/nifs.c @@ -58,6 +58,7 @@ #include "smp.h" #include "synclist.h" #include "sys.h" +#include "tempstack.h" #include "term.h" #include "term_typedef.h" #include "unicode.h" @@ -213,6 +214,7 @@ static term nif_jit_backend_module(Context *ctx, int argc, term argv[]); static term nif_jit_variant(Context *ctx, int argc, term argv[]); #endif static term nif_lists_reverse(Context *ctx, int argc, term argv[]); +static term nif_lists_flatten(Context *ctx, int argc, term argv[]); static term nif_lists_keyfind(Context *ctx, int argc, term argv[]); static term nif_lists_keymember(Context *ctx, int argc, term argv[]); static term nif_lists_member(Context *ctx, int argc, term argv[]); @@ -827,6 +829,10 @@ static const struct Nif erlang_lists_subtract_nif = { .base.type = NIFFunctionType, .nif_ptr = nif_erlang_lists_subtract }; +static const struct Nif lists_flatten_nif = { + .base.type = NIFFunctionType, + .nif_ptr = nif_lists_flatten +}; static const struct Nif lists_member_nif = { .base.type = NIFFunctionType, .nif_ptr = nif_lists_member @@ -6125,6 +6131,119 @@ static term nif_erlang_lists_subtract(Context *ctx, int argc, term argv[]) return result; } +static term nif_lists_flatten(Context *ctx, int argc, term argv[]) +{ + UNUSED(argc) + + // Compute resulting list length, as well as the length of the reusable tail + size_t result_len = 0; + size_t tail_len = 0; + term list = argv[0]; + + if (term_is_nil(list)) { + return list; + } + + VALIDATE_VALUE(list, term_is_nonempty_list); + + struct TempStack temp_stack; + if (UNLIKELY(temp_stack_init(&temp_stack) != TempStackOk)) { + RAISE_ERROR(OUT_OF_MEMORY_ATOM); + } + if (UNLIKELY(temp_stack_push(&temp_stack, list) != TempStackOk)) { + RAISE_ERROR(OUT_OF_MEMORY_ATOM); + } + while (!temp_stack_is_empty(&temp_stack)) { + term t = temp_stack_pop(&temp_stack); + + tail_len = 0; + + do { + term t_head = term_get_list_head(t); + term t_tail = term_get_list_tail(t); + // Exit if t is not a proper list + if (!term_is_list(t_tail)) { + RAISE_ERROR(BADARG_ATOM); + } + + if (term_is_nonempty_list(t_head)) { + tail_len = 0; + if (term_is_nonempty_list(t_tail)) { + if (UNLIKELY(temp_stack_push(&temp_stack, t_tail) != TempStackOk)) { + RAISE_ERROR(OUT_OF_MEMORY_ATOM); + } + } + t = t_head; + } else { + if (term_is_nil(t_head)) { + tail_len = 0; + } else { + result_len++; + tail_len++; + } + t = t_tail; + } + } while (!term_is_nil(t)); + } + + // Allocate flattened list and build it. + if (result_len > tail_len) { + if (UNLIKELY(memory_ensure_free_with_roots(ctx, CONS_SIZE * (result_len - tail_len), 1, &list, MEMORY_CAN_SHRINK) != MEMORY_GC_OK)) { + RAISE_ERROR(OUT_OF_MEMORY_ATOM); + } + } + + term result = term_nil(); + term *prev_term = NULL; + + if (UNLIKELY(temp_stack_push(&temp_stack, list) != TempStackOk)) { + RAISE_ERROR(OUT_OF_MEMORY_ATOM); + } + while (!temp_stack_is_empty(&temp_stack)) { + term t = temp_stack_pop(&temp_stack); + + do { + term t_head = term_get_list_head(t); + term t_tail = term_get_list_tail(t); + if (term_is_nonempty_list(t_head)) { + if (term_is_nonempty_list(t_tail)) { + if (UNLIKELY(temp_stack_push(&temp_stack, t_tail) != TempStackOk)) { + RAISE_ERROR(OUT_OF_MEMORY_ATOM); + } + } + t = t_head; + } else { + if (!term_is_nil(t_head)) { + // Append the tail of original list + if (result_len == tail_len) { + if (prev_term) { + prev_term[0] = t; + } else { + result = t; + } + break; + } + + term *new_list_item = term_list_alloc(&ctx->heap); + if (prev_term) { + prev_term[0] = term_list_from_list_ptr(new_list_item); + } else { + result = term_list_from_list_ptr(new_list_item); + } + prev_term = new_list_item; + new_list_item[0] = term_nil(); + new_list_item[1] = t_head; + + result_len--; + } + t = t_tail; + } + } while (!term_is_nil(t)); + } + + return result; +} + static term nif_lists_member(Context *ctx, int argc, term argv[]) { UNUSED(argc) diff --git a/src/libAtomVM/nifs.gperf b/src/libAtomVM/nifs.gperf index a4a2591fa..5cf304b88 100644 --- a/src/libAtomVM/nifs.gperf +++ b/src/libAtomVM/nifs.gperf @@ -189,6 +189,7 @@ base64:encode/1, &base64_encode_nif base64:decode/1, &base64_decode_nif base64:encode_to_string/1, &base64_encode_to_string_nif base64:decode_to_string/1, &base64_decode_to_string_nif +lists:flatten/1, &lists_flatten_nif lists:keyfind/3, &lists_keyfind_nif lists:keymember/3, &lists_keymember_nif lists:member/2, &lists_member_nif diff --git a/tests/libs/estdlib/test_lists.erl b/tests/libs/estdlib/test_lists.erl index a088f0c66..0ce366898 100644 --- a/tests/libs/estdlib/test_lists.erl +++ b/tests/libs/estdlib/test_lists.erl @@ -236,6 +236,7 @@ test_list_match() -> test_flatten() -> ?ASSERT_MATCH(lists:flatten([]), []), + ?ASSERT_MATCH(lists:flatten([[]]), []), ?ASSERT_MATCH(lists:flatten([a]), [a]), ?ASSERT_MATCH(lists:flatten([a, []]), [a]), ?ASSERT_MATCH(lists:flatten([[[[[[[[a]]]]]]]]), [a]), @@ -248,6 +249,9 @@ test_flatten() -> lists:flatten([[a, b, c], [d, e, f], [g, h, i]]), [a, b, c, d, e, f, g, h, i] ), + ?ASSERT_ERROR(lists:flatten([7 | {}])), + ?ASSERT_ERROR(lists:flatten([[] | [5 | 5]])), + ?ASSERT_ERROR(lists:flatten([[7 | 4], 2])), ok. test_flatmap() ->