diff --git a/src/solvers/iterators.jl b/src/solvers/iterators.jl index 990d86d2..1fe48449 100644 --- a/src/solvers/iterators.jl +++ b/src/solvers/iterators.jl @@ -12,7 +12,10 @@ abstract type AbstractNetworkIterator end islaststep(iterator::AbstractNetworkIterator) = state(iterator) >= length(iterator) function Base.iterate(iterator::AbstractNetworkIterator, init = true) - islaststep(iterator) && return nothing + # The assumption is that first "increment!" is implicit, therefore we must skip the + # the termination check for the first iteration, i.e. `AbstractNetworkIterator` is not + # defined when length < 1, + init || islaststep(iterator) && return nothing # We seperate increment! from step! and demand that any AbstractNetworkIterator *must* # define a method for increment! This way we avoid cases where one may wish to nest # calls to different step! methods accidentaly incrementing multiple times. @@ -44,6 +47,9 @@ mutable struct RegionIterator{Problem, RegionPlan} <: AbstractNetworkIterator which_region::Int const which_sweep::Int function RegionIterator(problem::P, region_plan::R, sweep::Int) where {P, R} + if length(region_plan) == 0 + throw(BoundsError("Cannot construct a region iterator with 0 elements.")) + end return new{P, R}(problem, region_plan, 1, sweep) end end @@ -119,8 +125,15 @@ mutable struct SweepIterator{Problem, Iter} <: AbstractNetworkIterator which_sweep::Int function SweepIterator(problem::Prob, sweep_kwargs::Iter) where {Prob, Iter} stateful_sweep_kwargs = Iterators.Stateful(sweep_kwargs) - first_kwargs, _ = Iterators.peel(stateful_sweep_kwargs) + first_state = Iterators.peel(stateful_sweep_kwargs) + + if isnothing(first_state) + throw(BoundsError("Cannot construct a sweep iterator with 0 elements.")) + end + + first_kwargs, _ = first_state region_iter = RegionIterator(problem; sweep = 1, first_kwargs...) + return new{Prob, Iter}(region_iter, stateful_sweep_kwargs, 1) end end diff --git a/test/solvers/test_iterators.jl b/test/solvers/test_iterators.jl index 730eee93..b73e2189 100644 --- a/test/solvers/test_iterators.jl +++ b/test/solvers/test_iterators.jl @@ -1,5 +1,5 @@ -using Test: @test, @testset -using ITensorNetworks: SweepIterator, islaststep, state, increment!, compute!, eachregion +using Test: @test, @testset, @test_throws +using ITensorNetworks: SweepIterator, RegionIterator, islaststep, state, increment!, compute!, eachregion module TestIteratorUtils @@ -49,6 +49,24 @@ end import .TestIteratorUtils @testset "`AbstractNetworkIterator` Interface" begin + + @testset "Edge cases" begin + TI = TestIteratorUtils.TestIterator(1, 1, []) + cb = [] + @test islaststep(TI) + for _ in TI + @test islaststep(TI) + push!(cb, state(TI)) + end + @test length(cb) == 1 + @test length(TI.output) == 1 + @test only(cb) == 1 + + prob = TestIteratorUtils.TestProblem([]) + @test_throws BoundsError SweepIterator(prob, 0) + @test_throws BoundsError RegionIterator(prob, [], 1) + end + TI = TestIteratorUtils.TestIterator(1, 4, []) @test !islaststep((TI)) @@ -171,6 +189,17 @@ end @test prob.data[1:2:end] == fill(1, 5) @test prob.data[2:2:end] == fill(2, 5) + + let i = 1, prob = TestIteratorUtils.TestProblem([]) + SI = SweepIterator(prob, 1) + cb = [] + for _ in eachregion(SI) + push!(cb, i) + i += 1 + end + @test length(cb) == 2 + end + end end end