Skip to content

Commit 61425ed

Browse files
Cleaned up GHWT_tf_1d.jl
1 parent efbf53d commit 61425ed

File tree

1 file changed

+18
-201
lines changed

1 file changed

+18
-201
lines changed

src/GHWT_tf_1d.jl

Lines changed: 18 additions & 201 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ using ..GraphSignal, ..GraphPartition, ..BasisSpecification, LinearAlgebra
66

77
include("common.jl")
88

9-
export ghwt_tf_bestbasis, tf_threshold, tf_synthesis
9+
export ghwt_tf_bestbasis
1010

1111

1212
"""
13-
coeffdict = tf_init(dmatrix::Matrix{Float64},GP::GraphPart)
13+
coeffdict = tf_init(dmatrix::Matrix{Float64},GP::GraphPart)
1414
1515
Store the expanding coeffcients from matrix into a list of dictionary (inbuilt hashmap in Julia)
1616
@@ -47,18 +47,11 @@ function tf_init(dmatrix::Matrix{Float64},GP::GraphPart)
4747
end
4848

4949

50-
51-
52-
5350
"""
54-
coeffdict_new,tag_tf = tf_core_new(coeffdict::Array{Dict{Tuple{Int,Int},Float64},1})
55-
56-
57-
One forward iteration of time-frequency adapted GHWT method. For each entry in `coeffdict_new`, we compare
58-
two (or one) entries in 'coeffdict' on time-direction and two (or one) entries in 'coeffdict' on frequency-direction.
59-
Those two groups reprensent the same subspace. We compare the cost-functional value of them and choose the smaller one
60-
as a new entry in 'coeffdict_new'.
51+
coeffdict_new,tag_tf = tf_core_new(coeffdict::Array{Dict{Tuple{Int,Int},Float64},1})
6152
53+
One forward iteration of time-frequency adapted GHWT method. For each entry in `coeffdict_new`, we compare two (or one) entries in 'coeffdict' on time-direction and two (or one) entries in 'coeffdict' on frequency-direction.
54+
Those two groups reprensent the same subspace. We compare the cost-functional value of them and choose the smaller one as a new entry in 'coeffdict_new'.
6255
6356
### Input Arguments
6457
* `coeffdict`: The entries of which reprensents the cost functional value of some basis-vectors' coefficients.
@@ -67,7 +60,6 @@ as a new entry in 'coeffdict_new'.
6760
* `coeffdict_new`: The entries of which represents the cost functional value of some basis-vectors' coefficients
6861
* `tag_tf`: Indicating whether the time-direction (0) or frequency direction (1) was chosen for each entry in coeffdict_new.
6962
70-
7163
Copyright 2018 The Regents of the University of California
7264
7365
Implemented by Yiqun Shao (Adviser: Dr. Naoki Saito)
@@ -134,25 +126,18 @@ function tf_core_new(coeffdict::Array{Dict{Tuple{Int,Int},Float64},1})
134126
end
135127

136128

