diff --git a/sorts/shell_sort.py b/sorts/shell_sort.py index b65609c974b7..43eee0b8d090 100644 --- a/sorts/shell_sort.py +++ b/sorts/shell_sort.py @@ -1,40 +1,50 @@ """ -https://en.wikipedia.org/wiki/Shellsort#Pseudocode +Shell Sort Algorithm +-------------------- + +Issue: #13887 +Implements the Shell Sort algorithm which is a generalization of insertion sort. +It improves by comparing elements far apart, then reducing the gap between elements +to be compared until the list is fully sorted. + +Time Complexity: + Worst case: O(n^2) + Best case: O(n log n) + Average: O(n^(3/2)) + +Space Complexity: O(1) """ +from __future__ import annotations -def shell_sort(collection: list[int]) -> list[int]: - """Pure implementation of shell sort algorithm in Python - :param collection: Some mutable ordered collection with heterogeneous - comparable items inside - :return: the same collection ordered by ascending - >>> shell_sort([0, 5, 3, 2, 2]) - [0, 2, 2, 3, 5] +def shell_sort(arr: list[int]) -> list[int]: + """ + Sorts the given list using Shell Sort and returns the sorted list. + + >>> shell_sort([5, 2, 9, 1]) + [1, 2, 5, 9] >>> shell_sort([]) [] - >>> shell_sort([-2, -5, -45]) - [-45, -5, -2] + >>> shell_sort([3]) + [3] + >>> shell_sort([1, 2, 3]) + [1, 2, 3] + >>> shell_sort([4, 3, 3, 1]) + [1, 3, 3, 4] """ - # Marcin Ciura's gap sequence + n = len(arr) + gap = n // 2 - gaps = [701, 301, 132, 57, 23, 10, 4, 1] - for gap in gaps: - for i in range(gap, len(collection)): - insert_value = collection[i] + # Keep reducing the gap until it becomes 0 + while gap > 0: + for i in range(gap, n): + temp = arr[i] j = i - while j >= gap and collection[j - gap] > insert_value: - collection[j] = collection[j - gap] + while j >= gap and arr[j - gap] > temp: + arr[j] = arr[j - gap] j -= gap - if j != i: - collection[j] = insert_value - return collection - - -if __name__ == "__main__": - from doctest import testmod + arr[j] = temp + gap //= 2 - testmod() - user_input = input("Enter numbers separated by a comma:\n").strip() - unsorted = [int(item) for item in user_input.split(",")] - print(shell_sort(unsorted)) + return arr diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/sorts/__init__.py b/tests/sorts/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/sorts/test_shell_sort.py b/tests/sorts/test_shell_sort.py new file mode 100644 index 000000000000..9dc612ff2741 --- /dev/null +++ b/tests/sorts/test_shell_sort.py @@ -0,0 +1,21 @@ +from sorts.shell_sort import shell_sort + + +def test_shell_sort_basic(): + assert shell_sort([5, 2, 9, 1]) == [1, 2, 5, 9] + + +def test_shell_sort_empty(): + assert shell_sort([]) == [] + + +def test_shell_sort_one_element(): + assert shell_sort([3]) == [3] + + +def test_shell_sort_sorted(): + assert shell_sort([1, 2, 3, 4]) == [1, 2, 3, 4] + + +def test_shell_sort_duplicates(): + assert shell_sort([4, 3, 3, 1]) == [1, 3, 3, 4]