Skip to content

Commit

Permalink
Update regularizers and Noise dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
roflmaostc committed Apr 15, 2021
1 parent 108b5f0 commit 2edbebd
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DeconvOptim"
uuid = "03e7cd2f-1a03-4ea9-b59b-760a446df67f"
authors = ["Felix Wechsler <[email protected]>"]
version = "0.1.3"
version = "0.1.4"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -21,7 +21,7 @@ ChainRulesCore = "0.9"
FFTW = "1.2"
Interpolations = "0.12.10, 0.13"
LineSearches = "7.1"
Noise = "0.2.1"
Noise = "0.2.2"
Optim = "1.2, 1.3"
Requires = "1.1.0"
Tullio = "0.2.12"
Expand Down
4 changes: 2 additions & 2 deletions src/regularizer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,15 +320,15 @@ of a n-dimensional array.
"""
function TH(; num_dims=2, ϵ=1f-8)
if num_dims == 3
reg_HES = x -> @tullio res = sqrt(1f-8 + abs2(x[i+1,j,k] + x[i-1,j,k] - 2* x[i,j,k]) +
reg_HES = x -> @tullio res = sqrt(ϵ + abs2(x[i+1,j,k] + x[i-1,j,k] - 2* x[i,j,k]) +
abs2(x[i,j+1,k] + x[i,j-1,k] - 2* x[i,j,k]) +
abs2(x[i,j,k+1] + x[i,j,k-1] - 2* x[i,j,k]) +
2 * abs2(x[i+1,j+1,k] - x[i+1,j,k] - x[i,j+1,k] + x[i, j,k]) +
2 * abs2(x[i+1,j,k+1] - x[i+1,j,k] - x[i,j,k+1] + x[i, j,k]) +
2 * abs2(x[i,j+1,k+1] - x[i,j,k+1] - x[i,j,k+1] + x[i, j,k]))
return reg_HES
elseif num_dims == 2
reg_HES = x -> @tullio res = sqrt(1f-8 + abs2(x[i+1, j] + x[i-1, j] - 2* x[i, j]) +
reg_HES = x -> @tullio res = sqrt(ϵ + abs2(x[i+1, j] + x[i-1, j] - 2* x[i, j]) +
abs2(x[i,j+1] + x[i, j-1] - 2* x[i, j]) +
2 * abs2(x[i+1, j+1] - x[i+1, j] - x[i, j+1] + x[i, j]))
return reg_HES
Expand Down
15 changes: 8 additions & 7 deletions src/regularizer_cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,35 +18,36 @@ julia> reg(CuArray([1 2 3; 4 5 6; 7 8 9]))
12.649111f0
```
"""
function TV_cuda(; num_dims=2)
function TV_cuda(; num_dims=2, weights=ones(Float32, num_dims), ϵ=1f-8)
if num_dims == 3
return TV_3D_view
return arr -> TV_3D_view(arr, weights, ϵ)
elseif num_dims == 2
return TV_2D_view
return arr -> TV_2D_view(arr, weights, ϵ)
else
throw(ArgumentError("num_dims must be 2 or 3"))
end

return reg_TV
end

function TV_2D_view(arr::AbstractArray{T, N}) where {T, N}
function TV_2D_view(arr::AbstractArray{T, N}, weights, ϵ=1f-8) where {T, N}
as = ntuple(i -> axes(arr, i), Val(N))
rs = map(x -> first(x):last(x)-1, as)
arr0 = view(arr, f_inds(rs, 0)...)
arr1 = view(arr, f_inds(rs, 1)...)
arr2 = view(arr, f_inds(rs, 2)...)

return @fastmath sum(sqrt.(1f-8 .+ (arr1 .- arr0).^2 .+ (arr0 .- arr2).^2))
return @fastmath sum(sqrt.(ϵ .+ weights[1] .* (arr1 .- arr0).^2 .+ weights[2] .* (arr0 .- arr2).^2))
end

function TV_3D_view(arr::AbstractArray{T, N}) where {T, N}
function TV_3D_view(arr::AbstractArray{T, N}, weights, ϵ=1f-8) where {T, N}
as = ntuple(i -> axes(arr, i), Val(N))
rs = map(x -> first(x):last(x)-1, as)
arr0 = view(arr, f_inds(rs, 0)...)
arr1 = view(arr, f_inds(rs, 1)...)
arr2 = view(arr, f_inds(rs, 2)...)
arr3 = view(arr, f_inds(rs, 3)...)

return @fastmath sum(sqrt.(1f-8 .+ (arr3 .- arr0).^2 .+ (arr1 .- arr0).^2 .+ (arr0 .- arr2).^2))
return @fastmath sum(sqrt.(ϵ .+ weights[1] .* (arr1 .- arr0).^2 .+
weights[2] .* (arr2 .- arr1).^2 .+ weights[3] .* (arr3 .- arr0).^2 ))
end

2 comments on commit 2edbebd

@roflmaostc
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/35016

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.4 -m "<description of version>" 2edbebd56183511a6e07564eb3965b63ce632e44
git push origin v0.1.4

Please sign in to comment.