diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 0000000..580b751 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style = "sciml" diff --git a/.gitignore b/.gitignore index e7625df..1209b4f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,2 @@ -<<<<<<< HEAD -/Manifest.toml -docs/build -======= Manifest.toml ->>>>>>> docs +docs/build diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..a12595d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: + - repo: https://github.com/qiaojunfeng/pre-commit-julia-format + rev: v0.2.0 # use the most recent version + hooks: + - id: julia-format # formatter for Julia code + args: [--project=.] diff --git a/docs/Manifest.toml b/docs/Manifest.toml deleted file mode 100644 index e3ab17a..0000000 --- a/docs/Manifest.toml +++ /dev/null @@ -1,265 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.10.4" -manifest_format = "2.0" -project_hash = "e0c77beb18dc1f6cce661ebd60658c0c1a77390f" - -[[deps.ANSIColoredPrinters]] -git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" -uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9" -version = "0.0.1" - -[[deps.AbstractTrees]] -git-tree-sha1 = "2d9c9a55f9c93e8887ad391fbae72f8ef55e1177" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.4.5" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "bce6804e5e6044c6daab27bb533d1295e4a2e759" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.6" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Documenter]] -deps = ["ANSIColoredPrinters", "AbstractTrees", "Base64", "CodecZlib", "Dates", "DocStringExtensions", "Downloads", "Git", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "MarkdownAST", "Pkg", "PrecompileTools", "REPL", "RegistryInstances", "SHA", "TOML", "Test", "Unicode"] -git-tree-sha1 = "76deb8c15f37a3853f13ea2226b8f2577652de05" -uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "1.5.0" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.Expat_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "1c6317308b9dc757616f0b5cb379db10494443a7" -uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.6.2+0" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.Git]] -deps = ["Git_jll"] -git-tree-sha1 = "04eff47b1354d702c3a85e8ab23d539bb7d5957e" -uuid = "d7ba0133-e1db-5d97-8f8c-041e4b3a1eb2" -version = "1.3.1" - -[[deps.Git_jll]] -deps = ["Artifacts", "Expat_jll", "JLLWrappers", "LibCURL_jll", "Libdl", "Libiconv_jll", "OpenSSL_jll", "PCRE2_jll", "Zlib_jll"] -git-tree-sha1 = "d18fb8a1f3609361ebda9bf029b60fd0f120c809" -uuid = "f8c6e375-362e-5223-8a59-34ff63f689eb" -version = "2.44.0+2" - -[[deps.IOCapture]] -deps = ["Logging", "Random"] -git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770" -uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.5" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.JLLWrappers]] -deps = ["Artifacts", "Preferences"] -git-tree-sha1 = "7e5d6779a1e09a36db2a7b6cff50942a0a7d0fca" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.5.0" - -[[deps.JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.4" - -[[deps.LazilyInitializedFields]] -git-tree-sha1 = "8f7f3cabab0fd1800699663533b6d5cb3fc0e612" -uuid = "0e77f7df-68c5-4e49-93ce-4cd80f5598bf" -version = "1.2.2" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.4" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "8.4.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "LibGit2_jll", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibGit2_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll"] -uuid = "e37daf67-58a4-590a-8e99-b0245dd2ffc5" -version = "1.6.4+0" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.11.0+1" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.Libiconv_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "f9557a255370125b405568f9767d6d195822a175" -uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.17.0+0" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.MarkdownAST]] -deps = ["AbstractTrees", "Markdown"] -git-tree-sha1 = "465a70f0fc7d443a00dcdc3267a497397b8a3899" -uuid = "d0879d2d-cac2-40c8-9cee-1863dc0c7391" -version = "0.1.2" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+1" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2023.1.10" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "a028ee3cb5641cccc4c24e90c36b0a4f7707bdf5" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.14+0" - -[[deps.PCRE2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" -version = "10.42.0+1" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "8489905bcdbcfac64d1daa51ca07c0d8f0283821" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.8.1" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.10.0" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "5aa36f7049a63a1528fe8f7c3f2113413ffd4e1f" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.2.1" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "9306f6085165d270f7e3db02af26a400d580f5c6" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.3" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.RegistryInstances]] -deps = ["LazilyInitializedFields", "Pkg", "TOML", "Tar"] -git-tree-sha1 = "ffd19052caf598b8653b99404058fce14828be51" -uuid = "2792f1a3-b283-48e8-9a74-f99dce5104f3" -version = "0.1.0" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TranscodingStreams]] -git-tree-sha1 = "e84b3a11b9bece70d14cce63406bbc79ed3464d2" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.11.2" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+1" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.52.0+1" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+2" diff --git a/docs/make.jl b/docs/make.jl index 9755d10..5230522 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,22 +1,20 @@ using Documenter, e3nn mathengine = MathJax3() +prettyurls = get(ENV, "CI", nothing) == "true" makedocs( modules = [e3nn], doctest = false, clean = true, - sitename="e3nn.jl", - format = Documenter.HTML( - canonical = "https://dsantra92.github.io/e3nn.jl/stable/", + sitename = "e3nn.jl", + format = Documenter.HTML(; mathengine, + prettyurls, assets = ["assets/favicon.ico"], - prettyurls = get(ENV, "CI", nothing) == "true", - collapselevel=3, + collapselevel = 3 ), - pages = ["Home" => "Index.md", - "API Reference" => [ - "o3" => "api/o3.md", - ], - ] - ) + pages = ["Home" => "index.md", "API Reference" => ["o3" => "api/o3.md"]] +) + +deploydocs(repo = "github.com/Dsantra92/e3nn.jl.git") diff --git a/docs/src/Index.md b/docs/src/Index.md deleted file mode 100644 index e25d663..0000000 --- a/docs/src/Index.md +++ /dev/null @@ -1,3 +0,0 @@ -# e3nn - -Julia implementation of e3nn python library. diff --git a/docs/src/index.md b/docs/src/index.md index 48b48b0..e25d663 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -1,3 +1,3 @@ -# e3nn.jl +# e3nn -E(3) equivariant neural networks in Julia. \ No newline at end of file +Julia implementation of e3nn python library. diff --git a/src/e3nn.jl b/src/e3nn.jl index 4c9e6ff..f25c53a 100644 --- a/src/e3nn.jl +++ b/src/e3nn.jl @@ -3,5 +3,4 @@ module e3nn # o3 include("o3/o3.jl") - end diff --git a/src/o3/irreps.jl b/src/o3/irreps.jl index 158de9b..a436953 100644 --- a/src/o3/irreps.jl +++ b/src/o3/irreps.jl @@ -11,10 +11,10 @@ struct Irrep end end -function Irrep(l::T) where {T<:AbstractString} +function Irrep(l::T) where {T <: AbstractString} name = strip(l) try - l = parse(Int, name[1:end-1]) + l = parse(Int, name[1:(end - 1)]) (l >= 0) || throw(ArgumentError("l must be zero or positive integer, got $l")) p = Dict('e' => 1, 'o' => -1, 'y' => (-1)^l)[name[end]] return Irrep(l, p) @@ -41,7 +41,7 @@ function Base.:*(x1::Irrep, x2::Irrep) p = x1.p * x2.p lmin = abs(x1.l - x2.l) lmax = x1.l + x2.l - return (Irrep(l, p) for l = lmin:lmax) + return (Irrep(l, p) for l in lmin:lmax) end function Base.:*(i::Int, x::Irrep) @@ -96,7 +96,7 @@ end Irreps(irrep::Irrep) = Irreps([MulIrrep(1, irrep)]) Irreps(irreps::Irreps) = irreps -function Irreps(irreps::T) where {T<:AbstractString} +function Irreps(irreps::T) where {T <: AbstractString} mulirreps = MulIrrep[] if strip(irreps) != "" for mul_irrep in split(irreps, "+") @@ -162,26 +162,28 @@ Base.firstindex(xs::Irreps) = Base.firstindex(xs.irreps) Base.lastindex(xs::Irreps) = Base.lastindex(xs.irreps) Base.in(x::Irrep, xs::Irreps) = x ∈ [mx.irrep for mx in xs.irreps] -Base.count(x::Irrep, xs::Irreps) = +function Base.count(x::Irrep, xs::Irreps) sum([mx.mul for mx in xs.irreps if mx.irrep == x], init = 0) -Base.iterate(xs::Irreps, state = 1) = +end +function Base.iterate(xs::Irreps, state = 1) state > length(xs) ? nothing : ((xs[state].mul, (xs[state].irrep.l, xs[state].irrep.p)), state + 1) +end """ Representation of spherical harmonics. """ function spherical_harmonics(lmax::Int, p::Int = -1)::Irreps - return Irreps([(1, (l, p^l)) for l = 1:lmax]) + return Irreps([(1, (l, p^l)) for l in 1:lmax]) end function Base.randn( - xs::Irreps, - dims, - normalization::String, - rng::AbstractRNG, - ::Type{T}, -) where {T<:Number} + xs::Irreps, + dims, + normalization::String, + rng::AbstractRNG, + ::Type{T} +) where {T <: Number} di = dims[end] lsize = dims[:di] rsize = dims[di + 1 :] @@ -210,8 +212,8 @@ end """ Remove any irreps with multiplicities of zero. """ -remove_zero_multiplicities(xs::Irreps) = - [(mul, irreps) for (mul, irreps) in xs if mul > 0] |> Irreps +remove_zero_multiplicities(xs::Irreps) = [(mul, irreps) for (mul, irreps) in xs if mul > 0] |> + Irreps function Base.sort(xs::Irreps)::Irreps out = [(mx.irrep, i, mx.mul) for (i, mx) in enumerate(xs)] @@ -224,7 +226,7 @@ dim(xs::Irreps) = sum([mx.mul * dim(mx.irrep) for mx in xs], init = 0) num_irreps(xs::Irreps) = sum([mx.mul for mx in xs], init = 0) -ls(xs::Irreps) = [mx.irrep.l for mx in xs for _ = 1:mx.mul] +ls(xs::Irreps) = [mx.irrep.l for mx in xs for _ in 1:(mx.mul)] function lmax(xs::Irreps)::Int if length(xs) == 0 diff --git a/src/o3/rotations.jl b/src/o3/rotations.jl index 582cfb7..01b3a83 100644 --- a/src/o3/rotations.jl +++ b/src/o3/rotations.jl @@ -10,38 +10,38 @@ export euler_angles, CartesianToSphericalAngles, SphercialAnglesToCartesian # Cross conversion -function (::Type{Q})(aa::AngleAxis) where {Q<:Quaternion} +function (::Type{Q})(aa::AngleAxis) where {Q <: Quaternion} s, c = sincos(aa.theta / 2) return Q(c, s * aa.axis_x, s * aa.axis_y, s * aa.axis_z) end -function (::Type{AA})(q::Quaternion) where {AA<:AngleAxis} +function (::Type{AA})(q::Quaternion) where {AA <: AngleAxis} AA(QuatRotation(q, false)) # already implemented method end -function (::Type{AA})(; α::Real, β::Real, γ::Real) where {AA<:AngleAxis} +function (::Type{AA})(; α::Real, β::Real, γ::Real) where {AA <: AngleAxis} AA(RotYXY(promote(α, β, γ)...)) end -function (::Type{Q})(; α::Real, β::Real, γ::Real) where {Q<:Quaternion} +function (::Type{Q})(; α::Real, β::Real, γ::Real) where {Q <: Quaternion} return QuatRotation(RotYXY(promote(α, β, γ)...)).q |> Q end -function euler_angles(x::T) where {T<:Union{QuatRotation,RotMatrix3,AngleAxis}} +function euler_angles(x::T) where {T <: Union{QuatRotation, RotMatrix3, AngleAxis}} Rotations.params(RotYXY(x)) end -euler_angles(q::Q) where {Q<:Quaternion} = euler_angles(QuatRotation(q, false)) +euler_angles(q::Q) where {Q <: Quaternion} = euler_angles(QuatRotation(q, false)) -function (::Type{R})(q::Quaternion) where {R<:RotMatrix3} +function (::Type{R})(q::Quaternion) where {R <: RotMatrix3} QuatRotation(q) |> R end -function (::Type{Q})(R::RotMatrix3) where {Q<:Quaternion} +function (::Type{Q})(R::RotMatrix3) where {Q <: Quaternion} return QuatRotation(R).q |> Q end -function CartesianToSphericalAngles(x::AbstractVector{T}) where {T<:Real} +function CartesianToSphericalAngles(x::AbstractVector{T}) where {T <: Real} length(x) == 3 || error("Spherical transform takes a 3D coordinate") # done in e3nn to remove NaNs diff --git a/src/o3/s2grid.jl b/src/o3/s2grid.jl index f654151..281f20a 100644 --- a/src/o3/s2grid.jl +++ b/src/o3/s2grid.jl @@ -6,13 +6,15 @@ using StaticArrays import Base: *, +, -, / -struct SphericalSignal{T<:AbstractArray} +struct SphericalSignal{T <: AbstractArray} grid_values::T quadrature::String p_val::Int p_arg::Int - function SphericalSignial(grid_values::T, quadrature::String; p_val::Int=1, p_arg::Int=-1, perform_checks::Bool=true) where T<:AbstractArray + function SphericalSignial( + grid_values::T, quadrature::String; p_val::Int = 1, p_arg::Int = -1, + perform_checks::Bool = true) where {T <: AbstractArray} if perform_checks if ndims(grid_values) < 2 throw(ArgumentError("Grid values should have at least 2 axes. Got grid_values of shape $(size(grid_values)).")) @@ -35,17 +37,19 @@ struct SphericalSignal{T<:AbstractArray} end end - function Base.show(io::IO, s::SphericalSignal) if ndims(s.grid_values) >= 2 - print(io, "SphericalSignal(shape=$(size(s.grid_values)), res_beta=$(s.res_beta), res_alpha=$(s.res_alpha), quadrature=$(s.quadrature), p_val=$(s.p_val), p_arg=$(s.p_arg))\n") + print(io, + "SphericalSignal(shape=$(size(s.grid_values)), res_beta=$(s.res_beta), res_alpha=$(s.res_alpha), quadrature=$(s.quadrature), p_val=$(s.p_val), p_arg=$(s.p_arg))\n") show(io, s.grid_values) else print(io, "SphericalSignal($(s.grid_values))") end end -*(s::SphericalSignal, scalar::Number) = SphericalSignal(s.grid_values * scalar, s.quadrature, p_val=s.p_val, p_arg=s.p_arg) +function *(s::SphericalSignal, scalar::Number) + SphericalSignal(s.grid_values * scalar, s.quadrature, p_val = s.p_val, p_arg = s.p_arg) +end *(scalar::Number, s::SphericalSignal) = s * scalar /(s::SphericalSignal, scalar::Number) = s * (1 / scalar) @@ -60,32 +64,32 @@ function +(s1::SphericalSignal, s2::SphericalSignal) throw(ArgumentError("Quadrature for both signals must be identical.")) end - SphericalSignal(s1.grid_values + s2.grid_values, s1.quadrature, p_val=s1.p_val, p_arg=s1.p_arg) + SphericalSignal( + s1.grid_values + s2.grid_values, s1.quadrature, p_val = s1.p_val, p_arg = s1.p_arg) end - -(s1::SphericalSignal, s2::SphericalSignal) = s1 + (-s2) --(s::SphericalSignal) = SphericalSignal(-s.grid_values, s.quadrature, p_val=s.p_val, p_arg=s.p_arg) +-(s::SphericalSignal) = SphericalSignal( + -s.grid_values, s.quadrature, p_val = s.p_val, p_arg = s.p_arg) # Properties Base.size(s::SphericalSignal) = size(s.grid_values) Base.eltype(s::SphericalSignal) = eltype(s.grid_values) Base.ndims(s::SphericalSignal) = ndims(s.grid_values) -res_beta(s::SphericalSignal) = size(s.grid_values, ndims(s.grid_values)-1) +res_beta(s::SphericalSignal) = size(s.grid_values, ndims(s.grid_values) - 1) res_alpha(s::SphericalSignal) = size(s.grid_values, ndims(s.grid_values)) grid_resolution(s::SphericalSignal) = (res_beta(s), res_alpha(s)) - function _s2grid(res_β::Int, res_α::Int, quadrature::String) - γ, qw = _quadrature_weights(res_β, quadrature=quadrature) - α = range(0, 2π, length=res_α) + γ, qw = _quadrature_weights(res_β, quadrature = quadrature) + α = range(0, 2π, length = res_α) return γ, α, qw end function _quadrature_weights(res_β::Int; quadrature::String) if quadrature == "soft" - i = 0:(res_β-1) + i = 0:(res_β - 1) β = (i .+ 0.5) / res_β * π y = -cos.(β)i_soft(res_β) elseif quadrature == "gausslegendre" @@ -97,4 +101,4 @@ function _quadrature_weights(res_β::Int; quadrature::String) return y, qw end -end \ No newline at end of file +end diff --git a/src/o3/spherical_harmonics.jl b/src/o3/spherical_harmonics.jl index 96e1ee2..3b07915 100644 --- a/src/o3/spherical_harmonics.jl +++ b/src/o3/spherical_harmonics.jl @@ -6,10 +6,10 @@ struct SphericalHarmonics _is_range_lmax::Bool function SphericalHarmonics( - irreps_out::Union{AbstractVector{Int},o3.Irreps}, - normalize::Bool, - normalization::String, - irreps_in, + irreps_out::Union{AbstractVector{Int}, o3.Irreps}, + normalize::Bool, + normalization::String, + irreps_in ) @assert normalization in ["integral", "component", "norm"] @@ -27,8 +27,8 @@ struct SphericalHarmonics if !(irreps_in in (o3.Irreps("1x1o"), o3.Irreps("1x1e"))) throw( ArgumentError( - "irreps_in for SphericalHarmonics must be either a vector (`1x1o`) or a pseudovector (`1x1e`), not `$(irreps_in)`", - ), + "irreps_in for SphericalHarmonics must be either a vector (`1x1o`) or a pseudovector (`1x1e`), not `$(irreps_in)`", + ), ) end input_p = irreps_in[1].irrep.p @@ -38,8 +38,8 @@ struct SphericalHarmonics if p != input_p^l throw( ArgumentError( - "irreps_out `$irreps_out` passed to SphericalHarmonics asked for an output of l = $l with parity p = $p, which is inconsistent with the input parity $input_p — the output parity should have been p = $(input_p^l)", - ), + "irreps_out `$irreps_out` passed to SphericalHarmonics asked for an output of l = $l with parity p = $p, which is inconsistent with the input parity $input_p — the output parity should have been p = $(input_p^l)", + ), ) end append!(repeat(l, mul)) @@ -60,10 +60,10 @@ struct SphericalHarmonics end function SphericalHarmonics( - irreps_out::String, - normalize::Bool, - normalization::String, - irreps_in::Any, + irreps_out::String, + normalize::Bool, + normalization::String, + irreps_in::Any ) irreps_out_rep = o3.Irreps(irreps_out) return SphericalHarmonics(irreps_out_rep, normalize, normalization, irreps_in) @@ -74,15 +74,13 @@ function (sh::SphericalHarmonics)(x::AbstractArray) foreach(normalize!, eachslice(x, dims = ndims(x), drop = false)) end - sh_matrix = - _spherical_harmonics(sh._lmax, eachslice(x, dims = ndims(x), drop = false)...) + sh_matrix = _spherical_harmonics( + sh._lmax, eachslice(x, dims = ndims(x), drop = false)...) if !sh._is_range_lmax return cat( - [ - eachslice(sh_matrix, dims = ndims(sh_matrix), drop = false)[l*l:(l+1)*(l+1)] - for l in sh._ls_list - ], - dims = ndims(x) + 1, + [eachslice(sh_matrix, dims = ndims(sh_matrix), drop = false)[(l * l):((l + 1) * (l + 1))] + for l in sh._ls_list], + dims = ndims(x) + 1 ) end if sh.normalization == "integral" @@ -95,10 +93,10 @@ function (sh::SphericalHarmonics)(x::AbstractArray) end function _spherical_harmonics( - lmax::Int, - x::AbstractArray, - y::AbstractArray, - z::AbstractArray, + lmax::Int, + x::AbstractArray, + y::AbstractArray, + z::AbstractArray ) sh_0_0 = fill(similar(x), 1.0) @@ -124,7 +122,7 @@ function _spherical_harmonics( if lmax == 2 return stack( [sh_0_0, sh_1_0, sh_1_1, sh_1_2, sh_2_0, sh_2_1, sh_2_2, sh_2_3, sh_2_4], - dims = ndims(sh_0_0) + 1, + dims = ndims(sh_0_0) + 1 ) end @@ -154,17 +152,16 @@ function _spherical_harmonics( sh_3_3, sh_3_4, sh_3_5, - sh_3_6, + sh_3_6 ], - ndims(sh_0_0) + 1, + ndims(sh_0_0) + 1 ) end sh_4_0 = (3 / 4) * sqrt(2) * (sh_3_0 * z + sh_3_6 * x) - sh_4_1 = - (3 / 4) * sh_3_0 * y + - (3 / 8) * sqrt(6) * sh_3_1 * z + - (3 / 8) * sqrt(6) * sh_3_5 * x + sh_4_1 = (3 / 4) * sh_3_0 * y + + (3 / 8) * sqrt(6) * sh_3_1 * z + + (3 / 8) * sqrt(6) * sh_3_5 * x sh_4_2 = ( -3 / 56 * sqrt(14) * sh_3_0 * z + (3 / 14) * sqrt(21) * sh_3_1 * y + @@ -178,9 +175,8 @@ function _spherical_harmonics( (3 / 28) * sqrt(70) * sh_3_3 * x + (3 / 56) * sqrt(42) * sh_3_5 * x ) - sh_4_4 = - -3 / 28 * sqrt(42) * sh_3_2 * x + (3 / 7) * sqrt(7) * sh_3_3 * y - - 3 / 28 * sqrt(42) * sh_3_4 * z + sh_4_4 = -3 / 28 * sqrt(42) * sh_3_2 * x + (3 / 7) * sqrt(7) * sh_3_3 * y - + 3 / 28 * sqrt(42) * sh_3_4 * z sh_4_5 = ( -3 / 56 * sqrt(42) * sh_3_1 * x + (3 / 28) * sqrt(70) * sh_3_3 * z + @@ -191,10 +187,9 @@ function _spherical_harmonics( (3 / 56) * sqrt(210) * sh_3_4 * z + (3 / 14) * sqrt(21) * sh_3_5 * y - 3 / 56 * sqrt(14) * sh_3_6 * z ) - sh_4_7 = - -3 / 8 * sqrt(6) * sh_3_1 * x + - (3 / 8) * sqrt(6) * sh_3_5 * z + - (3 / 4) * sh_3_6 * y + sh_4_7 = -3 / 8 * sqrt(6) * sh_3_1 * x + + (3 / 8) * sqrt(6) * sh_3_5 * z + + (3 / 4) * sh_3_6 * y sh_4_8 = (3 / 4) * sqrt(2) * (-sh_3_0 * x + sh_3_6 * z) if lmax == 4 return stack( @@ -223,17 +218,16 @@ function _spherical_harmonics( sh_4_5, sh_4_6, sh_4_7, - sh_4_8, + sh_4_8 ], - ndims(sh_0_0) + 1, + ndims(sh_0_0) + 1 ) end sh_5_0 = (1 / 10) * sqrt(110) * (sh_4_0 * z + sh_4_8 * x) - sh_5_1 = - (1 / 5) * sqrt(11) * sh_4_0 * y + - (1 / 5) * sqrt(22) * sh_4_1 * z + - (1 / 5) * sqrt(22) * sh_4_7 * x + sh_5_1 = (1 / 5) * sqrt(11) * sh_4_0 * y + + (1 / 5) * sqrt(22) * sh_4_1 * z + + (1 / 5) * sqrt(22) * sh_4_7 * x sh_5_2 = ( -1 / 30 * sqrt(22) * sh_4_0 * z + (4 / 15) * sqrt(11) * sh_4_1 * y + @@ -273,10 +267,9 @@ function _spherical_harmonics( (1 / 15) * sqrt(154) * sh_4_6 * z + (4 / 15) * sqrt(11) * sh_4_7 * y - 1 / 30 * sqrt(22) * sh_4_8 * z ) - sh_5_9 = - -1 / 5 * sqrt(22) * sh_4_1 * x + - (1 / 5) * sqrt(22) * sh_4_7 * z + - (1 / 5) * sqrt(11) * sh_4_8 * y + sh_5_9 = -1 / 5 * sqrt(22) * sh_4_1 * x + + (1 / 5) * sqrt(22) * sh_4_7 * z + + (1 / 5) * sqrt(11) * sh_4_8 * y sh_5_10 = (1 / 10) * sqrt(110) * (-sh_4_0 * x + sh_4_8 * z) if lmax == 5 return stack( @@ -316,9 +309,9 @@ function _spherical_harmonics( sh_5_7, sh_5_8, sh_5_9, - sh_5_10, + sh_5_10 ], - ndims(sh_0_0) + 1, + ndims(sh_0_0) + 1 ) end @@ -437,17 +430,16 @@ function _spherical_harmonics( sh_6_9, sh_6_10, sh_6_11, - sh_6_12, + sh_6_12 ], - ndims(sh_0_0) + 1, + ndims(sh_0_0) + 1 ) end sh_7_0 = (1 / 14) * sqrt(210) * (sh_6_0 * z + sh_6_12 * x) - sh_7_1 = - (1 / 7) * sqrt(15) * sh_6_0 * y + - (3 / 7) * sqrt(5) * sh_6_1 * z + - (3 / 7) * sqrt(5) * sh_6_11 * x + sh_7_1 = (1 / 7) * sqrt(15) * sh_6_0 * y + + (3 / 7) * sqrt(5) * sh_6_1 * z + + (3 / 7) * sqrt(5) * sh_6_11 * x sh_7_2 = ( -1 / 182 * sqrt(390) * sh_6_0 * z + (6 / 91) * sqrt(130) * sh_6_1 * y + @@ -512,10 +504,9 @@ function _spherical_harmonics( (6 / 91) * sqrt(130) * sh_6_11 * y - 1 / 182 * sqrt(390) * sh_6_12 * z - 3 / 91 * sqrt(715) * sh_6_2 * x ) - sh_7_13 = - -3 / 7 * sqrt(5) * sh_6_1 * x + - (3 / 7) * sqrt(5) * sh_6_11 * z + - (1 / 7) * sqrt(15) * sh_6_12 * y + sh_7_13 = -3 / 7 * sqrt(5) * sh_6_1 * x + + (3 / 7) * sqrt(5) * sh_6_11 * z + + (1 / 7) * sqrt(15) * sh_6_12 * y sh_7_14 = (1 / 14) * sqrt(210) * (-sh_6_0 * x + sh_6_12 * z) if lmax == 7 return stack( @@ -583,9 +574,9 @@ function _spherical_harmonics( sh_7_11, sh_7_12, sh_7_13, - sh_7_14, + sh_7_14 ], - ndims(sh_0_0) + 1, + ndims(sh_0_0) + 1 ) end @@ -660,7 +651,8 @@ function _spherical_harmonics( (1 / 48) * sqrt(2) * ( - sqrt(1122) * sh_7_10 * y - sqrt(102) * sh_7_11 * z - sqrt(102) * sh_7_3 * x - sqrt(561) * sh_7_5 * x + sqrt(561) * sh_7_9 * z + sqrt(1122) * sh_7_10 * y - sqrt(102) * sh_7_11 * z - sqrt(102) * sh_7_3 * x - + sqrt(561) * sh_7_5 * x + sqrt(561) * sh_7_9 * z ) ) sh_8_12 = ( @@ -672,7 +664,8 @@ function _spherical_harmonics( (1 / 80) * sqrt(2) * ( - -sqrt(85) * sh_7_1 * x + sqrt(2210) * sh_7_11 * z + sqrt(2210) * sh_7_12 * y - sqrt(85) * sh_7_13 * z - sqrt(2210) * sh_7_3 * x + -sqrt(85) * sh_7_1 * x + sqrt(2210) * sh_7_11 * z + sqrt(2210) * sh_7_12 * y - + sqrt(85) * sh_7_13 * z - sqrt(2210) * sh_7_3 * x ) ) sh_8_14 = ( @@ -770,9 +763,9 @@ function _spherical_harmonics( sh_8_13, sh_8_14, sh_8_15, - sh_8_16, + sh_8_16 ], - ndims(sh_0_0) + 1, + ndims(sh_0_0) + 1 ) end @@ -963,9 +956,9 @@ function _spherical_harmonics( sh_9_15, sh_9_16, sh_9_17, - sh_9_18, + sh_9_18 ], - ndims(sh_0_0) + 1, + ndims(sh_0_0) + 1 ) end @@ -1198,9 +1191,9 @@ function _spherical_harmonics( sh_10_17, sh_10_18, sh_10_19, - sh_10_20, + sh_10_20 ], - ndims(sh_0_0) + 1, + ndims(sh_0_0) + 1 ) end @@ -1466,17 +1459,16 @@ function _spherical_harmonics( sh_11_19, sh_11_20, sh_11_21, - sh_11_22, + sh_11_22 ], - ndims(sh_0_0) + 1, + ndims(sh_0_0) + 1 ) end sh_12_0 = (5 / 12) * sqrt(6) * (sh_11_0 * z + sh_11_22 * x) - sh_12_1 = - (5 / 12) * sh_11_0 * y + - (5 / 24) * sqrt(22) * sh_11_1 * z + - (5 / 24) * sqrt(22) * sh_11_21 * x + sh_12_1 = (5 / 12) * sh_11_0 * y + + (5 / 24) * sqrt(22) * sh_11_1 * z + + (5 / 24) * sqrt(22) * sh_11_21 * x sh_12_2 = ( -5 / 552 * sqrt(46) * sh_11_0 * z + (5 / 138) * sqrt(253) * sh_11_1 * y + @@ -1600,10 +1592,9 @@ function _spherical_harmonics( (5 / 552) * sqrt(10626) * sh_11_20 * z + (5 / 138) * sqrt(253) * sh_11_21 * y - 5 / 552 * sqrt(46) * sh_11_22 * z ) - sh_12_23 = - -5 / 24 * sqrt(22) * sh_11_1 * x + - (5 / 24) * sqrt(22) * sh_11_21 * z + - (5 / 12) * sh_11_22 * y + sh_12_23 = -5 / 24 * sqrt(22) * sh_11_1 * x + + (5 / 24) * sqrt(22) * sh_11_21 * z + + (5 / 12) * sh_11_22 * y sh_12_24 = (5 / 12) * sqrt(6) * (-sh_11_0 * x + sh_11_22 * z) return stack( @@ -1776,8 +1767,8 @@ function _spherical_harmonics( sh_12_21, sh_12_22, sh_12_23, - sh_12_24, + sh_12_24 ], - ndims(sh_0_0) + 1, + ndims(sh_0_0) + 1 ) end diff --git a/test/o3/irreps.jl b/test/o3/irreps.jl index f629263..4e81636 100644 --- a/test/o3/irreps.jl +++ b/test/o3/irreps.jl @@ -2,7 +2,6 @@ using e3nn.o3 using Test @testset "Irreps" begin - @testset "creation" begin Irrep(2, 1) ir = Irrep("3o") @@ -66,7 +65,7 @@ using Test repeat([2], 12), repeat([1], 1), repeat([2], 2), - repeat([4], 12), + repeat([4], 12) ) @test lmax(irreps) == 4 @test num_irreps(irreps) == 4 + 6 + 12 + 1 + 2 + 12 @@ -84,5 +83,4 @@ using Test @test_throws ArgumentError Irreps("-1x1e") @test_throws ArgumentError Irreps("bla") end - end