Skip to content

Commit 3d7342b

Browse files
authored
Fix contract(::TTN, ::TTN) tests by fixing bug in inner(::TTN, ::TTN) (#52)
1 parent 97ce278 commit 3d7342b

File tree

3 files changed

+19
-22
lines changed

3 files changed

+19
-22
lines changed

src/treetensornetworks/abstracttreetensornetwork.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ function inner(
154154
ψ::AbstractTTN;
155155
root_vertex=default_root_vertex(ϕ, ψ),
156156
)
157-
ϕᴴ = sim(dag(ψ); sites=[])
157+
ϕᴴ = sim(dag(ϕ); sites=[])
158158
ψ = sim(ψ; sites=[])
159159
ϕψ = ϕᴴ ψ
160160
# TODO: find the largest tensor and use it as

src/treetensornetworks/solvers/contract.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ function contract(
1818
nsweeps=1,
1919
kwargs...,
2020
)
21-
@warn """`contract(::AbstractTTN, ::AbstractTTN; alg="fit")` is currently broken, you will likely get incorrect results."""
2221
n = nv(tn1)
2322
n != nv(tn2) && throw(
2423
DimensionMismatch("Number of sites operator ($n) and state ($(nv(tn2))) do not match"),

test/test_treetensornetworks/test_solvers/test_contract.jl

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ using Test
2222
H = mpo(os, s)
2323

2424
# Test basic usage with default parameters
25-
Hpsi = apply(H, psi; alg="fit")
26-
@test_broken inner(psi, Hpsi) inner(psi', H, psi) atol = 1E-5
25+
Hpsi = apply(H, psi; alg="fit", init=psi')
26+
@test inner(psi, Hpsi) inner(psi', H, psi) atol = 1E-5
2727

2828
#
2929
# Change "top" indices of MPO to be a different set
@@ -35,21 +35,19 @@ using Test
3535
psit[j] *= delta(s[j], t[j])
3636
end
3737

38-
# Test with nsweeps=2
39-
Hpsi = apply(H, psi; alg="fit", nsweeps=2)
40-
@test_broken inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5
38+
# Test with nsweeps=3
39+
Hpsi = apply(H, psi; alg="fit", nsweeps=3)
40+
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5
4141

4242
# Test with less good initial guess MPS not equal to psi
43-
psi_guess = copy(psi)
44-
psi_guess = truncate(psi_guess; maxdim=2)
43+
psi_guess = truncate(psi; maxdim=2)
4544
Hpsi = apply(H, psi; alg="fit", nsweeps=4, init_state=psi_guess)
46-
@test_broken inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5
45+
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5
4746

4847
# Test with nsite=1
49-
Hpsi_guess = @test_broken apply(H, psi; alg="naive", cutoff=1E-4)
50-
# Hpsi = apply(H, psi; alg="fit", init_state=Hpsi_guess, nsite=1, nsweeps=2)
51-
Hpsi = apply(H, psi; alg="fit", nsite=1, nsweeps=2)
52-
@test_broken inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-4
48+
Hpsi_guess = random_mps(t; internal_inds_space=32)
49+
Hpsi = apply(H, psi; alg="fit", init=Hpsi_guess, nsite=1, nsweeps=4)
50+
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-4
5351
end
5452

5553
@testset "Contract TTN" begin
@@ -65,7 +63,7 @@ end
6563

6664
# Test basic usage with default parameters
6765
Hpsi = apply(H, psi; alg="fit")
68-
@test_broken inner(psi, Hpsi) inner(psi', H, psi) atol = 1E-5 # broken when rebasing local draft on remote main, fix this
66+
@test inner(psi, Hpsi) inner(psi', H, psi) atol = 1E-5
6967

7068
#
7169
# Change "top" indices of TTN to be a different set
@@ -77,17 +75,17 @@ end
7775

7876
# Test with nsweeps=2
7977
Hpsi = apply(H, psi; alg="fit", nsweeps=2)
80-
@test_broken inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5 # broken when rebasing local draft on remote main, fix this
78+
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5
8179

8280
# Test with less good initial guess MPS not equal to psi
83-
psi_guess = copy(psi)
84-
psi_guess = truncate(psi_guess; maxdim=2)
85-
Hpsi = apply(H, psi; alg="fit", nsweeps=4, init_state=psi_guess)
86-
@test_broken inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5 # broken when rebasing local draft on remote main, fix this
81+
Hpsi_guess = truncate(psit; maxdim=2)
82+
Hpsi = apply(H, psi; alg="fit", nsweeps=4, init=Hpsi_guess)
83+
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-5
8784

8885
# Test with nsite=1
89-
Hpsi = apply(H, psi; alg="fit", nsite=1, nsweeps=2)
90-
@test_broken inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-4 # broken when rebasing local draft on remote main, fix this
86+
Hpsi_guess = random_ttn(t; link_space=4)
87+
Hpsi = apply(H, psi; alg="fit", nsite=1, nsweeps=4, init=Hpsi_guess)
88+
@test inner(psit, Hpsi) inner(psit, H, psi) atol = 1E-4
9189
end
9290

9391
nothing

0 commit comments

Comments
 (0)