137-
138-
139-
140-
141129
"""
142130
tag_tf_b_new = tf_basisrecover_new(tag_tf_b::Array{Dict{Tuple{Int,Int},Bool}},tag_tf_f::Array{Dict{Tuple{Int,Int},Bool}})
143131
144-
145-
One backward iteration of time-frequency adapted GHWT method to recover the best-basis from the `tag_tf`s recorded.
132+
One backward iteration of time-frequency adapted GHWT method to recover the best-basis from the `tag_tf`s recorded.
146133
147134
### Input Arguments
148135
* `tag_tf_b`: The `dictionary` recording the time-or-frequency information on some iteration 'i' in the main algorithm
149136
* `tag_tf_f`: The `dictionary` recording the time-or-frequency information on some iteration 'i+1' in the main algorithm
150137
151-
152138
### Output Arguments
153139
* `tag_tf_b_new`: The updated 'tag_tf_b'. Eventually the 'tag_tf' on iteration 1 will represent the selected best-basis
154140
155-
156141
Copyright 2018 The Regents of the University of California
157142
158143
Implemented by Yiqun Shao (Adviser: Dr. Naoki Saito)
@@ -183,7 +168,7 @@ function tf_basisrecover_new(tag_tf_b::Array{Dict{Tuple{Int,Int},Bool}},tag_tf_f
183168
tag_tf_b_new[j][(k,2*l+1)] = tag_tf_b[j][(k,2*l+1)]
184169
end
185170
else
186-
# The entries on time-direction are selected
171+
# The entries on time-direction are selected
187172
if ~haskey(tag_tf_b[j+1],(2*k,l))
188173
tag_tf_b_new[j+1][(2*k+1,l)] = tag_tf_b[j+1][(2*k+1,l)]
189174
elseif ~haskey(tag_tf_b[j+1],(2*k+1,l))
@@ -200,27 +185,27 @@ end
200185

201186

202187
"""
203-
bestbasis_tag_matrix, bestbasis = ghwt_tf_bestbasis(dmatrix::Matrix{Float64},GP::GraphPart)
188+
(dvec, BS) = ghwt_tf_bestbasis(dmatrix::Array{Float64,3}, GP::GraphPart; cfspec::Float64 = 1.0, flatten::Any = 1.0)
204189
205-
Implementation of time-frequency adapted GHWT method.
206-
Modified from the algorithm in paper 'A Fast Algorithm for Adapted Time Frequency Tilings' by Christoph M Thiele and Lars F Villemoes.
190+
Implementation of time-frequency adapted GHWT method = eGHWT.
191+
Modified from the algorithm in the paper: "A Fast Algorithm for Adapted Time Frequency Tilings" by Christoph M. Thiele and Lars F. Villemoes.
207192
208193
### Input Arguments
209-
### Input Arguments
210-
* `dmatrix`: The expanding GHWT coefficients of all levels corresponding to input GP.
211-
* `GP::GraphPart`: an input GraphPart object.
194+
* `dmatrix::Array{Float64,3}`: the matrix of expansion coefficients
195+
* `GP::GraphPart`: an input GraphPart object
196+
* `cfspec::Any`: the specification of cost functional to be used (default: 1.0, i.e., 1-norm)
197+
* `flatten::Any`: the method for flattening vector-valued data to scalar-valued data (default: 1.0, i.e., 1-norm)
212198
213199
### Output Arguments
214-
* `bestbasis_tag_matrix`: binary 0-1 matrix indicating the location of best-basis in dmatrix
215-
* `bestbasis`: same size as dmatrix, but only coefficients of best-basis vectors are nonzero
200+
* `dvec::Matrix{Float64}`: the vector of expansion coefficients corresponding to the eGHWT best basis
201+
* `BS::BasisSpec`: a BasisSpec object which specifies the eGHWT best basis
216202
217203
Copyright 2018 The Regents of the University of California
218204
219205
Implemented by Yiqun Shao (Adviser: Dr. Naoki Saito)
220206
"""
221207
function ghwt_tf_bestbasis(dmatrix::Array{Float64,3}, GP::GraphPart; cfspec::Float64 = 1.0, flatten::Any = 1.0)
222208

223-
224209
# determine the cost functional to be used
225210
costfun = cost_functional(cfspec)
226211

@@ -294,174 +279,6 @@ function ghwt_tf_bestbasis(dmatrix::Array{Float64,3}, GP::GraphPart; cfspec::Flo
294279
BS = BasisSpec(levlist, c2f = true, description = "eGHWT Best Basis")
295280
dvec = dmatrix2dvec(dmatrix0, GP, BS)
296281
return dvec, BS
297-
end
298-
299-
300-
"""
301-
bestbasis_new = tf_threshold(bestbasis::Matrix{Float64}, GP::GraphPart, keep::Float64, SORH::String)
302-
303-
Thresholding the coefficients of bestbasis.
304-
305-
### Input Arguments
306-
* `bestbasis::Matrix{Float64}` the matrix of expansion coefficients
307-
* `SORH::String` use soft ('s') or hard ('h') thresholding
308-
* `keep::Float64` a fraction between 0 and 1 which says how many coefficients should be kept
309-
* `GP::GraphPart` a GraphPart object, used to identify scaling coefficients
310-
311-
### Output Argument
312-
* `bestbasis_new::Matrix{Float64}` the thresholded expansion coefficients
313-
"""
314-
function tf_threshold(bestbasis::Matrix{Float64}, GP::GraphPart, keep::Float64, SORH::String)
315-
316-
tag = GP.tag
317-
if keep > 1 || keep < 0
318-
error("keep should be floating point between 0~1")
319-
end
320-
kept = UInt32(round(keep*size(bestbasis,1)))
321-
dvec_S = sort(abs.(bestbasis[:]), rev = true)
322-
T = dvec_S[kept + 1]
323-
bestbasis_new = deepcopy(bestbasis[:])
324-
indp = bestbasis_new .> T #index for coefficients > T
325-
indn = bestbasis_new .< -1*T #index for coefficients < -T
326-
327-
# hard thresholding
328-
if SORH == "h" || SORH == "hard"
329-
bestbasis_new[.~(indp .| indn)] .= 0
330-
331-
# soft thresholding
332-
elseif SORH == "s" || SORH == "soft"
333-
bestbasis_new[(.~indp) .& (.~indn) .& (tag[:].!=0)] .= 0
334-
bestbasis_new[indp .& (tag[:].!=0)] = bestbasis[indp .& (tag[:].!=0)] .- T
335-
bestbasis_new[indn .& (tag[:].!=0)] = bestbasis[indn .& (tag[:].!=0)] .+ T
336-
end
337-
338-
bestbasis_new = reshape(bestbasis_new,size(bestbasis))
339-
end
282+
end # of function ghwt_tf_bestbasis
340283

341-
342-
343-
"""
344-
(f, GS) = tf_synthesis(bestbasis::Matrix{Float64},bestbasis_tag::Matrix{Int},GP::GraphPart,G::GraphSig)
345-
346-
Given a vector of GHWT expansion coefficients and info about the graph
347-
partitioning and the choice of basis, reconstruct the signal
348-
349-
### Input Arguments
350-
* `bestbasis::Matrix{Float64}`: the expansion coefficients corresponding to the chosen basis
351-
* 'bestbasis_tag::Matrix{Int}': the location of the best basis coefficients in bestbasis matrix
352-
* `GP::GraphPart`: an input GraphPart object
353-
* `G::GraphSig`: an input GraphSig object
354-
355-
### Output Arguments
356-
* `f::Matrix{Float64}`: the reconstructed signal(s)
357-
* `GS::GraphSig`: the reconstructed GraphSig object
358-
"""
359-
function tf_synthesis(bestbasis::Matrix{Float64},bestbasis_tag::Matrix{Int},GP::GraphPart,G::GraphSig)
360-
tag = GP.tag
361-
rs = GP.rs
362-
bestbasis_new = deepcopy(bestbasis)
363-
bestbasis_tag = deepcopy(bestbasis_tag)
364-
jmax = size(rs,2)
365-
for j = 1:(jmax-1)
366-
regioncount = count(!iszero, rs[:,j]) - 1
367-
for r = 1:regioncount
368-
# the index that marks the start of the first subregion
369-
rs1 = rs[r,j]
370-
371-
# the index that is one after the end of the second subregion
372-
rs3 = rs[r+1,j]
373-
374-
# the number of points in the current region
375-
n = rs3 - rs1
376-
377-
# only proceed forward if the coefficients do not exist
378-
if count(!iszero, bestbasis_tag[rs1:(rs3-1),j]) !=0
379-
if n == 1
380-
# scaling coefficient
381-
if bestbasis_tag[rs1,j] == 1 # check ind
382-
bestbasis_new[rs1,j+1] = bestbasis_new[rs1,j]
383-
bestbasis_tag[rs1,j+1] = 1
384-
end
385-
elseif n > 1
386-
# the index that marks the start of the second subregion
387-
rs2 = rs1 + 1
388-
while rs2 < rs3 && tag[rs2, j+1] != 0
389-
rs2 = rs2 +1
390-
end
391-
392-
# only one child
393-
if rs2 == rs3
394-
if bestbasis_tag[rs1:rs3-1,j] == 1
395-
bestbasis_new[rs1:rs3-1,j+1] = bestbasis_new[rs1:rs3-1,j]
396-
bestbasis_tag[rs1:rs3-1,j+1] = 1
397-
end
398-
399-
else
400-
401-
# the number of points in the first subregion
402-
n1 = rs2-rs1
403-
# the number of points in the second subregion
404-
n2 = rs3-rs2
405-
406-
# scaling coefficients
407-
if bestbasis_tag[rs1,j] == 1 && bestbasis_tag[rs1+1,j] == 1 # check if it is the coefficients of best basis
408-
bestbasis_new[rs1,j+1] = (sqrt(n1) *bestbasis_new[rs1,j] + sqrt(n2)*bestbasis_new[rs1+1,j])/sqrt(n)
409-
bestbasis_new[rs2,j+1] = (sqrt(n2) *bestbasis_new[rs1,j] - sqrt(n1)*bestbasis_new[rs1+1,j])/sqrt(n)
410-
bestbasis_tag[rs1,j+1] = 1
411-
bestbasis_tag[rs2,j+1] = 1
412-
end
413-
414-
### HAAR-LIKE & WALSH-LIKE coefficients
415-
416-
# search through the remaining coefficients in each subregion
417-
parent = rs1 + 2
418-
child1 = rs1 + 1
419-
child2 = rs2 + 1
420-
while child1 < rs2 || child2 < rs3
421-
# subregion 1 has the smaller tag
422-
if child2 == rs3 || (tag[child1,j+1] < tag[child2, j+1] && child1 < rs2)
423-
if bestbasis_tag[parent,j]==1 # check if it is the coefficients of best basis
424-
bestbasis_new[child1,j+1] = bestbasis_new[parent,j]
425-
bestbasis_tag[child1, j+1] =1
426-
end
427-
child1 = child1 + 1
428-
parent = parent + 1
429-
430-
# subregion 2 has the smaller tag
431-
elseif child1 == rs2 || (tag[child2, j+1] < tag[child1, j+1] && child2 < rs3)
432-
if bestbasis_tag[parent, j] == 1 # check if it is the coefficients of best basis
433-
bestbasis_new[child2, j+1] = bestbasis_new[parent, j]
434-
bestbasis_tag[child2, j+1] = 1
435-
end
436-
child2 = child2 + 1
437-
parent = parent + 1
438-
439-
# both subregions have the same tag
440-
else
441-
if bestbasis_tag[parent,j] == 1 && bestbasis_tag[parent+1, j] == 1 # check if it is the coefficients of best basis
442-
bestbasis_new[child1,j+1] = (bestbasis_new[parent,j] + bestbasis_new[parent+1,j])/sqrt(2)
443-
bestbasis_new[child2,j+1] = (bestbasis_new[parent,j] - bestbasis_new[parent+1,j])/sqrt(2)
444-
bestbasis_tag[child1,j+1]=1
445-
bestbasis_tag[child2,j+1]=1
446-
end
447-
child1 = child1 + 1
448-
child2 = child2 + 1
449-
parent = parent + 2
450-
end
451-
end
452-
end
453-
end
454-
end
455-
end
456-
end
457-
ftemp = bestbasis_new[:,end]
458-
459-
f = zeros(size(ftemp))
460-
f[GP.ind] = ftemp
461-
f = reshape(f,(length(f),1)) # reorder f
462-
GS = deepcopy(G)
463-
replace_data!(GS, f) # create the new graph signal
464-
return f, GS
465-
end
466-
467-
end
284+
end # of module GHWT_tf_1d

0 commit comments

Comments
 (0)