Skip to content

Commit

Permalink
Switch to PythonCall
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonOresten committed Nov 24, 2024
1 parent a60bf77 commit 6602ab8
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
/Manifest.toml
.CondaPkg/
3 changes: 3 additions & 0 deletions CondaPkg.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

[pip.deps]
colabdesign = "@ https://github.com/sokrypton/ColabDesign/archive/refs/tags/v1.1.2.tar.gz"
11 changes: 6 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
name = "ColabMPNN"
uuid = "434c8270-3839-45f3-8f5c-492a580b2514"
authors = ["anton083 <anton.oresten42@gmail.com> and contributors"]
version = "0.0.1"
authors = ["Anton Oresten <anton.oresten42@gmail.com> and contributors"]
version = "0.0.2"

[deps]
Conda = "8f4d0f93-b110-5947-807f-2305c1781a2d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"

[compat]
CondaPkg = "0.2.24"
PythonCall = "0.9.23"
julia = "1"

[extras]
Expand Down
34 changes: 11 additions & 23 deletions src/ColabMPNN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,11 @@ export mpnn
export Samples, Score
export mk_mpnn_model, prep_inputs, sample, sample_parallel, score, get_unconditional_logits

import Pkg
using Conda, PyCall
using PythonCall

const mpnn = PyNULL()
const mpnn = PythonCall.pynew()

function __init__()
ENV["PYTHON"] = ""
Pkg.build("PyCall")

if !haskey(Conda._installed_packages_dict(), "colabdesign")
Conda.pip_interop(true)
Conda.pip("install", "git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
Conda.add("colabdesign")
end

copy!(mpnn, pyimport_conda("colabdesign.mpnn", "colabdesign"))
end
__init__() = PythonCall.pycopy!(mpnn, pyimport("colabdesign.mpnn"))

struct Samples
seq::Vector{String}
Expand All @@ -30,8 +18,8 @@ struct Samples
decoding_order::Array{Int32, 3}
S::Array{Float32, 3}

function Samples(samples::Dict{Any, Any})
new([samples[string(f)] for f in fieldnames(Samples)]...)
function Samples(samples::PyDict)
new([pyconvert(Any, samples[Py(string(f))]) for f in fieldnames(Samples)]...)
end
end

Expand All @@ -42,8 +30,8 @@ struct Score
decoding_order::Array{Int32, 1}
S::Array{Float32, 2}

function Score(scores::Dict{Any, Any})
new([scores[string(f)] for f in fieldnames(Score)]...)
function Score(scores::PyDict)
new([pyconvert(Any, scores[pystr(string(f))]) for f in fieldnames(Score)]...)
end
end

Expand Down Expand Up @@ -83,7 +71,7 @@ prep_inputs(mpnn_model, args...; kwargs...) = mpnn_model.prep_inputs(args...; kw
rescore=false,
)
"""
sample(mpnn_model, args...; kwargs...) = Samples(mpnn_model.sample(args...; kwargs...))
sample(mpnn_model, args...; kwargs...) = Samples(PyDict(mpnn_model.sample(args...; kwargs...)))

"""
sample_parallel(mpnn_model,
Expand All @@ -92,18 +80,18 @@ sample(mpnn_model, args...; kwargs...) = Samples(mpnn_model.sample(args...; kwar
rescore=false,
)
"""
sample_parallel(mpnn_model, args...; kwargs...) = Samples(mpnn_model.sample_parallel(args...; kwargs...))
sample_parallel(mpnn_model, args...; kwargs...) = Samples(PyDict(mpnn_model.sample_parallel(args...; kwargs...)))

"""
score(mpnn_model,
seq=nothing,
)
"""
score(mpnn_model, args...; kwargs...) = Score(mpnn_model.score(args...; kwargs...))
score(mpnn_model, args...; kwargs...) = Score(PyDict(mpnn_model.score(args...; kwargs...)))

"""
get_unconditional_logits(mpnn_model)
"""
get_unconditional_logits(mpnn_model) = mpnn_model.get_unconditional_logits()
get_unconditional_logits(mpnn_model) = pyconvert(Array, mpnn_model.get_unconditional_logits())

end

0 comments on commit 6602ab8

Please sign in to comment.