diff --git a/README.md b/README.md index ad1a9ad5..09aa90dd 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,34 @@ When `purify` runs a directory will be created, and the output images will be saved and time-stamped. Additionally, a config file with the settings used will be saved and time-stamped, helping for reproducibility and book-keeping. +Uncertainty Quantification +-------------------------- + +Bayesian hypothesis testing may be performed with the `purify_UQ` application, which can be located in the `build` directory. + +This application takes a config yaml file with the following parameters: +- `confidence_interval` or `alpha`. (alpha = 1 - confidence interval.)) +- `measurements_path`: path to measurements data (.vis file) +- `reference_image_path`: path to reference image i.e. output from purify run. (.fits files) +- `surrogate_image_path`: path to surrogate image i.e. doctored image with blurring or structural change that you want to test. (.fits file) +- `sigma`: standard deviation for Gaussian likelihood. +- `gamma`: multiplicative factor for prior. +- `purify_config`: path to purify config used to generate the reference image. **This should be used if available in order to ensure consistency of things like measurement and wavelet operators.** + +You can then run the uncertainty quantification with the command: +``` +purify_UQ +``` + +The application will report the value of the objective function for each image, the threshold value calculated from the reference image, and whether the surrogate image is ruled out or not. + +Presently this is designed to work for the unconstrained problem where: +- The negative log-likelihood is a scaled L2-norm i.e. $ \frac{1}{2 \sigma^2} \sum (y_i - \Phi x_i)$ for some data $y$, image $x$, and measurement operator $\Phi$. (Equivalent to indepdendent multivariate Gaussian likelihood.) +- The negative log prior is a scaled L1-norm in _wavelet space_ i.e. $\gamma \sum (\Psi^\dag x)_i$ for some image $x$ and wavelet operator $\Psi$. +- The objective function is the sum of these two terms. + +Docker +------- ## Debugging the CI workflow with tmate The CI workflow has a manual dispatch trigger which allows you to log into the job while it's running. You can trigger it in diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 520cd28c..8261f4b1 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -45,6 +45,8 @@ if (docs) add_subdirectory(docs) endif() +add_subdirectory(uncertainty_quantification) + add_executable(purify main.cc) set_target_properties(purify PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) target_link_libraries(purify libpurify ${sopt_LIBRARIES}) diff --git a/cpp/benchmarks/algorithms.cc b/cpp/benchmarks/algorithms.cc index 4c20ff44..39f9121b 100644 --- a/cpp/benchmarks/algorithms.cc +++ b/cpp/benchmarks/algorithms.cc @@ -115,7 +115,7 @@ BENCHMARK_DEFINE_F(AlgoFixture, ForwardBackwardOnnx)(benchmark::State &state) { m_fb = factory::fb_factory>( factory::algo_distribution::serial, m_measurements_transform, wavelets, m_uv_data, m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, false, 1e-3, - 1e-2, 50, 1.0, tf_model_path, factory::g_proximal_type::TFGProximal); + 1e-2, 50, 1.0, tf_model_path, factory::nondiff_func_type::Denoiser); while (state.KeepRunning()) { auto start = std::chrono::high_resolution_clock::now(); diff --git a/cpp/benchmarks/algorithms_mpi.cc b/cpp/benchmarks/algorithms_mpi.cc index cf97b72d..11f4d19d 100644 --- a/cpp/benchmarks/algorithms_mpi.cc +++ b/cpp/benchmarks/algorithms_mpi.cc @@ -188,7 +188,7 @@ BENCHMARK_DEFINE_F(AlgoFixtureMPI, FbOnnxDistributeImage)(benchmark::State &stat m_fb = factory::fb_factory>( factory::algo_distribution::mpi_serial, m_measurements_distribute_image, wavelets, m_uv_data, m_sigma, beta, gamma, m_imsizey, m_imsizex, m_sara.size(), state.range(3) + 1, true, true, - false, 1e-3, 1e-2, 50, 1.0, tf_model_path, factory::g_proximal_type::TFGProximal); + false, 1e-3, 1e-2, 50, 1.0, tf_model_path, factory::nondiff_func_type::Denoiser); // Benchmark the application of the algorithm while (state.KeepRunning()) { diff --git a/cpp/main.cc b/cpp/main.cc index 97d74e49..420bd5e4 100644 --- a/cpp/main.cc +++ b/cpp/main.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include #include "purify/algorithm_factory.h" #include "purify/cimg.h" @@ -11,6 +12,7 @@ #include "purify/measurement_operator_factory.h" #include "purify/pfitsio.h" #include "purify/read_measurements.h" +#include "purify/setup_utils.h" #include "purify/update_factory.h" #include "purify/wavelet_operator_factory.h" #include "purify/wide_field_utilities.h" @@ -20,6 +22,11 @@ #include #include #include + +#ifdef PURIFY_ONNXRT +#include +#endif + using namespace purify; int main(int argc, const char **argv) { @@ -40,386 +47,59 @@ int main(int argc, const char **argv) { " but the configuration file expects version " + params.version() + ". Please updated the config version manually to be compatable with the new version."); - factory::distributed_measurement_operator mop_algo = - (not params.gpu()) ? factory::distributed_measurement_operator::serial - : factory::distributed_measurement_operator::gpu_serial; - factory::distributed_wavelet_operator wop_algo = factory::distributed_wavelet_operator::serial; - bool using_mpi = false; - std::vector image_index = std::vector(); - std::vector w_stacks = std::vector(); - #ifdef PURIFY_MPI auto const session = sopt::mpi::init(argc, argv); #endif - if (params.mpiAlgorithm() != factory::algo_distribution::serial) { -#ifdef PURIFY_MPI - auto const world = sopt::mpi::Communicator::World(); -#else - throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); -#endif - mop_algo = (not params.gpu()) - ? factory::distributed_measurement_operator::mpi_distribute_image - : factory::distributed_measurement_operator::gpu_mpi_distribute_image; - if (params.mpi_all_to_all()) - mop_algo = (not params.gpu()) - ? factory::distributed_measurement_operator::mpi_distribute_all_to_all - : factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all; - wop_algo = factory::distributed_wavelet_operator::mpi_sara; - if (params.mpiAlgorithm() == factory::algo_distribution::mpi_random_updates) { - mop_algo = (not params.gpu()) ? factory::distributed_measurement_operator::serial - : factory::distributed_measurement_operator::serial; - wop_algo = factory::distributed_wavelet_operator::serial; - } - using_mpi = true; - } + const auto [mop_algo, wop_algo, using_mpi] = selectOperators(params); sopt::logging::set_level(params.logging()); purify::logging::set_level(params.logging()); // Read or generate input data - utilities::vis_params uv_data; - t_real sigma; - Vector measurement_op_eigen_vector = - Vector::Ones(params.width() * params.height()); - // read eigen vector for power method - if (params.eigenvector_real() != "" and params.eigenvector_imag() != "") { - t_int rows; - t_int cols; - t_int pols; - t_int chans; - Vector temp_real; - Vector temp_imag; - pfitsio::read3d(params.eigenvector_real(), temp_real, rows, cols, chans, pols); - if (rows != params.height() or cols != params.width() or chans != 1 or pols != 1) - throw std::runtime_error("Image of measurement operator eigenvector is wrong size."); - pfitsio::read3d(params.eigenvector_imag(), temp_imag, rows, cols, chans, pols); - if (rows != params.height() or cols != params.width() or chans != 1 or pols != 1) - throw std::runtime_error("Image of measurement operator eigenvector is wrong size."); - measurement_op_eigen_vector.real() = temp_real; - measurement_op_eigen_vector.imag() = temp_imag; - } - if (params.source() == purify::utilities::vis_source::measurements) { - PURIFY_HIGH_LOG("Input visibilities are from files:"); - for (size_t i = 0; i < params.measurements().size(); i++) - PURIFY_HIGH_LOG("{}", params.measurements()[i]); - sigma = params.measurements_sigma(); -#ifdef PURIFY_MPI - if (using_mpi) { - auto const world = sopt::mpi::Communicator::World(); - uv_data = read_measurements::read_measurements(params.measurements(), world, - distribute::plan::radial, true, stokes::I, - params.measurements_units()); - const t_real norm = - std::sqrt(world.all_sum_all( - (uv_data.weights.real().array() * uv_data.weights.real().array()).sum()) / - world.all_sum_all(uv_data.size())); - // normalise weights - uv_data.weights = uv_data.weights / norm; - // using no weights for now - // uv_data.weights = Vector::Ones(uv_data.size()); - } else -#endif - { - uv_data = read_measurements::read_measurements(params.measurements(), true, stokes::I, - params.measurements_units()); - const t_real norm = std::sqrt( - (uv_data.weights.real().array() * uv_data.weights.real().array()).sum() / uv_data.size()); - // normalising weights - uv_data.weights = uv_data.weights / norm; - // using no weights for now - // uv_data.weights = Vector::Ones(uv_data.size()); - } - if (params.conjugate_w()) uv_data = utilities::conjugate_w(uv_data); -#ifdef PURIFY_MPI - if (params.mpi_wstacking() and - (mop_algo == factory::distributed_measurement_operator::mpi_distribute_all_to_all or - mop_algo == factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all)) { - auto const world = sopt::mpi::Communicator::World(); - const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; - const t_real du = - widefield::pixel_to_lambda(params.cellsizex(), params.width(), params.oversampling()); - std::tie(uv_data, image_index, w_stacks) = utilities::w_stacking_with_all_to_all( - uv_data, du, params.Jx(), params.Jw(), world, params.kmeans_iters(), 0, cost); - } else if (params.mpi_wstacking()) { - auto const world = sopt::mpi::Communicator::World(); - const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; - uv_data = utilities::w_stacking(uv_data, world, params.kmeans_iters(), cost); - } -#endif - } else if (params.source() == purify::utilities::vis_source::simulation) { - PURIFY_HIGH_LOG("Input visibilities will be generated for random coverage."); - // TODO: move this to function (in utilities.h?) - auto image = pfitsio::read2d(params.skymodel()); - if (params.height() != image.rows() || params.width() != image.cols()) - throw std::runtime_error("Input image size (" + std::to_string(image.cols()) + "x" + - std::to_string(image.rows()) + ") is not equal to the input one (" + - std::to_string(params.width()) + "x" + - std::to_string(params.height()) + ")."); - t_int const number_of_pixels = image.size(); - t_int const number_of_vis = params.number_of_measurements(); - t_real const sigma_m = constant::pi / 4; - const t_real rms_w = params.w_rms(); // lambda - if (params.measurements().at(0) == "") { - uv_data = utilities::random_sample_density(number_of_vis, 0, sigma_m, rms_w); - uv_data.units = utilities::vis_units::radians; - uv_data.weights = Vector::Ones(uv_data.size()); - } else { -#ifdef PURIFY_MPI - if (using_mpi) { - auto const world = sopt::mpi::Communicator::World(); - uv_data = read_measurements::read_measurements(params.measurements(), world, - distribute::plan::radial, true, stokes::I, - params.measurements_units()); - } else -#endif - uv_data = read_measurements::read_measurements(params.measurements(), true, stokes::I, - params.measurements_units()); - uv_data.weights = Vector::Ones(uv_data.weights.size()); - } - if (params.conjugate_w()) uv_data = utilities::conjugate_w(uv_data); -#ifdef PURIFY_MPI - if (params.mpi_wstacking() and - (mop_algo == factory::distributed_measurement_operator::mpi_distribute_all_to_all or - mop_algo == factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all)) { - auto const world = sopt::mpi::Communicator::World(); - const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; - const t_real du = - widefield::pixel_to_lambda(params.cellsizex(), params.width(), params.oversampling()); - std::tie(uv_data, image_index, w_stacks) = utilities::w_stacking_with_all_to_all( - uv_data, du, params.Jx(), params.Jw(), world, params.kmeans_iters(), 0, cost); - } else if (params.mpi_wstacking()) { - auto const world = sopt::mpi::Communicator::World(); - const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; - uv_data = utilities::w_stacking(uv_data, world, params.kmeans_iters(), cost); - } -#endif - std::shared_ptr>> sky_measurements; - if (mop_algo != factory::distributed_measurement_operator::mpi_distribute_all_to_all and - mop_algo != factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all) - sky_measurements = - (not params.wprojection()) - ? factory::measurement_operator_factory>( - mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), - params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.sim_J(), - params.mpi_wstacking()) - : factory::measurement_operator_factory>( - mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), - params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.Jw(), - params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); - else - sky_measurements = - (not params.wprojection()) - ? factory::all_to_all_measurement_operator_factory>( - mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), - params.cellsizey(), params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.sim_J(), - params.mpi_wstacking()) - : factory::all_to_all_measurement_operator_factory>( - mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), - params.cellsizey(), params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.Jw(), - params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); - uv_data.vis = - ((*sky_measurements) * Vector::Map(image.data(), image.size())).eval().array(); - sigma = utilities::SNR_to_standard_deviation(uv_data.vis, params.signal_to_noise()); - uv_data.vis = utilities::add_noise(uv_data.vis, 0., sigma); - } - t_real ideal_cell_x = widefield::estimate_cell_size(uv_data.u.cwiseAbs().maxCoeff(), - params.width(), params.oversampling()); - t_real ideal_cell_y = widefield::estimate_cell_size(uv_data.v.cwiseAbs().maxCoeff(), - params.height(), params.oversampling()); -#ifdef PURIFY_MPI - if (using_mpi) { - auto const comm = sopt::mpi::Communicator::World(); - ideal_cell_x = widefield::estimate_cell_size( - comm.all_reduce(uv_data.u.cwiseAbs().maxCoeff(), MPI_MAX), params.width(), - params.oversampling()); - ideal_cell_y = widefield::estimate_cell_size( - comm.all_reduce(uv_data.v.cwiseAbs().maxCoeff(), MPI_MAX), params.height(), - params.oversampling()); - } -#endif - PURIFY_HIGH_LOG( - "Using cell size {}\" x {}\", recommended from the uv coverage and field of view is " - "{}\"x{}\".", - params.cellsizey(), params.cellsizex(), ideal_cell_y, ideal_cell_x); - PURIFY_HIGH_LOG("The equivalent miriad cell size is: {}\" x {}\"", - widefield::equivalent_miriad_cell_size(params.cellsizex(), params.width(), - params.oversampling()), - widefield::equivalent_miriad_cell_size(params.cellsizey(), params.height(), - params.oversampling())); + auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] = + getInputData(params, mop_algo, wop_algo, using_mpi); + // create measurement operator - std::shared_ptr>> measurements_transform; - if (mop_algo != factory::distributed_measurement_operator::mpi_distribute_all_to_all and - mop_algo != factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all) - measurements_transform = - (not params.wprojection()) - ? factory::measurement_operator_factory>( - mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), - params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jx(), - params.mpi_wstacking()) - : factory::measurement_operator_factory>( - mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), - params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jw(), - params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); - else - measurements_transform = - (not params.wprojection()) - ? factory::all_to_all_measurement_operator_factory>( - mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), - params.cellsizey(), params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jx(), - params.mpi_wstacking()) - : factory::all_to_all_measurement_operator_factory>( - mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), - params.cellsizey(), params.cellsizex(), params.oversampling(), - kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jw(), - params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); - t_real operator_norm = 1.; -#ifdef PURIFY_MPI - if (using_mpi) { - auto const comm = sopt::mpi::Communicator::World(); - auto power_method_result = - (params.mpiAlgorithm() != factory::algo_distribution::mpi_random_updates) - ? sopt::algorithm::power_method>( - *measurements_transform, params.powMethod_iter(), params.powMethod_tolerance(), - comm.broadcast(measurement_op_eigen_vector).eval()) - : sopt::algorithm::all_sum_all_power_method>( - comm, *measurements_transform, params.powMethod_iter(), - params.powMethod_tolerance(), comm.broadcast(measurement_op_eigen_vector).eval()); - measurement_op_eigen_vector = std::get<1>(power_method_result); - operator_norm = std::get<0>(power_method_result); - } else -#endif - { - auto power_method_result = sopt::algorithm::power_method>( - *measurements_transform, params.powMethod_iter(), params.powMethod_tolerance(), - measurement_op_eigen_vector); - measurement_op_eigen_vector = std::get<1>(power_method_result); - operator_norm = std::get<0>(power_method_result); - } + auto [measurements_transform, operator_norm] = + createMeasurementOperator(params, mop_algo, wop_algo, using_mpi, image_index, w_stacks, + uv_data, measurement_op_eigen_vector); + + // create wavelet operator + const waveletInfo wavelets = createWaveletOperator(params, wop_algo); + PURIFY_LOW_LOG("Value of operator norm is {}", operator_norm); t_real const flux_scale = 1.; uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale; // Save some things before applying the algorithm // the config yaml file - this also generates the output directory and the timestamp - if (params.mpiAlgorithm() != factory::algo_distribution::serial) { -#ifdef PURIFY_MPI - auto const world = sopt::mpi::Communicator::World(); - if (world.is_root()) -#else - throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); -#endif - params.writeOutput(); - } else { - params.writeOutput(); - } - const std::string out_dir = params.output_prefix() + "/output_" + params.timestamp(); + initOutDirectoryWithConfig(params); + // Creating header for saving output images during iterations - const pfitsio::header_params update_header_sol = - pfitsio::header_params(out_dir + "/sol_update.fits", "Jy/Pixel", 1, uv_data.ra, uv_data.dec, - params.measurements_polarization(), params.cellsizex(), - params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); - const pfitsio::header_params update_header_res = - pfitsio::header_params(out_dir + "/res_update.fits", "Jy/Beam", 1, uv_data.ra, uv_data.dec, - params.measurements_polarization(), params.cellsizex(), - params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); - const pfitsio::header_params def_header = pfitsio::header_params( - "", "Jy/Pixel", 1, uv_data.ra, uv_data.dec, params.measurements_polarization(), - params.cellsizex(), params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); + const auto [update_header_sol, update_header_res, def_header] = genHeaders(params, uv_data); + // the eigenvector - if (params.mpiAlgorithm() != factory::algo_distribution::serial) { -#ifdef PURIFY_MPI - auto const world = sopt::mpi::Communicator::World(); - if (world.is_root()) -#else - throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); -#endif - { - pfitsio::write2d(measurement_op_eigen_vector.real(), params.height(), params.width(), - out_dir + "/eigenvector_real.fits", "pix", true); - pfitsio::write2d(measurement_op_eigen_vector.imag(), params.height(), params.width(), - out_dir + "/eigenvector_imag.fits", "pix", true); - } - } else { - pfitsio::write2d(measurement_op_eigen_vector.real(), params.height(), params.width(), - out_dir + "/eigenvector_real.fits", "pix", true); - pfitsio::write2d(measurement_op_eigen_vector.imag(), params.height(), params.width(), - out_dir + "/eigenvector_imag.fits", "pix", true); - } + saveMeasurementEigenVector(params, measurement_op_eigen_vector); + // the psf - pfitsio::header_params psf_header = def_header; - psf_header.fits_name = out_dir + "/psf.fits"; - psf_header.pix_units = "Jy/Pixel"; - const Vector psf = measurements_transform->adjoint() * (uv_data.weights / flux_scale); - const Image psf_image = - Image::Map(psf.data(), params.height(), params.width()).real(); - PURIFY_HIGH_LOG( - "Peak of PSF: {} (used to convert between Jy/Pixel and Jy/BEAM)", - psf_image(static_cast(params.width() * 0.5 + params.height() * 0.5 * params.width()))); - t_real beam_units = 1.; + t_real beam_units = 1.0; if (params.mpiAlgorithm() != factory::algo_distribution::serial) { #ifdef PURIFY_MPI auto const world = sopt::mpi::Communicator::World(); beam_units = world.all_sum_all(uv_data.size()) / flux_scale / flux_scale; - PURIFY_LOW_LOG( - "Expected image domain residual RMS is {} jy/beam", - sigma * params.epsilonScaling() * operator_norm / - (std::sqrt(params.width() * params.height()) * world.all_sum_all(uv_data.size()))); - if (world.is_root()) #else throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); #endif - pfitsio::write2d(psf_image, psf_header, true); } else { beam_units = uv_data.size() / flux_scale / flux_scale; - PURIFY_LOW_LOG("Expected image domain residual RMS is {} jy/beam", - sigma * params.epsilonScaling() * operator_norm / - (std::sqrt(params.width() * params.height()) * uv_data.size())); - pfitsio::write2d(psf_image, psf_header, true); } - PURIFY_HIGH_LOG( - "Theoretical calculation for peak PSF: {} (used to convert between Jy/Pixel and Jy/BEAM)", - beam_units); - PURIFY_HIGH_LOG("Effective sigma is {} Jy", sigma * params.epsilonScaling()); + + savePSF(params, def_header, measurements_transform, uv_data, flux_scale, sigma, operator_norm, + beam_units); + // the dirty image - pfitsio::header_params dirty_header = def_header; - dirty_header.fits_name = out_dir + "/dirty.fits"; - dirty_header.pix_units = "Jy/Beam"; - const Vector dimage = measurements_transform->adjoint() * uv_data.vis; - const Image dirty_image = - Image::Map(dimage.data(), params.height(), params.width()).real(); - if (params.mpiAlgorithm() != factory::algo_distribution::serial) { -#ifdef PURIFY_MPI - auto const world = sopt::mpi::Communicator::World(); - if (world.is_root()) -#else - throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); -#endif - pfitsio::write2d(dirty_image / beam_units, dirty_header, true); - } else { - pfitsio::write2d(dirty_image / beam_units, dirty_header, true); - } - // create wavelet operator - std::vector> sara; - for (size_t i = 0; i < params.wavelet_basis().size(); i++) - sara.push_back(std::make_tuple(params.wavelet_basis().at(i), params.wavelet_levels())); - t_uint sara_size = 0; -#ifdef PURIFY_MPI - { - auto const world = sopt::mpi::Communicator::World(); - if (params.mpiAlgorithm() == factory::algo_distribution::mpi_random_updates) - sara = sopt::wavelets::distribute_sara(sara, world); - } -#endif - auto const wavelets_transform = factory::wavelet_operator_factory>( - wop_algo, sara, params.height(), params.width(), sara_size); + saveDirtyImage(params, def_header, measurements_transform, uv_data, beam_units); // Create algorithm std::shared_ptr> padmm; @@ -427,30 +107,44 @@ int main(int argc, const char **argv) { std::shared_ptr> primaldual; if (params.algorithm() == "padmm") padmm = factory::padmm_factory>( - params.mpiAlgorithm(), measurements_transform, wavelets_transform, uv_data, - sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(), sara_size, - params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(), + params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data, + sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(), + wavelets.sara_size, params.iterations(), params.realValueConstraint(), + params.positiveValueConstraint(), (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and (not params.positiveValueConstraint()), params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50, params.epsilonConvergenceScaling(), operator_norm); - if (params.algorithm() == "fb") + if (params.algorithm() == "fb") { + std::shared_ptr> f; + if (params.diffFuncType() == diff_func_type::L2Norm_with_CRR) { +#ifdef PURIFY_ONNXRT + f = std::make_shared>( + params.CRR_function_model_path(), params.CRR_gradient_model_path(), sigma, + params.CRR_mu(), params.CRR_lambda(), *measurements_transform); +#else + throw std::runtime_error("CRR approach cannot be used with ONNXRT off"); +#endif + } + fb = factory::fb_factory>( - params.mpiAlgorithm(), measurements_transform, wavelets_transform, uv_data, + params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data, sigma * params.epsilonScaling() / flux_scale, params.stepsize() * std::pow(sigma * params.epsilonScaling() / flux_scale, 2), - params.regularisation_parameter(), params.height(), params.width(), sara_size, + params.regularisation_parameter(), params.height(), params.width(), wavelets.sara_size, params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(), (params.wavelet_basis().size() < 2) and (not params.realValueConstraint()) and (not params.positiveValueConstraint()), params.relVarianceConvergence(), params.dualFBVarianceConvergence(), 50, operator_norm, - params.model_path(), params.gProximalType()); + params.model_path(), params.nondiffFuncType(), f); + } if (params.algorithm() == "primaldual") primaldual = factory::primaldual_factory>( - params.mpiAlgorithm(), measurements_transform, wavelets_transform, uv_data, - sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(), sara_size, - params.iterations(), params.realValueConstraint(), params.positiveValueConstraint(), - params.relVarianceConvergence(), params.epsilonConvergenceScaling(), operator_norm); + params.mpiAlgorithm(), measurements_transform, wavelets.transform, uv_data, + sigma * params.epsilonScaling() / flux_scale, params.height(), params.width(), + wavelets.sara_size, params.iterations(), params.realValueConstraint(), + params.positiveValueConstraint(), params.relVarianceConvergence(), + params.epsilonConvergenceScaling(), operator_norm); // Add primal dual preconditioning if (params.algorithm() == "primaldual" and params.precondition_iters() > 0) { PURIFY_HIGH_LOG( @@ -476,28 +170,30 @@ int main(int argc, const char **argv) { // Adding step size update to algorithm factory::add_updater>( algo_weak, 1e-3, params.update_tolerance(), params.update_iters(), update_header_sol, - update_header_res, params.height(), params.width(), sara_size, using_mpi, beam_units); + update_header_res, params.height(), params.width(), wavelets.sara_size, using_mpi, + beam_units); } if (params.algorithm() == "primaldual") { const std::weak_ptr> algo_weak(primaldual); // Adding step size update to algorithm factory::add_updater>( algo_weak, 1e-3, params.update_tolerance(), params.update_iters(), update_header_sol, - update_header_res, params.height(), params.width(), sara_size, using_mpi, beam_units); + update_header_res, params.height(), params.width(), wavelets.sara_size, using_mpi, + beam_units); } if (params.algorithm() == "fb") { const std::weak_ptr> algo_weak(fb); // Adding step size update to algorithm factory::add_updater>( algo_weak, 0, params.update_tolerance(), 0, update_header_sol, update_header_res, - params.height(), params.width(), sara_size, using_mpi, beam_units); + params.height(), params.width(), wavelets.sara_size, using_mpi, beam_units); } PURIFY_HIGH_LOG("Starting sopt!"); Image image; Image residual_image; pfitsio::header_params purified_header = def_header; - purified_header.fits_name = out_dir + "/purified.fits"; + purified_header.fits_name = params.output_path() + "/purified.fits"; const Vector estimate_image = (params.warm_start() != "") ? Vector::Map(pfitsio::read2d(params.warm_start()).data(), @@ -560,7 +256,7 @@ int main(int argc, const char **argv) { } // the residuals pfitsio::header_params residuals_header = purified_header; - residuals_header.fits_name = out_dir + "/residuals.fits"; + residuals_header.fits_name = params.output_path() + "/residuals.fits"; residuals_header.pix_units = "Jy/Beam"; if (params.mpiAlgorithm() != factory::algo_distribution::serial) { #ifdef PURIFY_MPI diff --git a/cpp/purify/CMakeLists.txt b/cpp/purify/CMakeLists.txt index 1e56b6ec..0e815cd4 100644 --- a/cpp/purify/CMakeLists.txt +++ b/cpp/purify/CMakeLists.txt @@ -35,7 +35,7 @@ set(HEADERS set(SOURCES utilities.cc pfitsio.cc logging.cc kernels.cc wproj_utilities.cc operators.cc uvfits.cc yaml-parser.cc read_measurements.cc distribute.cc integration.cc wide_field_utilities.cc wkernel_integration.cc - wproj_operators.cc uvw_utilities.cc) + wproj_operators.cc uvw_utilities.cc setup_utils.cc) if(PURIFY_CASACORE) list(APPEND SOURCES casacore.cc) diff --git a/cpp/purify/algorithm_factory.h b/cpp/purify/algorithm_factory.h index b5730b22..9f90677d 100644 --- a/cpp/purify/algorithm_factory.h +++ b/cpp/purify/algorithm_factory.h @@ -35,14 +35,11 @@ namespace purify { namespace factory { enum class algorithm { padmm, primal_dual, sdmm, forward_backward }; enum class algo_distribution { serial, mpi_serial, mpi_distributed, mpi_random_updates }; -enum class g_proximal_type { L1GProximal, TFGProximal, Indicator }; const std::map algo_distribution_string = { {"none", algo_distribution::serial}, {"serial-equivalent", algo_distribution::mpi_serial}, {"random-updates", algo_distribution::mpi_random_updates}, {"fully-distributed", algo_distribution::mpi_distributed}}; -const std::map g_proximal_type_string = { - {"l1", g_proximal_type::L1GProximal}, {"learned", g_proximal_type::TFGProximal}}; //! return chosen algorithm given parameters template @@ -166,7 +163,7 @@ fb_factory(const algo_distribution dist, const bool tight_frame = false, const t_real relative_variation = 1e-3, const t_real l1_proximal_tolerance = 1e-2, const t_uint maximum_proximal_iterations = 50, const t_real op_norm = 1, const std::string model_path = "", - const g_proximal_type g_proximal = g_proximal_type::L1GProximal, + const nondiff_func_type g_proximal = nondiff_func_type::L1Norm, std::shared_ptr> f_function = nullptr) { typedef typename Algorithm::Scalar t_scalar; if (sara_size > 1 and tight_frame) @@ -188,7 +185,7 @@ fb_factory(const algo_distribution dist, std::shared_ptr> g; switch (g_proximal) { - case (g_proximal_type::L1GProximal): { + case (nondiff_func_type::L1Norm): { // Create a shared pointer to an instance of the L1GProximal class // and set its properties auto l1_gp = std::make_shared>(false); @@ -208,7 +205,7 @@ fb_factory(const algo_distribution dist, g = l1_gp; break; } - case (g_proximal_type::TFGProximal): { + case (nondiff_func_type::Denoiser): { #ifdef PURIFY_ONNXRT // Create a shared pointer to an instance of the TFGProximal class g = std::make_shared>(model_path); @@ -218,7 +215,8 @@ fb_factory(const algo_distribution dist, "Type TFGProximal not recognized because purify was built with onnxrt=off"); #endif } - case (g_proximal_type::Indicator): { + + case (nondiff_func_type::RealIndicator): { g = std::make_shared>(); break; } diff --git a/cpp/purify/pfitsio.cc b/cpp/purify/pfitsio.cc index 857b2a77..5a5733e2 100644 --- a/cpp/purify/pfitsio.cc +++ b/cpp/purify/pfitsio.cc @@ -91,4 +91,30 @@ void write3d(const std::vector> &eigen_images, const std::string & write3d(eigen_images, header, overwrite); } +//! Read cube from fits file +std::vector> read3d(const std::string &fits_name) { + std::vector> eigen_images; + Vector image; + int rows, cols, channels, pols = 1; + read3d>(fits_name, image, rows, cols, channels, pols); + for (int i = 0; i < channels; i++) { + Vector eigen_image = Vector::Zero(rows * cols); + eigen_image.real() = image.segment(i * rows * cols, rows * cols); + eigen_images.push_back(Image::Map(eigen_image.data(), rows, cols)); + } + return eigen_images; +} + +//! Read image from fits file +Image read2d(const std::string &fits_name) { + /* + Reads in an image from a fits file and returns the image. + + fits_name:: name of fits file + */ + + const std::vector> images = read3d(fits_name); + return images.at(0); +} + } // namespace purify::pfitsio diff --git a/cpp/purify/pfitsio.h b/cpp/purify/pfitsio.h index 4312417c..6cbc2fa0 100644 --- a/cpp/purify/pfitsio.h +++ b/cpp/purify/pfitsio.h @@ -322,30 +322,10 @@ void read3d(const std::string &fits_name, Eigen::EigenBase &output, int &rows } //! Read cube from fits file -std::vector> read3d(const std::string &fits_name) { - std::vector> eigen_images; - Vector image; - int rows, cols, channels, pols = 1; - read3d>(fits_name, image, rows, cols, channels, pols); - for (int i = 0; i < channels; i++) { - Vector eigen_image = Vector::Zero(rows * cols); - eigen_image.real() = image.segment(i * rows * cols, rows * cols); - eigen_images.push_back(Image::Map(eigen_image.data(), rows, cols)); - } - return eigen_images; -} +std::vector> read3d(const std::string &fits_name); //! Read image from fits file -Image read2d(const std::string &fits_name) { - /* - Reads in an image from a fits file and returns the image. - - fits_name:: name of fits file - */ - - const std::vector> images = read3d(fits_name); - return images.at(0); -} +Image read2d(const std::string &fits_name); } // namespace purify::pfitsio diff --git a/cpp/purify/setup_utils.cc b/cpp/purify/setup_utils.cc new file mode 100644 index 00000000..cc763958 --- /dev/null +++ b/cpp/purify/setup_utils.cc @@ -0,0 +1,442 @@ +#include "purify/setup_utils.h" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace purify; + +waveletInfo createWaveletOperator(YamlParser ¶ms, + const factory::distributed_wavelet_operator &wop_algo) { + std::vector> sara; + for (size_t i = 0; i < params.wavelet_basis().size(); i++) + sara.push_back(std::make_tuple(params.wavelet_basis().at(i), params.wavelet_levels())); + t_uint sara_size = 0; +#ifdef PURIFY_MPI + { + auto const world = sopt::mpi::Communicator::World(); + if (params.mpiAlgorithm() == factory::algo_distribution::mpi_random_updates) + sara = sopt::wavelets::distribute_sara(sara, world); + } +#endif + auto const wavelets_transform = factory::wavelet_operator_factory>( + wop_algo, sara, params.height(), params.width(), sara_size); + return {wavelets_transform, sara_size}; +} + +OperatorsInfo selectOperators(YamlParser ¶ms) { + factory::distributed_measurement_operator mop_algo = + (not params.gpu()) ? factory::distributed_measurement_operator::serial + : factory::distributed_measurement_operator::gpu_serial; + factory::distributed_wavelet_operator wop_algo = factory::distributed_wavelet_operator::serial; + bool using_mpi = false; + if (params.mpiAlgorithm() != factory::algo_distribution::serial) { +#ifndef PURIFY_MPI + throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); +#endif + mop_algo = (not params.gpu()) + ? factory::distributed_measurement_operator::mpi_distribute_image + : factory::distributed_measurement_operator::gpu_mpi_distribute_image; + if (params.mpi_all_to_all()) + mop_algo = (not params.gpu()) + ? factory::distributed_measurement_operator::mpi_distribute_all_to_all + : factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all; + wop_algo = factory::distributed_wavelet_operator::mpi_sara; + if (params.mpiAlgorithm() == factory::algo_distribution::mpi_random_updates) { + mop_algo = (not params.gpu()) ? factory::distributed_measurement_operator::serial + : factory::distributed_measurement_operator::serial; + wop_algo = factory::distributed_wavelet_operator::serial; + } + using_mpi = true; + } + return {mop_algo, wop_algo, using_mpi}; +} + +inputData getInputData(const YamlParser ¶ms, + const factory::distributed_measurement_operator mop_algo, + const factory::distributed_wavelet_operator wop_algo, const bool using_mpi) { + utilities::vis_params uv_data; + bool w_term = params.w_term(); + t_real sigma; + std::vector image_index = std::vector(); + std::vector w_stacks = std::vector(); + + Vector measurement_op_eigen_vector = + Vector::Ones(params.width() * params.height()); + // read eigen vector for power method + if (params.eigenvector_real() != "" and params.eigenvector_imag() != "") { + t_int rows; + t_int cols; + t_int pols; + t_int chans; + Vector temp_real; + Vector temp_imag; + pfitsio::read3d(params.eigenvector_real(), temp_real, rows, cols, chans, pols); + if (rows != params.height() or cols != params.width() or chans != 1 or pols != 1) + throw std::runtime_error("Image of measurement operator eigenvector is wrong size."); + pfitsio::read3d(params.eigenvector_imag(), temp_imag, rows, cols, chans, pols); + if (rows != params.height() or cols != params.width() or chans != 1 or pols != 1) + throw std::runtime_error("Image of measurement operator eigenvector is wrong size."); + measurement_op_eigen_vector.real() = temp_real; + measurement_op_eigen_vector.imag() = temp_imag; + } + if (params.source() == purify::utilities::vis_source::measurements) { + PURIFY_HIGH_LOG("Input visibilities are from files:"); + for (size_t i = 0; i < params.measurements().size(); i++) + PURIFY_HIGH_LOG("{}", params.measurements()[i]); + sigma = params.measurements_sigma(); +#ifdef PURIFY_MPI + if (using_mpi) { + auto const world = sopt::mpi::Communicator::World(); + uv_data = read_measurements::read_measurements(params.measurements(), world, + distribute::plan::radial, w_term, stokes::I, + params.measurements_units()); + const t_real norm = + std::sqrt(world.all_sum_all( + (uv_data.weights.real().array() * uv_data.weights.real().array()).sum()) / + world.all_sum_all(uv_data.size())); + // normalise weights + uv_data.weights = uv_data.weights / norm; + // using no weights for now + // uv_data.weights = Vector::Ones(uv_data.size()); + } else +#endif + { + uv_data = read_measurements::read_measurements(params.measurements(), w_term, stokes::I, + params.measurements_units()); + const t_real norm = std::sqrt( + (uv_data.weights.real().array() * uv_data.weights.real().array()).sum() / uv_data.size()); + // normalising weights + uv_data.weights = uv_data.weights / norm; + // using no weights for now + // uv_data.weights = Vector::Ones(uv_data.size()); + } + if (params.conjugate_w()) uv_data = utilities::conjugate_w(uv_data); +#ifdef PURIFY_MPI + if (params.mpi_wstacking() and + (mop_algo == factory::distributed_measurement_operator::mpi_distribute_all_to_all or + mop_algo == factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all)) { + auto const world = sopt::mpi::Communicator::World(); + const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; + const t_real du = + widefield::pixel_to_lambda(params.cellsizex(), params.width(), params.oversampling()); + std::tie(uv_data, image_index, w_stacks) = utilities::w_stacking_with_all_to_all( + uv_data, du, params.Jx(), params.Jw(), world, params.kmeans_iters(), 0, cost); + } else if (params.mpi_wstacking()) { + auto const world = sopt::mpi::Communicator::World(); + const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; + uv_data = utilities::w_stacking(uv_data, world, params.kmeans_iters(), cost); + } +#endif + } else if (params.source() == purify::utilities::vis_source::simulation) { + PURIFY_HIGH_LOG("Input visibilities will be generated for random coverage."); + // TODO: move this to function (in utilities.h?) + auto image = pfitsio::read2d(params.skymodel()); + if (params.height() != image.rows() || params.width() != image.cols()) + throw std::runtime_error("Input image size (" + std::to_string(image.cols()) + "x" + + std::to_string(image.rows()) + ") is not equal to the input one (" + + std::to_string(params.width()) + "x" + + std::to_string(params.height()) + ")."); + t_int const number_of_pixels = image.size(); + t_int const number_of_vis = params.number_of_measurements(); + t_real const sigma_m = constant::pi / 4; + const t_real rms_w = params.w_rms(); // lambda + if (params.measurements().at(0) == "") { + uv_data = utilities::random_sample_density(number_of_vis, 0, sigma_m, rms_w); + uv_data.units = utilities::vis_units::radians; + uv_data.weights = Vector::Ones(uv_data.size()); + } else { +#ifdef PURIFY_MPI + if (using_mpi) { + auto const world = sopt::mpi::Communicator::World(); + uv_data = read_measurements::read_measurements(params.measurements(), world, + distribute::plan::radial, w_term, stokes::I, + params.measurements_units()); + } else +#endif + uv_data = read_measurements::read_measurements(params.measurements(), w_term, stokes::I, + params.measurements_units()); + uv_data.weights = Vector::Ones(uv_data.weights.size()); + } + if (params.conjugate_w()) uv_data = utilities::conjugate_w(uv_data); +#ifdef PURIFY_MPI + if (params.mpi_wstacking() and + (mop_algo == factory::distributed_measurement_operator::mpi_distribute_all_to_all or + mop_algo == factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all)) { + auto const world = sopt::mpi::Communicator::World(); + const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; + const t_real du = + widefield::pixel_to_lambda(params.cellsizex(), params.width(), params.oversampling()); + std::tie(uv_data, image_index, w_stacks) = utilities::w_stacking_with_all_to_all( + uv_data, du, params.Jx(), params.Jw(), world, params.kmeans_iters(), 0, cost); + } else if (params.mpi_wstacking()) { + auto const world = sopt::mpi::Communicator::World(); + const auto cost = [](t_real x) -> t_real { return std::abs(x * x); }; + uv_data = utilities::w_stacking(uv_data, world, params.kmeans_iters(), cost); + } +#endif + std::shared_ptr>> sky_measurements; + if (mop_algo != factory::distributed_measurement_operator::mpi_distribute_all_to_all and + mop_algo != factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all) + sky_measurements = + (not params.wprojection()) + ? factory::measurement_operator_factory>( + mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), + params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.sim_J(), + params.mpi_wstacking()) + : factory::measurement_operator_factory>( + mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), + params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.Jw(), + params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); + else + sky_measurements = + (not params.wprojection()) + ? factory::all_to_all_measurement_operator_factory>( + mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), + params.cellsizey(), params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.sim_J(), + params.mpi_wstacking()) + : factory::all_to_all_measurement_operator_factory>( + mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), + params.cellsizey(), params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.sim_J(), params.Jw(), + params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); + uv_data.vis = + ((*sky_measurements) * Vector::Map(image.data(), image.size())).eval().array(); + sigma = utilities::SNR_to_standard_deviation(uv_data.vis, params.signal_to_noise()); + uv_data.vis = utilities::add_noise(uv_data.vis, 0., sigma); + } + t_real ideal_cell_x = widefield::estimate_cell_size(uv_data.u.cwiseAbs().maxCoeff(), + params.width(), params.oversampling()); + t_real ideal_cell_y = widefield::estimate_cell_size(uv_data.v.cwiseAbs().maxCoeff(), + params.height(), params.oversampling()); +#ifdef PURIFY_MPI + if (using_mpi) { + auto const comm = sopt::mpi::Communicator::World(); + ideal_cell_x = widefield::estimate_cell_size( + comm.all_reduce(uv_data.u.cwiseAbs().maxCoeff(), MPI_MAX), params.width(), + params.oversampling()); + ideal_cell_y = widefield::estimate_cell_size( + comm.all_reduce(uv_data.v.cwiseAbs().maxCoeff(), MPI_MAX), params.height(), + params.oversampling()); + } +#endif + PURIFY_HIGH_LOG( + "Using cell size {}\" x {}\", recommended from the uv coverage and field of view is " + "{}\"x{}\".", + params.cellsizey(), params.cellsizex(), ideal_cell_y, ideal_cell_x); + PURIFY_HIGH_LOG("The equivalent miriad cell size is: {}\" x {}\"", + widefield::equivalent_miriad_cell_size(params.cellsizex(), params.width(), + params.oversampling()), + widefield::equivalent_miriad_cell_size(params.cellsizey(), params.height(), + params.oversampling())); + + return {uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks}; +} + +measurementOpInfo createMeasurementOperator( + const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, + const factory::distributed_wavelet_operator wop_algo, const bool using_mpi, + const std::vector &image_index, const std::vector &w_stacks, + const utilities::vis_params &uv_data, Vector &measurement_op_eigen_vector) { + std::shared_ptr>> measurements_transform; + if (mop_algo != factory::distributed_measurement_operator::mpi_distribute_all_to_all and + mop_algo != factory::distributed_measurement_operator::gpu_mpi_distribute_all_to_all) + measurements_transform = + (not params.wprojection()) + ? factory::measurement_operator_factory>( + mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), + params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jx(), + params.mpi_wstacking()) + : factory::measurement_operator_factory>( + mop_algo, uv_data, params.height(), params.width(), params.cellsizey(), + params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jw(), + params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); + else + measurements_transform = + (not params.wprojection()) + ? factory::all_to_all_measurement_operator_factory>( + mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), + params.cellsizey(), params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jx(), + params.mpi_wstacking()) + : factory::all_to_all_measurement_operator_factory>( + mop_algo, image_index, w_stacks, uv_data, params.height(), params.width(), + params.cellsizey(), params.cellsizex(), params.oversampling(), + kernels::kernel_from_string.at(params.kernel()), params.Jy(), params.Jw(), + params.mpi_wstacking(), 1e-6, 1e-6, dde_type::wkernel_radial); + t_real operator_norm = 1.; +#ifdef PURIFY_MPI + if (using_mpi) { + auto const comm = sopt::mpi::Communicator::World(); + auto power_method_result = + (params.mpiAlgorithm() != factory::algo_distribution::mpi_random_updates) + ? sopt::algorithm::power_method>( + *measurements_transform, params.powMethod_iter(), params.powMethod_tolerance(), + comm.broadcast(measurement_op_eigen_vector).eval()) + : sopt::algorithm::all_sum_all_power_method>( + comm, *measurements_transform, params.powMethod_iter(), + params.powMethod_tolerance(), comm.broadcast(measurement_op_eigen_vector).eval()); + measurement_op_eigen_vector = std::get<1>(power_method_result); + operator_norm = std::get<0>(power_method_result); + } else +#endif + { + auto power_method_result = sopt::algorithm::power_method>( + *measurements_transform, params.powMethod_iter(), params.powMethod_tolerance(), + measurement_op_eigen_vector); + measurement_op_eigen_vector = std::get<1>(power_method_result); + operator_norm = std::get<0>(power_method_result); + } + + return {measurements_transform, operator_norm}; +} + +void setupCostFunctions(const YamlParser ¶ms, std::unique_ptr> &f, + std::unique_ptr> &g, t_real sigma, + sopt::LinearTransform> &Phi) { + switch (params.diffFuncType()) { + case purify::diff_func_type::L2Norm: + f = std::make_unique>(sigma, Phi); + break; + case purify::diff_func_type::L2Norm_with_CRR: + f = std::make_unique>( + params.CRR_function_model_path(), params.CRR_gradient_model_path(), sigma, params.CRR_mu(), + params.CRR_lambda(), Phi); + break; + } + + switch (params.nondiffFuncType()) { + case purify::nondiff_func_type::L1Norm: + g = std::make_unique>(); + break; + case purify::nondiff_func_type::Denoiser: + g = std::make_unique>(params.model_path()); + break; + case purify::nondiff_func_type::RealIndicator: + g = std::make_unique>(); + break; + } +} + +void initOutDirectoryWithConfig(YamlParser ¶ms) { + if (params.mpiAlgorithm() != factory::algo_distribution::serial) { +#ifdef PURIFY_MPI + auto const world = sopt::mpi::Communicator::World(); + if (world.is_root()) +#else + throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); +#endif + params.writeOutput(); + } else { + params.writeOutput(); + } +} + +Headers genHeaders(const YamlParser ¶ms, const utilities::vis_params &uv_data) { + const pfitsio::header_params update_header_sol = + pfitsio::header_params(params.output_path() + "/sol_update.fits", "Jy/Pixel", 1, uv_data.ra, + uv_data.dec, params.measurements_polarization(), params.cellsizex(), + params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); + const pfitsio::header_params update_header_res = + pfitsio::header_params(params.output_path() + "/res_update.fits", "Jy/Beam", 1, uv_data.ra, + uv_data.dec, params.measurements_polarization(), params.cellsizex(), + params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); + const pfitsio::header_params def_header = pfitsio::header_params( + "", "Jy/Pixel", 1, uv_data.ra, uv_data.dec, params.measurements_polarization(), + params.cellsizex(), params.cellsizey(), uv_data.average_frequency, 0, 0, false, 0, 0, 0); + + return {update_header_sol, update_header_res, def_header}; +} + +void saveMeasurementEigenVector(const YamlParser ¶ms, + const Vector &measurement_op_eigen_vector) { + if (params.mpiAlgorithm() != factory::algo_distribution::serial) { +#ifdef PURIFY_MPI + auto const world = sopt::mpi::Communicator::World(); + if (world.is_root()) +#else + throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); +#endif + { + pfitsio::write2d(measurement_op_eigen_vector.real(), params.height(), params.width(), + params.output_path() + "/eigenvector_real.fits", "pix", true); + pfitsio::write2d(measurement_op_eigen_vector.imag(), params.height(), params.width(), + params.output_path() + "/eigenvector_imag.fits", "pix", true); + } + } else { + pfitsio::write2d(measurement_op_eigen_vector.real(), params.height(), params.width(), + params.output_path() + "/eigenvector_real.fits", "pix", true); + pfitsio::write2d(measurement_op_eigen_vector.imag(), params.height(), params.width(), + params.output_path() + "/eigenvector_imag.fits", "pix", true); + } +} + +void savePSF( + const YamlParser ¶ms, const pfitsio::header_params &def_header, + const std::shared_ptr>> &measurements_transform, + const utilities::vis_params &uv_data, const t_real flux_scale, const t_real sigma, + const t_real operator_norm, const t_real beam_units) { + pfitsio::header_params psf_header = def_header; + psf_header.fits_name = params.output_path() + "/psf.fits"; + psf_header.pix_units = "Jy/Pixel"; + const Vector psf = measurements_transform->adjoint() * (uv_data.weights / flux_scale); + const Image psf_image = + Image::Map(psf.data(), params.height(), params.width()).real(); + PURIFY_HIGH_LOG( + "Peak of PSF: {} (used to convert between Jy/Pixel and Jy/BEAM)", + psf_image(static_cast(params.width() * 0.5 + params.height() * 0.5 * params.width()))); + if (params.mpiAlgorithm() != factory::algo_distribution::serial) { +#ifdef PURIFY_MPI + auto const world = sopt::mpi::Communicator::World(); + PURIFY_LOW_LOG( + "Expected image domain residual RMS is {} jy/beam", + sigma * params.epsilonScaling() * operator_norm / + (std::sqrt(params.width() * params.height()) * world.all_sum_all(uv_data.size()))); + if (world.is_root()) +#else + throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); +#endif + pfitsio::write2d(psf_image, psf_header, true); + } else { + PURIFY_LOW_LOG("Expected image domain residual RMS is {} jy/beam", + sigma * params.epsilonScaling() * operator_norm / + (std::sqrt(params.width() * params.height()) * uv_data.size())); + pfitsio::write2d(psf_image, psf_header, true); + } + PURIFY_HIGH_LOG( + "Theoretical calculation for peak PSF: {} (used to convert between Jy/Pixel and Jy/BEAM)", + beam_units); + PURIFY_HIGH_LOG("Effective sigma is {} Jy", sigma * params.epsilonScaling()); +} + +void saveDirtyImage( + const YamlParser ¶ms, const pfitsio::header_params &def_header, + const std::shared_ptr>> &measurements_transform, + const utilities::vis_params &uv_data, const t_real beam_units) { + pfitsio::header_params dirty_header = def_header; + dirty_header.fits_name = params.output_path() + "/dirty.fits"; + dirty_header.pix_units = "Jy/Beam"; + const Vector dimage = measurements_transform->adjoint() * uv_data.vis; + const Image dirty_image = + Image::Map(dimage.data(), params.height(), params.width()).real(); + if (params.mpiAlgorithm() != factory::algo_distribution::serial) { +#ifdef PURIFY_MPI + auto const world = sopt::mpi::Communicator::World(); + if (world.is_root()) +#else + throw std::runtime_error("Compile with MPI if you want to use MPI algorithm"); +#endif + pfitsio::write2d(dirty_image / beam_units, dirty_header, true); + } else { + pfitsio::write2d(dirty_image / beam_units, dirty_header, true); + } +} \ No newline at end of file diff --git a/cpp/purify/setup_utils.h b/cpp/purify/setup_utils.h new file mode 100644 index 00000000..fc8f4083 --- /dev/null +++ b/cpp/purify/setup_utils.h @@ -0,0 +1,83 @@ +#ifndef SETUP_UTILS_H +#define SETUP_UTILS_H + +#include "purify/types.h" +#include "purify/logging.h" +#include "purify/measurement_operator_factory.h" +#include "purify/pfitsio.h" +#include "purify/read_measurements.h" +#include "purify/wavelet_operator_factory.h" +#include "purify/yaml-parser.h" +#include +#include + +using namespace purify; + +struct waveletInfo { + std::shared_ptr> transform; + t_uint sara_size; +}; + +waveletInfo createWaveletOperator(YamlParser ¶ms, + const factory::distributed_wavelet_operator &wop_algo); + +struct OperatorsInfo { + factory::distributed_measurement_operator mop_algo; + factory::distributed_wavelet_operator wop_algo; + bool using_mpi; +}; + +OperatorsInfo selectOperators(YamlParser ¶ms); + +struct inputData { + utilities::vis_params uv_data; + t_real sigma; + Vector measurement_op_eigen_vector; + std::vector image_index; + std::vector w_stacks; +}; + +inputData getInputData(const YamlParser ¶ms, + const factory::distributed_measurement_operator mop_algo, + const factory::distributed_wavelet_operator wop_algo, const bool using_mpi); + +struct measurementOpInfo { + std::shared_ptr>> measurement_transform; + t_real operator_norm; +}; + +measurementOpInfo createMeasurementOperator( + const YamlParser ¶ms, const factory::distributed_measurement_operator mop_algo, + const factory::distributed_wavelet_operator wop_algo, const bool using_mpi, + const std::vector &image_index, const std::vector &w_stacks, + const utilities::vis_params &uv_data, Vector &measurement_op_eigen_vector); + +void setupCostFunctions(const YamlParser ¶ms, std::unique_ptr> &f, + std::unique_ptr> &g, t_real sigma, + sopt::LinearTransform> &Phi); + +void initOutDirectoryWithConfig(YamlParser ¶ms); + +struct Headers { + pfitsio::header_params solution_header; + pfitsio::header_params residuals_header; + pfitsio::header_params def_header; +}; + +Headers genHeaders(const YamlParser ¶ms, const utilities::vis_params &uv_data); + +void saveMeasurementEigenVector(const YamlParser ¶ms, + const Vector &measurement_op_eigen_vector); + +void savePSF( + const YamlParser ¶ms, const pfitsio::header_params &def_header, + const std::shared_ptr>> &measurements_transform, + const utilities::vis_params &uv_data, const t_real flux_scale, const t_real sigma, + const t_real operator_norm, const t_real beam_units); + +void saveDirtyImage( + const YamlParser ¶ms, const pfitsio::header_params &def_header, + const std::shared_ptr>> &measurements_transform, + const utilities::vis_params &uv_data, const t_real beam_units); + +#endif \ No newline at end of file diff --git a/cpp/purify/types.h b/cpp/purify/types.h index 7754f983..e77c4301 100644 --- a/cpp/purify/types.h +++ b/cpp/purify/types.h @@ -22,6 +22,18 @@ typedef std::complex t_complexf; //! Root of the type hierarchy for triplet lists typedef Eigen::Triplet t_tripletList; +// Different available types of differentiable functions (f) +enum class diff_func_type { L2Norm, L2Norm_with_CRR }; +const std::map diff_type_string = { + {"l2", diff_func_type::L2Norm}, {"CRR", diff_func_type::L2Norm_with_CRR}}; + +// Different available types of non-differentiable functions (g) +enum class nondiff_func_type { L1Norm, Denoiser, RealIndicator }; +const std::map nondiff_type_string = { + {"l1", nondiff_func_type::L1Norm}, + {"denoiser", nondiff_func_type::Denoiser}, + {"realIndicator", nondiff_func_type::RealIndicator}}; + //! \brief A matrix of a given type //! \details Operates as mathematical sparse matrix. template diff --git a/cpp/purify/uvw_utilities.cc b/cpp/purify/uvw_utilities.cc index 9f9e5ef1..3527e7a2 100644 --- a/cpp/purify/uvw_utilities.cc +++ b/cpp/purify/uvw_utilities.cc @@ -177,6 +177,11 @@ utilities::vis_params read_visibility(const std::string &vis_name, const bool w_ vis_name:: name of input text file containing [u, v, real(V), imag(V)] (separated by ' '). */ std::ifstream vis_file(vis_name); + if (vis_file) { + PURIFY_LOW_LOG("File {} successfully opened", vis_name); + } else { + throw std::runtime_error("Could not open file " + vis_name); + } vis_file.precision(13); t_int row = 0; std::string line; diff --git a/cpp/purify/yaml-parser.cc b/cpp/purify/yaml-parser.cc index 2a20fee0..e86b98e2 100644 --- a/cpp/purify/yaml-parser.cc +++ b/cpp/purify/yaml-parser.cc @@ -119,6 +119,13 @@ void YamlParser::parseAndSetGeneralConfiguration(const YAML::Node& generalConfig this->source_ = purify::utilities::vis_source::measurements; this->measurements_ = get_vector>( generalConfigNode, {"InputOutput", "input", "measurements", "measurements_files"}); + try { + this->w_term_ = + get(generalConfigNode, {"InputOutput", "input", "measurements", "w_term"}); + } catch (...) { + PURIFY_LOW_LOG("W-term flag not set for input measurements; defaulting to true."); + this->w_term_ = true; + } // TODO: use the enum instead of string. const std::string units_measurement_str = get( generalConfigNode, {"InputOutput", "input", "measurements", "measurements_units"}); @@ -222,9 +229,24 @@ void YamlParser::parseAndSetAlgorithmOptions(const YAML::Node& algorithmOptionsN get(algorithmOptionsNode, {"fb", "regularisation_parameter"}); this->dualFBVarianceConvergence_ = get(algorithmOptionsNode, {"fb", "dualFBVarianceConvergence"}); - this->gProximalType_ = factory::g_proximal_type_string.at( - get(algorithmOptionsNode, {"fb", "gProximalType"})); - this->model_path_ = get(algorithmOptionsNode, {"fb", "modelPath"}); + + this->nondiffFuncType_ = nondiff_type_string.at( + get(algorithmOptionsNode, {"fb", "nonDifferentiableFunctionType"})); + if (this->nondiffFuncType_ == nondiff_func_type::Denoiser) { + this->model_path_ = get(algorithmOptionsNode, {"fb", "modelPath"}); + } + + this->diffFuncType_ = diff_type_string.at( + get(algorithmOptionsNode, {"fb", "differentiableFunctionType"})); + if (this->diffFuncType_ == diff_func_type::L2Norm_with_CRR) { + this->CRR_function_model_path_ = + get(algorithmOptionsNode, {"fb", "CRR_function_model_path"}); + this->CRR_gradient_model_path_ = + get(algorithmOptionsNode, {"fb", "CRR_gradient_model_path"}); + this->CRR_mu_ = get(algorithmOptionsNode, {"fb", "CRR_mu"}); + this->CRR_lambda_ = get(algorithmOptionsNode, {"fb", "CRR_lambda"}); + } + if (this->algorithm_ == "fb_joint_map") { this->jmap_iters_ = get(algorithmOptionsNode, {"fb", "joint_map_estimation", "iters"}); @@ -303,7 +325,7 @@ void YamlParser::writeOutput() { base_file_name.substr((file_path.size() ? file_path.size() + 1 : 0), base_file_name.size()); // Construct output directory structure and file name boost::filesystem::path const path(this->output_prefix_); - std::string const out_path = output_prefix_ + "/output_" + std::string(this->timestamp()); + out_path = output_prefix_ + "/output_" + std::string(this->timestamp()); mkdir_recursive(out_path); std::string out_filename = out_path + "/" + base_file_name + "_save.yaml"; diff --git a/cpp/purify/yaml-parser.h b/cpp/purify/yaml-parser.h index ede5b537..83424d34 100644 --- a/cpp/purify/yaml-parser.h +++ b/cpp/purify/yaml-parser.h @@ -81,7 +81,7 @@ class YamlParser { TYPE NAME##_ = VALUE; \ \ public: \ - TYPE NAME() { return NAME##_; }; + TYPE NAME() const { return NAME##_; }; YAML_MACRO(std::string, filepath, "") YAML_MACRO(std::string, version, "") @@ -129,6 +129,7 @@ class YamlParser { YAML_MACRO(std::string, output_prefix, "") YAML_MACRO(std::string, skymodel, "") YAML_MACRO(std::vector, measurements, {}) + YAML_MACRO(bool, w_term, true) YAML_MACRO(stokes, measurements_polarization, stokes::I) YAML_MACRO(utilities::vis_units, measurements_units, utilities::vis_units::radians) YAML_MACRO(std::string, kernel, "") @@ -141,11 +142,19 @@ class YamlParser { YAML_MACRO(t_real, jmap_beta, 1) YAML_MACRO(std::string, model_path, "") - YAML_MACRO(factory::g_proximal_type, gProximalType, factory::g_proximal_type::L1GProximal) + YAML_MACRO(nondiff_func_type, nondiffFuncType, nondiff_func_type::L1Norm) + YAML_MACRO(diff_func_type, diffFuncType, diff_func_type::L2Norm) + YAML_MACRO(std::string, CRR_function_model_path, "") + YAML_MACRO(std::string, CRR_gradient_model_path, "") + YAML_MACRO(t_real, CRR_mu, 20) + YAML_MACRO(t_real, CRR_lambda, 5000) + + std::string output_path() const { return out_path; } #undef YAML_MACRO private: YAML::Node config_file; + std::string out_path; template T get(const YAML::Node& node_map, const std::initializer_list indicies); diff --git a/cpp/tests/algo_factory.cc b/cpp/tests/algo_factory.cc index 1e7b1ac4..5c8b0dd5 100644 --- a/cpp/tests/algo_factory.cc +++ b/cpp/tests/algo_factory.cc @@ -223,7 +223,7 @@ TEST_CASE("tf_fb_factory") { auto const fb = factory::fb_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm, - tf_model_path, factory::g_proximal_type::TFGProximal); + tf_model_path, nondiff_func_type::Denoiser); auto const diagnostic = (*fb)(); const Image image = Image::Map(diagnostic.x.data(), imsizey, imsizex); @@ -285,7 +285,7 @@ TEST_CASE("onnx_fb_factory") { auto const fb = factory::fb_factory>( factory::algo_distribution::serial, measurements_transform, wavelets, uv_data, sigma, beta, gamma, imsizey, imsizex, sara.size(), 1000, true, true, false, 1e-2, 1e-3, 50, op_norm, "", - factory::g_proximal_type::Indicator, diff_function); + nondiff_func_type::RealIndicator, diff_function); auto const diagnostic = (*fb)(); const Image image = Image::Map(diagnostic.x.data(), imsizey, imsizex); diff --git a/cpp/uncertainty_quantification/CMakeLists.txt b/cpp/uncertainty_quantification/CMakeLists.txt new file mode 100644 index 00000000..5c2e7918 --- /dev/null +++ b/cpp/uncertainty_quantification/CMakeLists.txt @@ -0,0 +1,9 @@ +add_executable(purify_UQ uq_main.cc) +target_link_libraries(purify_UQ libpurify) +set_target_properties(purify_UQ PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) + +install(TARGETS purify_UQ + EXPORT PurifyTargets + DESTINATION share/cmake/Purify + RUNTIME DESTINATION bin + ) \ No newline at end of file diff --git a/cpp/uncertainty_quantification/uq_main.cc b/cpp/uncertainty_quantification/uq_main.cc new file mode 100644 index 00000000..f3436591 --- /dev/null +++ b/cpp/uncertainty_quantification/uq_main.cc @@ -0,0 +1,188 @@ +#include +#include +#include +#include +#include +#include "purify/measurement_operator_factory.h" +#include "purify/pfitsio.h" +#include "purify/setup_utils.h" +#include "purify/utilities.h" +#include "purify/yaml-parser.h" +#include "yaml-cpp/yaml.h" +#include "sopt/differentiable_func.h" +#include "sopt/non_differentiable_func.h" +#include "sopt/objective_functions.h" +#include +#include +#include +#include + +#ifdef PURIFY_ONNXRT +#include +#endif + +using VectorC = sopt::Vector>; + +int main(int argc, char **argv) { + if (argc != 4) { + std::cout << "purify_UQ should be run using three additional arguments." << std::endl; + std::cout << "purify_UQ " + << std::endl; + std::cout << ": path to a .yaml config file specifying details of measurement " + "operator, wavelet operator, observations, and cost functions." + << std::endl; + std::cout << ": path to image file (.fits) which was output from running " + "purify on observed data." + << std::endl; + std::cout << ": path to modified image file (.fits) for feature analysis." + << std::endl; + std::cout << std::endl; + std::cout + << "For more information about the contents of the config file please consult the README." + << std::endl; + return 1; + } + + // Load and parse the config for parameters + const std::string config_path = argv[1]; + const YAML::Node UQ_config = YAML::LoadFile(config_path); + + // Load the Reference and Surrogate images + const std::string ref_image_path = argv[2]; + const std::string surrogate_image_path = argv[3]; + const auto reference_image = purify::pfitsio::read2d(ref_image_path); + const VectorC reference_vector = VectorC::Map(reference_image.data(), reference_image.size()); + const auto surrogate_image = purify::pfitsio::read2d(surrogate_image_path); + const VectorC surrogate_vector = VectorC::Map(surrogate_image.data(), surrogate_image.size()); + + const uint imsize_x = reference_image.cols(); + const uint imsize_y = reference_image.rows(); + + std::unique_ptr> f; + std::unique_ptr> g; + + // Prepare operators and data using purify config + // If no purify config use basic version for now based on algo_factory test images + purify::utilities::vis_params measurement_data; + double regulariser_strength = 0; + std::shared_ptr> measurement_operator; + std::shared_ptr> wavelet_operator; + std::vector> const sara{ + std::make_tuple("Dirac", 3u), std::make_tuple("DB1", 3u), std::make_tuple("DB2", 3u), + std::make_tuple("DB3", 3u), std::make_tuple("DB4", 3u), std::make_tuple("DB5", 3u), + std::make_tuple("DB6", 3u), std::make_tuple("DB7", 3u), std::make_tuple("DB8", 3u)}; + if (UQ_config["purify_config_file"]) { + YamlParser purify_config = YamlParser(UQ_config["purify_config_file"].as()); + + const auto [mop_algo, wop_algo, using_mpi] = selectOperators(purify_config); + auto [uv_data, sigma, measurement_op_eigen_vector, image_index, w_stacks] = + getInputData(purify_config, mop_algo, wop_algo, using_mpi); + + auto [transform, operator_norm] = + createMeasurementOperator(purify_config, mop_algo, wop_algo, using_mpi, image_index, + w_stacks, uv_data, measurement_op_eigen_vector); + + const waveletInfo wavelets = createWaveletOperator(purify_config, wop_algo); + + t_real const flux_scale = 1.; + uv_data.vis = uv_data.vis.array() * uv_data.weights.array() / flux_scale; + + measurement_data = uv_data; + measurement_operator = transform; + wavelet_operator = wavelets.transform; + + // setup f and g based on config file + setupCostFunctions(purify_config, f, g, sigma, *measurement_operator); + + regulariser_strength = purify_config.regularisation_parameter(); + } else { + const std::string measurements_path = UQ_config["measurements_path"].as(); + // Load the images and measurements + measurement_data = purify::utilities::read_visibility(measurements_path, false); + + // This is the measurement operator used in the test but this should probably be selectable + measurement_operator = purify::factory::measurement_operator_factory>( + purify::factory::distributed_measurement_operator::serial, measurement_data, imsize_y, + imsize_x, 1, 1, 2, kernels::kernel_from_string.at("kb"), 4, 4); + + wavelet_operator = purify::factory::wavelet_operator_factory>( + factory::distributed_wavelet_operator::serial, sara, imsize_y, imsize_x); + + // default cost function + f = std::make_unique>( + 1, *measurement_operator); // what would a default sigma look like?? + g = std::make_unique>(); + + try { + regulariser_strength = UQ_config["regulariser_strength"].as(); + } catch (...) { + std::cout + << "Regulariser strength not provided in UQ config, and no purify config was provided.\n"; + std::cout << "Regulariser strength will be 0 by default." << std::endl; + } + } + + // Set up confidence and objective function params + double confidence; + double alpha; + if ((UQ_config["confidence_interval"]) && (UQ_config["alpha"])) { + std::cout << "Config should only contain one of 'confidence_interval' or 'alpha'." << std::endl; + return 1; + } + if (UQ_config["confidence_interval"]) { + confidence = UQ_config["confidence_interval"].as(); + alpha = 1 - confidence; + } else if (UQ_config["alpha"]) { + alpha = UQ_config["alpha"].as(); + confidence = 1 - alpha; + } else { + std::cout << "Config file must contain either 'confidence_interval' or 'alpha' as a parameter." + << std::endl; + return 1; + } + + if ((imsize_x != surrogate_image.cols()) || (imsize_y != surrogate_image.rows())) { + std::cout << "Surrogate and reference images have different dimensions. Aborting." << std::endl; + return 2; + } + + if (((*measurement_operator) * reference_vector).size() != measurement_data.vis.size()) { + std::cout << "Image size is not compatible with the measurement operator and data provided." + << std::endl; + return 3; + } + + // Calculate the posterior function for the reference image + // posterior = likelihood + prior + // Likelihood = |y - Phi(x)|^2 / sigma^2 (L2 norm) + // Prior = Sum(Psi^t * |x_i|) * regulariser_strength (L1 norm) + auto Posterior = [&measurement_data, measurement_operator, wavelet_operator, regulariser_strength, + &f, &g](const VectorC &image) { + { + const auto residuals = (*measurement_operator * image) - measurement_data.vis; + auto A = f->function(image, measurement_data.vis, (*measurement_operator)); + auto B = g->function(image); + return A + regulariser_strength * B; + } + }; + + const double reference_posterior = Posterior(reference_vector); + const double surrogate_posterior = Posterior(surrogate_vector); + + // Threshold for surrogate image posterior to be within confidence limit + const double N = imsize_x * imsize_y; + const double tau = std::sqrt(16 * std::log(3 / alpha)); + const double threshold = reference_posterior + tau * std::sqrt(N) + N; + + std::cout << "Uncertainty Quantification." << std::endl; + std::cout << "Reference Log Posterior = " << reference_posterior << std::endl; + std::cout << "Confidence interval = " << confidence << std::endl; + std::cout << "Log Posterior threshold = " << threshold << std::endl; + std::cout << "Surrogate Log Posterior = " << surrogate_posterior << std::endl; + std::cout << "Surrogate image is " + << ((surrogate_posterior <= threshold) ? "within the credible interval." + : "excluded by the credible interval.") + << std::endl; + + return 0; +} \ No newline at end of file diff --git a/data/config/fb_UQ.yaml b/data/config/fb_UQ.yaml new file mode 100644 index 00000000..e9e4f611 --- /dev/null +++ b/data/config/fb_UQ.yaml @@ -0,0 +1,2 @@ +purify_config_file: data/config/fb_example.yaml +confidence_interval: 0.95 \ No newline at end of file diff --git a/data/config/fb_example.yaml b/data/config/fb_example.yaml new file mode 100644 index 00000000..692dd952 --- /dev/null +++ b/data/config/fb_example.yaml @@ -0,0 +1,71 @@ +### Config file for Purify +Version: 4.2.0 +GeneralConfiguration: + InputOutput: + input: + # If the following parameter is "measurements", only the measurements block should + # be provided, and the measurements will be read from file. If it's "simulation", only the + # simulation block should be provided, and random measurements of the given + # image will be generated with the requested noise. + source: measurements # one from measurements, simulation + measurements: + measurements_files: + - local/data/expected/fb/input_data.vis # path to the measurement set + w_term: false + measurements_polarization: I # one from I Q V XX LL + measurements_units: radians # one from lambda, radians, pixels + measurements_sigma: 0.0238 # the uncertainty of a visibility (RMS noise of the real or imaginary part of the visibility in Jy) + warm_start: "" #path to fits file image of previous purify solution (assumig the same cell size), otherwise leave as an empty quotation "" + output_prefix: local/data/expected/fb/ # results directory will be created inside this path, and called output_ + + ### Execution parameters + logging: debug # How much logging you want (debug, critical, ..). debug maybe the most useful in some cases because it gives the most output. + iterations: 100 # number of iterations (>= 0). More iterations can give better results. + epsilonScaling: 1 # factor to scale measurements_sigma by. We suggest 1, but it can be useful to change. + +########## Measure Operators ########## +MeasureOperators: + pixelSize: #pixel size in arcseconds. This should be estimated from the longest baseline (PURIFY should print out a suggestion) + cellsizex: 1 + cellsizey: 1 + imageSize: # for the moment: power of two . For a simulation or when using a warm start, this must be same image size as the ground truth + width: 128 + height: 128 + J: # Interpolation pixels for the anti-aliasing kernel (default 4x4). Input either a value or pair (J == 6 is the only option for pswf) + Jx: 4 + Jy: 4 + Jw: 30 #Maximum size of w kernel + kernel: kb # kernel, choose between: kb, Gauss, box, pswf + oversampling: 2 # value > 1. Value of 2 is the standard + gpu: False #This can be used when compiled with arrayfire gpu library + powermethod: + iters: 1000 # value > 0. This is the maximum number of iterations used with the power method for calculating the measurement operator norm. + tolerance: 1e-5 # value > 0. This is the tolerance for convergence of the operator norm + eigenvector: #Here we can choose real and imaginary fits images of the operator eigenvector associated with the operator norm. This speeds up convergence for repeat runs. + real: "" + imag: "" + wide-field: + wprojection: False # using radially symmetric w projection kernel + mpi_wstacking: False # applies average w-stack correction on each node (always True with wprojection) + mpi_all_to_all: False # performs all to all operation of the grid to even out computation. Highly recommended when using MPI for wide-field imaging! + conjugate_w: True #reflects measurements onto the positive w-domain (can reduce computation) + kmeans_iterations: 100 #number of iterations in w-stacking clustering algorithm + +########## SARA ########## +SARA: + # wavelet bases, to choose from: Dirac, DB1, ..., DBn (where n=0 is dirac, and 1 <= n <= 35 is DB basis) + wavelet_dict: 0..8 # sequence of desired basis: 0,23,31 or 0,12..16,20 . We typically use 0..8, this gives good results but will be expensive for large images. + wavelet_levels: 4 # maximum number of wavelet levels for each dictionary. More levels can be good for increasing quality of large scale structures in an image (we recommend 4 to 6). + realValueConstraint: True # Boolean, we recommend this to be True for Stokes I + positiveValueConstraint: True # Boolean, we recommend this to be True for Stokes I + +AlgorithmOptions: + algorithm: fb # will just read the options of that one. We recommend the primal dual algorithm, especially for large image sizes. + fb: #solve the unconstrained problem + mpiAlgorithm: none # one of none, serial-equivalent + regularisation_parameter: 0.0001 + stepsize: 0.00056448 + relVarianceConvergence: 1e-2 # (>0) relative convergence of the objective function + dualFBVarianceConvergence: 1e-3 # (>0) relative convergence tolerance of l1 proximal + nonDifferentiableFunctionType: l1 + differentiableFunctionType: l2