Skip to content

Commit

Permalink
Merge pull request #17 from MurrellGroup/tweaked-defaults
Browse files Browse the repository at this point in the history
Minor adjustments to defaults, and adding some options.
  • Loading branch information
AntonOresten authored Mar 31, 2024
2 parents 4da3dbc + d84c174 commit c0c0fc7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
31 changes: 22 additions & 9 deletions src/animate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,42 @@ using ColorSchemes
"""
function animate_attention(
chain::Backboner.Protein.Chain, attention::AbstractArray{<:Real, 3};
azimuth_start = 1, azimuth_end = 0, output_file::String = "attention.mp4",
ribbon_colorscheme = ColorSchemes.jet,
attention_colorscheme = ColorSchemes.hawaii,
end_padding = 3,
azimuth_start = 1, azimuth_end = -6, output_file::String = "attention.mp4",
ribbon_colorscheme = ColorSchemes.glasgow,
attention_colorscheme = ColorSchemes.hsv,
end_padding = 3, grow_limits = false, from_centroid = true, frames_per_residue::Int = 10, framerate::Int = 30
)
points = Backboner.Protein.alphacarbon_coords(chain)

if from_centroid
points = Backboner.Frames(chain.backbone,Backboner.Protein.STANDARD_TRIANGLE_ANGSTROM).locations
else
points = Backboner.Protein.carbonyl_coords(chain)
end

attention = PointAttention(points, attention)

fig = Figure();
ax = Axis3(fig[1, 1], protrusions=(20, 20, 10, 10), perspectiveness=0.2, aspect=:data);
hidespines!(ax)
#hidedecorations!(ax)
hidedecorations!(ax)

if !grow_limits
plot_limits = extrema(points, dims=2)
xlims!(ax, plot_limits[1])
ylims!(ax, plot_limits[2])
zlims!(ax, plot_limits[3])
end

ax.azimuth[] = azimuth_start

# Aggregate all the plots from each frame, such that they can be deleted.
# Ribbon plots are actually made up of multiple plots, so each subplot gets added to the list.
# A possible optimization is only deleting the last segment of the ribbon plot.
# A possible optimization is only deleting changed segments of the ribbon plot (e.g. last segment, and coils with emerging beta sheets)
plots = Vector{AbstractPlot}()

n = length(chain)
k = 5 # 5 frames per residue
framerate = 30
k = frames_per_residue # 5 frames per residue

frame_indices = 2:1/k:n+end_padding*framerate/k
azimuth(t) = (t / (last(frame_indices) - first(frame_indices))) * (azimuth_end - azimuth_start) + azimuth_start
record(fig, output_file, frame_indices, framerate=framerate) do i
Expand Down
10 changes: 5 additions & 5 deletions src/attention/attention.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ end

function draw_lines_from_point!(container,
point::P, other_points::AbstractVector{P}, linewidths::AbstractVector{<:Real};
color = RGB(1, 1, 1), linewidth_factor = 10.0, plots = nothing, kwargs...
color = RGB(1, 1, 1), linewidth_factor = 3.0, plots = nothing, kwargs...
) where P
length(other_points) == length(linewidths) || throw(ArgumentError("The number of linewidths must match the number of other points."))
xs, ys, zs = [reduce(vcat, ([point[i], other_point[i]] for other_point in other_points)) for i in 1:3]
p = linesegments!(container, xs, ys, zs; linewidth=linewidths .* linewidth_factor, color, transparency=true)
p = linesegments!(container, xs, ys, zs; linewidth=linewidths .* linewidth_factor, color, transparency=true, kwargs...)
!isnothing(plots) && push!(plots, p)
end

Expand All @@ -41,12 +41,12 @@ Take H points, and an LxH attention intensity matrix.
For each column slice of the attention matrix, draw lines to the corresponding points with the intensities in the vector that exceed a certain threshold.
The thickness of the lines should be proportional to the value in the attention matrix.
"""
function draw_attention_slice!(container,
function draw_attention!(container,
point::P, other_points::AbstractVector{P}, intensity_matrix::AbstractMatrix{<:Real};
threshold::Real=1.0, colors=fill(RGB(1, 1, 1), size(intensity_matrix, 1)), kwargs...
) where P
h, l = size(intensity_matrix)
println(h, " ", l, " ", length(other_points))
#println(h, " ", l, " ", length(other_points))
l == length(other_points) || throw(ArgumentError("The number of points must match the number of rows in the attention matrix."))
for i in 1:h
intensity_vector = @view intensity_matrix[i, :]
Expand All @@ -57,7 +57,7 @@ function draw_attention_slice!(container,
end

function draw_attention_slice!(container, i::Int, attention::PointAttention; kwargs...)
draw_attention_slice!(container, eachcol(attention.points)[i], eachcol(attention.points)[1:i], @view(attention.intensities[:, i, 1:i]); kwargs...)
draw_attention!(container, eachcol(attention.points)[i], eachcol(attention.points)[1:i], @view(attention.intensities[:, i, 1:i]); kwargs...)
return nothing
end

Expand Down

0 comments on commit c0c0fc7

Please sign in to comment.