diff --git a/src/main.rs b/src/main.rs index 7c869b7..aed8837 100644 --- a/src/main.rs +++ b/src/main.rs @@ -146,7 +146,8 @@ fn restraint_bodies(input_file: &str) -> Result<(), Box> { }; // Find in-contiguous chains - let gaps = structure::find_structural_gaps(&pdb); + let bodies = structure::find_bodies(&pdb); + let gaps = structure::create_iter_body_gaps(&bodies); // Create the interactors let mut interactors: Vec = Vec::new(); @@ -162,8 +163,8 @@ fn restraint_bodies(input_file: &str) -> Result<(), Box> { interactor_i.set_chain(g.chain.as_str()); interactor_i.set_active(vec![g.res_i as i16]); - interactor_i.set_active_atoms(vec![g.atom.clone()]); - interactor_i.set_passive_atoms(vec![g.atom.clone()]); + interactor_i.set_active_atoms(vec![g.atom_i.clone()]); + interactor_i.set_passive_atoms(vec![g.atom_j.clone()]); interactor_i.set_target_distance(g.distance); interactor_i.set_lower_margin(0.0); interactor_i.set_upper_margin(0.0); diff --git a/src/structure.rs b/src/structure.rs index da7ef25..a0d530c 100644 --- a/src/structure.rs +++ b/src/structure.rs @@ -2,6 +2,9 @@ use std::collections::{HashMap, HashSet}; use kd_tree::KdTree; use pdbtbx::Residue; +use rand::rngs::StdRng; +use rand::seq::SliceRandom; +use rand::SeedableRng; pub fn neighbor_search( pdb: pdbtbx::PDB, @@ -138,15 +141,15 @@ pub fn get_chains_in_contact(pdb: &pdbtbx::PDB, cutoff: f64) -> HashSet<(String, #[derive(Debug)] pub struct Gap { pub chain: String, - pub atom: String, + pub atom_i: String, + pub atom_j: String, pub res_i: isize, pub res_j: isize, pub distance: f64, } -pub fn find_structural_gaps(pdb: &pdbtbx::PDB) -> Vec { +pub fn find_bodies(pdb: &pdbtbx::PDB) -> HashMap> { // Check if the distance of a given atom to its next one is higher than 4A - let mut gaps: Vec = Vec::new(); // Get only the `CA` atoms let mut ca_atoms: Vec<(&str, isize, &pdbtbx::Atom)> = Vec::new(); @@ -159,25 +162,56 @@ pub fn find_structural_gaps(pdb: &pdbtbx::PDB) -> Vec { }); }); }); + let mut bodies: HashMap> = HashMap::new(); + let mut body_id = 0; for (i, j) in ca_atoms.iter().zip(ca_atoms.iter().skip(1)) { let (chain_i, res_i, atom_i) = i; - let (chain_j, res_j, atom_j) = j; + let (chain_j, _res_j, atom_j) = j; if chain_i != chain_j { continue; } let distance = atom_i.distance(atom_j); if distance > 4.0 { - gaps.push(Gap { - chain: chain_i.to_string(), - atom: atom_i.name().to_string(), - res_i: *res_i, - res_j: *res_j, - distance, - }); + body_id += 1; + } + bodies + .entry(body_id) + .or_default() + .push((*res_i, chain_i, atom_i)); + } + + bodies +} + +pub fn create_iter_body_gaps( + bodies: &HashMap>, +) -> Vec { + let mut rng = StdRng::seed_from_u64(42); + let mut pairs: Vec = Vec::new(); + let body_ids: Vec = bodies.keys().cloned().collect(); + for (i, &body_id1) in body_ids.iter().enumerate() { + for &body_id2 in body_ids[i + 1..].iter() { + if let (Some(atoms1), Some(atoms2)) = (bodies.get(&body_id1), bodies.get(&body_id2)) { + if atoms1.len() >= 2 && atoms2.len() >= 2 { + let selected1: Vec<_> = atoms1.choose_multiple(&mut rng, 2).cloned().collect(); + let selected2: Vec<_> = atoms2.choose_multiple(&mut rng, 2).cloned().collect(); + + for i in 0..2 { + pairs.push(Gap { + chain: selected1[i].1.to_string(), + atom_i: selected1[i].2.name().to_string(), + atom_j: selected2[i].2.name().to_string(), + res_i: selected1[i].0, + res_j: selected2[i].0, + distance: selected1[i].2.distance(selected2[i].2), + }); + } + } + } } } - gaps + pairs } #[cfg(test)]