diff --git a/.codecov.yml b/.codecov.yml new file mode 100644 index 0000000000..c4b44c360b --- /dev/null +++ b/.codecov.yml @@ -0,0 +1,29 @@ +comment: + layout: "reach, diff, files" + behavior: default + require_changes: false # if true: only post the comment if coverage changes + require_base: no # [yes :: must have a base report to post] + require_head: yes # [yes :: must have a head report to post] + branches: null + +ignore: + - "*/benchmarks/*" + - "setup.py" + - "*/setup.py" + - "*/tests/*" + - "*/fixes/*" + - "*/external/*" + +coverage: + status: + project: + default: + # Drops on the order 0.01% are typical even when no change occurs + # Having this threshold set a little higher (0.1%) than that makes it + # a little more tolerant to fluctuations + target: auto + threshold: 0.1% + patch: + default: + target: auto + threshold: 0.1% diff --git a/.gitignore b/.gitignore index 9474a416ca..6da02b865d 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,5 @@ __config__.py .buildbot.patch .eggs/ dipy/.idea/ -.idea/ +.idea +.vscode diff --git a/.mailmap b/.mailmap index 6e88a5e40a..fd2c5a3ec1 100644 --- a/.mailmap +++ b/.mailmap @@ -34,6 +34,7 @@ Shahnawaz Ahmed Your Name smerlet Mauro Zucchelli Mauro Mauro Zucchelli maurozucchelli +Mauro Zucchelli maurozucchelli Andrew Lawrence AndrewLawrence Samuel St-Jean samuelstjean Samuel St-Jean samuelstjean @@ -65,6 +66,7 @@ Alexandre Gauvin Alexandre Gauvin Nil Goyette Eric Peterson etpeterson Rutger Fick Rutger Fick +Rutger Fick Rutger Fick Demian Wassermann Demian Wassermann Sourav Singh Sourav Sven Dorkenwald @@ -75,3 +77,13 @@ Matthieu Dumont unknown Adam Rybinski Bennet Fauber +Aman Arya +Ricci Woo RicciWoo +Francois Rheault +David Hunt David +David Hunt davhunt +Parichit Sharma Parichit Sharma +Chandan Gangwar +Naveen Kumarmarri +Jacob Wasserthal +Shreyas Fadnavis diff --git a/.pep8speaks.yml b/.pep8speaks.yml new file mode 100644 index 0000000000..d4cd9bdead --- /dev/null +++ b/.pep8speaks.yml @@ -0,0 +1,24 @@ +# File : .pep8speaks.yml + +message: # Customize the comment made by the bot + opened: # Messages when a new PR is submitted + header: "Hello @{name}, Thank you for submitting the Pull Request !" + # The keyword {name} is converted into the author's username + footer: "Do see the [DIPY coding Style guideline](https://github.com/nipy/dipy/blob/master/doc/devel/coding_style_guideline.rst)" + # The messages can be written as they would over GitHub + updated: # Messages when new commits are added to the PR + header: "Hello @{name}, Thank you for updating !" + footer: "" # Why to comment the link to the style guide everytime? :) + no_errors: "Cheers ! There are no PEP8 issues in this Pull Request. :beers: " + +scanner: + diff_only: True # If True, errors caused by only the patch are shown + +pycodestyle: + max-line-length: 80 # Default is 79 in PEP8 + # ignore: # Errors and warnings to ignore + # - W391 + # - E203 + +only_mention_files_with_errors: True # If False, a separate status comment for each file is made. +descending_issues_order: False # If True, PEP8 issues in message will be displayed in descending order of line numbers in the file \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 554ed253eb..33b9ecc9b5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -27,15 +27,20 @@ python: - 3.4 - 3.5 - 3.6 + # - "3.7" # TODO: Re-enable after https://github.com/travis-ci/travis-ci/issues/9815 is fixed matrix: include: + # TODO: Disable the local workaround + - python: 3.7 + dist: xenial + sudo: true - python: 2.7 # To test minimum dependencies - python: 2.7 env: # Check these values against requirements.txt and dipy/info.py - - DEPENDS="cython==0.25.1 numpy==1.7.1 scipy==0.9.0 nibabel==2.1.0 h5py==2.4.0" + - DEPENDS="cython==0.25.1 numpy==1.7.1 scipy==0.9.0 nibabel==2.3.0 h5py==2.4.0" - python: 2.7 env: - DEPENDS="$DEPENDS scikit_learn" @@ -53,7 +58,7 @@ matrix: - LIBGL_ALWAYS_INDIRECT=y - VENV_ARGS="--system-site-packages --python=/usr/bin/python2.7" - TEST_WITH_XVFB=true - - DEPENDS="$DEPENDS scikit_learn" + - DEPENDS="$DEPENDS scikit_learn fury" - python: 2.7 env: @@ -74,6 +79,10 @@ matrix: # Check against latest available pre-release version of all packages env: - USE_PRE=1 + allow_failures: + - python: 3.5 + env: + - USE_PRE=1 before_install: - PIPI="pip install $EXTRA_PIP_FLAGS" @@ -125,6 +134,7 @@ script: - 'echo "backend : agg" > matplotlibrc' - if [ "${COVERAGE}" == "1" ]; then cp ../.coveragerc .; + cp ../.codecov.yml .; COVER_ARGS="--with-coverage --cover-package dipy"; fi - nosetests --with-doctest --verbose $COVER_ARGS dipy diff --git a/AUTHOR b/AUTHOR index 92832cd1f5..30d0dffcef 100644 --- a/AUTHOR +++ b/AUTHOR @@ -1,33 +1,92 @@ Eleftherios Garyfallidis -Ian Nimmo-Smith +Ariel Rokem Matthew Brett Bago Amirbekian -Stefan Van der Walt -Ariel Rokem -Christopher Nguyen -Yaroslav Halchenko -Emanuele Olivetti -Mauro Zucchelli -Samuel St-Jean -Maxime Descoteaux +Omar Ocegueda +Rafael Neto Henriques +Serge Koudoro +Samuel St-Jean Gabriel Girard +Marc-Alexandre Côté +Rutger Fick +Shahnawaz Ahmed +Ian Nimmo-Smith +Mauro Zucchelli Matthieu Dumont -Kimberly Chan -Erik Ziegler -Emmanuel Caruyer -Matthias Ekman +Stefan van der Walt +Kesshi Jordan +Ranveer Aggarwal +Maxime Descoteaux +Riddhish Bhalodia +Bramsh Qamar +Karandeep +Bishakh Ghosh +Christopher Nguyen +Stephan Meesters +Ricci Woo +Eric Peterson +Manu Tej Sharma +Sourav Singh +Julio Villalon Jean-Christophe Houde -Michael Paquette -Sylvain Merlet -Omar Ocegueda -Marc-Alexandre Cote +Jon Haitz Legarreta Gorroño +Kumar Ashutosh +Shreyas Fadnavis +David Reagan +Parichit Sharma +Guillaume Theaud +Aman Arya +Dimitris Rozakis +Gregory R. Lee +Saber Sheybani +ChantalTax +Nil Goyette +Rohan Prinja +Antonio Ossa Demian Wassermann +Michael Paquette +Tingyi Wanyan +Jiri Borovec +Yaroslav Halchenko +Conor Corbin +Kimberly Chan +ArjitJ <32598699+ArjitJ@users.noreply.github.com> +Enes Albay +Etienne St-Onge +Erik Ziegler +David Qixiang Chen +Francois Rheault +Emanuele Olivetti +David Hunt +Alexandre Gauvin +Pradeep Reddy Raamana +theaverageguy +Julio Villalon endolith +Matthias Ekman +Oscar Esteban +Emmanuel Caruyer +Tom Wright +Jon Haitz Legarreta Gorroño Andrew Lawrence -Gregory R. Lee +Naveen Kumarmarri +Chandan Gangwar +Pradeep Reddy Raamana +Bennet Fauber +Matt Cieslak +Sylvain Merlet +Gonzalo Sanguinetti +Vatsala Swaroop +Vibhatha Abeykoon +Adam Rybinski Maria Luisa Mandelli -Kesshi jordan -Chantal Tax -Qiyuan Tian -Shahnawaz Ahmed -Eric Peterson +Sven Dorkenwald +Qiyuan Tian +Chris Filo Gorgolewski +Bennet Fauber +Daniel Enrico Cahall +Jon Mendoza +Sagun Pai +Javier Guaje +Jacob Wasserthal +Himanshu Mishra \ No newline at end of file diff --git a/Changelog b/Changelog index ef35eb6711..50a3344d1d 100644 --- a/Changelog +++ b/Changelog @@ -24,6 +24,19 @@ Dipy The code found in Dipy was created by the people found in the AUTHOR file. +* 0.15 (Wednesday, 12 December 2018) + +- Updated RecoBundles for automatic anatomical bundle segmentation. +- New Reconstruction Model: qtau-dMRI. +- New command line interfaces (e.g. dipy_slr). +- New continuous integration with AppVeyor CI. +- Nibabel Streamlines API now used almost everywhere for better memory management. +- Compatibility with Python 3.7. +- Many tutorials added or updated (5 New). +- Large documentation update. +- Moved visualization module to a new library: FURY. +- Closed 287 issues and merged 93 pull requests. + * 0.14 (Tuesday, 1 May 2018) - RecoBundles: anatomically relevant segmentation of bundles diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index c59d436b70..79097ef65e 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -12,10 +12,10 @@ - [ ] Operating system and version (run `python -c "import platform; print(platform.platform())"`) - [ ] Python version (run `python -c "import sys; print("Python", sys.version)"`) - [ ] dipy version (run `python -c "import dipy; print(dipy.__version__)"`) -- [ ] dependency version (numpy, scipy, nibabel, h5py, cvxpy, vtk) +- [ ] dependency version (numpy, scipy, nibabel, h5py, cvxpy, fury) * import numpy; print("NumPy", numpy.__version__) * import scipy; print("SciPy", scipy.__version__) * import nibabel; print("Nibabel", nibabel.__version__) * import h5py; print("H5py", h5py.__version__) * import cvxpy; print("Cvxpy", cvxpy.__version__) - * import vtk; print(vtk.vtkVersion.GetVTKSourceVersion()) + * import fury; print("fury", fury.__version__) diff --git a/LICENSE b/LICENSE index b90260ef06..c67326cfdf 100644 --- a/LICENSE +++ b/LICENSE @@ -1,7 +1,7 @@ Unless otherwise specified by LICENSE.txt files in individual directories, or within individual files or functions, all code is: -Copyright (c) 2008-2016, dipy developers +Copyright (c) 2008-2019, dipy developers All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.rst b/README.rst index 5dc21d012b..da04c1c7b4 100644 --- a/README.rst +++ b/README.rst @@ -23,7 +23,10 @@ .. image:: https://img.shields.io/badge/License-BSD%203--Clause-blue.svg :target: https://github.com/nipy/dipy/blob/master/LICENSE -DIPY is a python toolbox for analysis of MR diffusion imaging. +.. image:: https://ci.appveyor.com/api/projects/status/github/nipy/dipy?branch=master&svg=true + :target: https://ci.appveyor.com/project/nipy/dipy + +DIPY [DIPYREF]_ is a python library for analysis of MR diffusion imaging. DIPY is for research only; please do not use results from DIPY for clinical decisions. @@ -40,7 +43,7 @@ Please see the developers' list at https://mail.python.org/mailman/listinfo/neuroimaging Please see the users' forum at -https://neurostars.org +https://neurostars.org/tags/dipy Please join the gitter chatroom `here `_. @@ -67,10 +70,10 @@ DIPY can be installed using `pip`:: or using `conda`:: - conda install -c conda-forge dipy vtk + conda install -c conda-forge dipy For detailed installation instructions, including instructions for installing -from source, please read our `documentation `_. +from source, please read our `installation documentation `_. License @@ -83,3 +86,11 @@ Contributing ============ We welcome contributions from the community. Please read our `Contributing guidelines `_. + +Reference +========= + +.. [DIPYREF] E. Garyfallidis, M. Brett, B. Amirbekian, A. Rokem, + S. Van Der Walt, M. Descoteaux, I. Nimmo-Smith and DIPY contributors, + "DIPY, a library for the analysis of diffusion MRI data", + Frontiers in Neuroinformatics, vol. 8, p. 8, Frontiers, 2014. diff --git a/appveyor.yml b/appveyor.yml new file mode 100644 index 0000000000..3bd1ff40c5 --- /dev/null +++ b/appveyor.yml @@ -0,0 +1,92 @@ +# vim ft=yaml +# CI on Windows via appveyor + +environment: + global: + # SDK v7.0 MSVC Express 2008's SetEnv.cmd script will fail if the + # /E:ON and /V:ON options are not enabled in the batch script interpreter + # See: http://stackoverflow.com/a/13751649/163740 + CMD_IN_ENV: "cmd /E:ON /V:ON /C .\\tools\\run_with_env.cmd" + DEPENDS: "cython numpy scipy matplotlib h5py" + INSTALL_TYPE: "requirements" + EXTRA_PIP_FLAGS: "--timeout=60" + + matrix: + - PYTHON: C:\Python27-x64 + - PYTHON: C:\Python35-x64 + - PYTHON: C:\Python36 + - PYTHON: C:\Python36-x64 + - PYTHON: C:\Python36-x64 + INSTALL_TYPE: "pip" + COVERAGE: 1 + +platform: + - x64 + +init: + - systeminfo + - ps: iex ((new-object net.webclient).DownloadString('https://raw.githubusercontent.com/appveyor/ci/master/scripts/enable-rdp.ps1')) + +install: + # If there is a newer build queued for the same PR, cancel this one. + # The AppVeyor 'rollout builds' option is supposed to serve the same + # purpose but is problematic because it tends to cancel builds pushed + # directly to master instead of just PR builds. + # credits: JuliaLang developers. + - ps: if ($env:APPVEYOR_PULL_REQUEST_NUMBER -and $env:APPVEYOR_BUILD_NUMBER -ne ((Invoke-RestMethod ` + https://ci.appveyor.com/api/projects/$env:APPVEYOR_ACCOUNT_NAME/$env:APPVEYOR_PROJECT_SLUG/history?recordsNumber=50).builds | ` + Where-Object pullRequestId -eq $env:APPVEYOR_PULL_REQUEST_NUMBER)[0].buildNumber) { ` + throw "There are newer queued builds for this pull request, failing early." } + + - "set PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%" + - ps: $env:PIPI = "pip install $env:EXTRA_PIP_FLAGS" + - echo %PIPI% + # Check that we have the expected version and architecture for Python + - "python --version" + - ps: $env:PYTHON_ARCH = python -c "import struct; print(struct.calcsize('P') * 8)" + - ps: $env:PYTHON_VERSION = python -c "import platform;print(platform.python_version())" + - cmd: echo %PYTHON_VERSION% %PYTHON_ARCH% + + - ps: | + if($env:PYTHON -match "conda") + { + conda update -yq conda + Invoke-Expression "conda install -yq pip $env:DEPENDS" + pip install nibabel cvxpy scikit-learn + } + else + { + python -m pip install -U pip + pip --version + if($env:INSTALL_TYPE -match "requirements") + { + Invoke-Expression "$env:PIPI -r requirements.txt" + } + else + { + Invoke-Expression "$env:PIPI $env:DEPENDS" + } + Invoke-Expression "$env:PIPI nibabel matplotlib scikit-learn cvxpy" + } + - "%CMD_IN_ENV% python setup.py build_ext --inplace" + - "%CMD_IN_ENV% %PIPI% --user -e ." + +build: false # Not a C# project, build stuff at the test step instead. + +test_script: + - pip install nose coverage coveralls codecov + - mkdir for_testing + - cd for_testing + - echo backend:Agg > matplotlibrc + - if exist ../.coveragerc (cp ../.coveragerc .) else (echo no .coveragerc) + - ps: | + if ($env:COVERAGE) + { + $env:COVER_ARGS = "--with-coverage --cover-package dipy" + } + - cmd: echo %COVER_ARGS% + - nosetests --with-doctest --verbose %COVER_ARGS% dipy + +cache: + # Avoid re-downloading large packages + - '%APPDATA%\pip\Cache' \ No newline at end of file diff --git a/bin/dipy_fit_mapmri b/bin/dipy_fit_mapmri old mode 100644 new mode 100755 diff --git a/bin/dipy_info b/bin/dipy_info index eba05b7668..f3e57a956e 100755 --- a/bin/dipy_info +++ b/bin/dipy_info @@ -2,8 +2,8 @@ from __future__ import division, print_function -from dipy.workflows.io import IoInfoFlow from dipy.workflows.flow_runner import run_flow +from dipy.workflows.io import IoInfoFlow if __name__ == "__main__": run_flow(IoInfoFlow()) diff --git a/bin/dipy_labelsbundles b/bin/dipy_labelsbundles new file mode 100755 index 0000000000..b616ad2687 --- /dev/null +++ b/bin/dipy_labelsbundles @@ -0,0 +1,9 @@ +#!python + +from __future__ import division, print_function + +from dipy.workflows.flow_runner import run_flow +from dipy.workflows.segment import LabelsBundlesFlow + +if __name__ == "__main__": + run_flow(LabelsBundlesFlow()) diff --git a/bin/dipy_nlmeans b/bin/dipy_nlmeans index fd1bc852a2..5423c131fe 100755 --- a/bin/dipy_nlmeans +++ b/bin/dipy_nlmeans @@ -2,8 +2,9 @@ from __future__ import division, print_function -from dipy.workflows.denoise import NLMeansFlow from dipy.workflows.flow_runner import run_flow +from dipy.workflows.denoise import NLMeansFlow + if __name__ == "__main__": run_flow(NLMeansFlow()) diff --git a/bin/dipy_recobundles b/bin/dipy_recobundles new file mode 100755 index 0000000000..18d3e1fc15 --- /dev/null +++ b/bin/dipy_recobundles @@ -0,0 +1,9 @@ +#!python + +from __future__ import division, print_function + +from dipy.workflows.flow_runner import run_flow +from dipy.workflows.segment import RecoBundlesFlow + +if __name__ == "__main__": + run_flow(RecoBundlesFlow()) diff --git a/bin/dipy_reslice b/bin/dipy_reslice index 4af16f4164..321ad63d89 100755 --- a/bin/dipy_reslice +++ b/bin/dipy_reslice @@ -2,8 +2,9 @@ from __future__ import division, print_function -from dipy.workflows.align import ResliceFlow from dipy.workflows.flow_runner import run_flow +from dipy.workflows.align import ResliceFlow + if __name__ == "__main__": run_flow(ResliceFlow()) \ No newline at end of file diff --git a/bin/dipy_slr b/bin/dipy_slr new file mode 100755 index 0000000000..c5b8135304 --- /dev/null +++ b/bin/dipy_slr @@ -0,0 +1,9 @@ +#!python + +from __future__ import division, print_function + +from dipy.workflows.flow_runner import run_flow +from dipy.workflows.align import SlrWithQbxFlow + +if __name__ == "__main__": + run_flow(SlrWithQbxFlow()) \ No newline at end of file diff --git a/bin/dipy_snr_in_cc b/bin/dipy_snr_in_cc new file mode 100755 index 0000000000..6109473803 --- /dev/null +++ b/bin/dipy_snr_in_cc @@ -0,0 +1,9 @@ +#!python + +from __future__ import division, print_function + +from dipy.workflows.flow_runner import run_flow +from dipy.workflows.stats import SNRinCCFlow + +if __name__ == "__main__": + run_flow(SNRinCCFlow()) diff --git a/dipy/align/imwarp.py b/dipy/align/imwarp.py index 67a1aef7a3..0b21e14096 100644 --- a/dipy/align/imwarp.py +++ b/dipy/align/imwarp.py @@ -1373,9 +1373,9 @@ def _get_energy_derivative(self): x = range(self.energy_window) y = self.energy_list[(n_iter - self.energy_window):n_iter] ss = sum(y) - if(ss > 0): - ss *= -1 - y = [v / ss for v in y] + if not ss == 0: # avoid division by zero + ss = - ss if ss > 0 else ss + y = [v / ss for v in y] der = self._approximate_derivative_direct(x, y) return der diff --git a/dipy/align/reslice.py b/dipy/align/reslice.py index 4157efeb49..78ae35ffcc 100644 --- a/dipy/align/reslice.py +++ b/dipy/align/reslice.py @@ -50,8 +50,8 @@ def reslice(data, affine, zooms, new_zooms, order=1, mode='constant', cval=0, -------- >>> import nibabel as nib >>> from dipy.align.reslice import reslice - >>> from dipy.data import get_data - >>> fimg = get_data('aniso_vox') + >>> from dipy.data import get_fnames + >>> fimg = get_fnames('aniso_vox') >>> img = nib.load(fimg) >>> data = img.get_data() >>> data.shape == (58, 58, 24) diff --git a/dipy/align/streamlinear.py b/dipy/align/streamlinear.py index 165d787dfa..4b11051108 100644 --- a/dipy/align/streamlinear.py +++ b/dipy/align/streamlinear.py @@ -10,17 +10,15 @@ center_streamlines, set_number_of_points, select_random_set_of_streamlines, - length) -from dipy.segment.clustering import QuickBundles + length, + Streamlines) +from dipy.segment.clustering import qbx_and_merge from dipy.core.geometry import (compose_transformations, compose_matrix, decompose_matrix) from dipy.utils.six import string_types from time import time -MAX_DIST = 1e10 -LOG_MAX_DIST = np.log(MAX_DIST) - DEFAULT_BOUNDS = [(-35, 35), (-35, 35), (-35, 35), (-45, 45), (-45, 45), (-45, 45), (0.6, 1.4), (0.6, 1.4), (0.6, 1.4), @@ -169,10 +167,20 @@ def distance(self, xopt): class BundleMinDistanceAsymmetricMetric(BundleMinDistanceMetric): """ Asymmetric Bundle-based Minimum distance + + This is a cost function that can be used by the + StreamlineLinearRegistration class. + """ def distance(self, xopt): + """ Distance calculated from this Metric + Parameters + ---------- + xopt : sequence + List of affine parameters as an 1D vector + """ return bundle_min_distance_asymmetric_fast(xopt, self.static_centered_pts, self.moving_centered_pts, @@ -693,8 +701,15 @@ def bundle_min_distance_asymmetric_fast(t, static, moving, block_size): def remove_clusters_by_size(clusters, min_size=0): + by_size = lambda c: len(c) >= min_size - return filter(by_size, clusters) + ob = filter(by_size, clusters) + + centroids = Streamlines() + for cluster in ob: + centroids.append(cluster.centroid) + + return centroids def progressive_slr(static, moving, metric, x0, bounds, @@ -829,17 +844,17 @@ def progressive_slr(static, moving, metric, x0, bounds, return slm -def slr_with_qb(static, moving, - x0='affine', - rm_small_clusters=50, - maxiter=100, - select_random=None, - verbose=False, - greater_than=50, - less_than=250, - qb_thr=15, - nb_pts=20, - progressive=True, num_threads=None): +def slr_with_qbx(static, moving, + x0='affine', + rm_small_clusters=50, + maxiter=100, + select_random=None, + verbose=False, + greater_than=50, + less_than=250, + qbx_thr=[40, 30, 20, 15], + nb_pts=20, + progressive=True, rng=None, num_threads=None): """ Utility function for registering large tractograms. For efficiency we apply the registration on cluster centroids and remove @@ -849,21 +864,38 @@ def slr_with_qb(static, moving, ---------- static : Streamlines moving : Streamlines + x0 : str rigid, similarity or affine transformation model (default affine) rm_small_clusters : int Remove clusters that have less than `rm_small_clusters` (default 50) - verbose : bool, - If True then information about the optimization is shown. - select_random : int If not None select a random number of streamlines to apply clustering Default None. - options : None or dict, - Extra options to be used with the selected method. + verbose : bool, + If True then information about the optimization is shown. + + greater_than : int, optional + Keep streamlines that have length greater than + this value (default 50) + + less_than : int, optional + Keep streamlines have length less than this value (default 250) + + qbx_thr : variable int + Thresholds for QuickBundlesX (default [40, 30, 20, 15]) + + np_pts : int, optional + Number of points for discretizing each streamline (default 20) + + progressive : boolean, optional + (default True) + + rng : RandomState + If None creates RandomState in function. num_threads : int Number of threads. If None (default) then all available threads @@ -878,14 +910,17 @@ def slr_with_qb(static, moving, References ---------- .. [Garyfallidis15] Garyfallidis et al. "Robust and efficient linear - registration of white-matter fascicles in the space of streamlines" - , NeuroImage, 117, 124--140, 2015 + registration of white-matter fascicles in the space of streamlines", + NeuroImage, 117, 124--140, 2015 .. [Garyfallidis14] Garyfallidis et al., "Direct native-space fiber bundle alignment for group comparisons", ISMRM, 2014. .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter - bundles using local and global streamline-based registration and - clustering, Neuroimage, 2017. + bundles using local and global streamline-based registration and + clustering, Neuroimage, 2017. """ + if rng is None: + rng = np.random.RandomState() + if verbose: print('Static streamlines size {}'.format(len(static))) print('Moving streamlines size {}'.format(len(moving))) @@ -897,9 +932,9 @@ def check_range(streamline, gt=greater_than, lt=less_than): else: return False - # TODO change this to the new Streamlines API - streamlines1 = [s for s in static if check_range(s)] - streamlines2 = [s for s in moving if check_range(s)] + + streamlines1 = Streamlines(static[np.array([check_range(s) for s in static])]) + streamlines2 = Streamlines(moving[np.array([check_range(s) for s in moving])]) if verbose: @@ -910,29 +945,31 @@ def check_range(streamline, gt=greater_than, lt=less_than): if select_random is not None: rstreamlines1 = select_random_set_of_streamlines(streamlines1, - select_random) + select_random, + rng=rng) else: rstreamlines1 = streamlines1 rstreamlines1 = set_number_of_points(rstreamlines1, nb_pts) - qb1 = QuickBundles(threshold=qb_thr) - rstreamlines1 = [s.astype('f4') for s in rstreamlines1] - cluster_map1 = qb1.cluster(rstreamlines1) - clusters1 = remove_clusters_by_size(cluster_map1, rm_small_clusters) - qb_centroids1 = [cluster.centroid for cluster in clusters1] + + rstreamlines1._data.astype('f4') + + cluster_map1 = qbx_and_merge(rstreamlines1, thresholds=qbx_thr, rng=rng) + qb_centroids1 = remove_clusters_by_size(cluster_map1, rm_small_clusters) if select_random is not None: rstreamlines2 = select_random_set_of_streamlines(streamlines2, - select_random) + select_random, + rng=rng) else: rstreamlines2 = streamlines2 rstreamlines2 = set_number_of_points(rstreamlines2, nb_pts) - qb2 = QuickBundles(threshold=qb_thr) - rstreamlines2 = [s.astype('f4') for s in rstreamlines2] - cluster_map2 = qb2.cluster(rstreamlines2) - clusters2 = remove_clusters_by_size(cluster_map2, rm_small_clusters) - qb_centroids2 = [cluster.centroid for cluster in clusters2] + rstreamlines2._data.astype('f4') + + cluster_map2 = qbx_and_merge(rstreamlines2, thresholds=qbx_thr, rng=rng) + + qb_centroids2 = remove_clusters_by_size(cluster_map2, rm_small_clusters) if verbose: t = time() @@ -967,7 +1004,7 @@ def check_range(streamline, gt=greater_than, lt=less_than): # Garyfallidis et al. Recognition of white matter # bundles using local and global streamline-based registration and # clustering, Neuroimage, 2017. -whole_brain_slr = slr_with_qb +whole_brain_slr = slr_with_qbx def _threshold(x, th): @@ -1004,6 +1041,7 @@ def compose_matrix44(t, dtype=np.double): if size not in [3, 6, 7, 9, 12]: raise ValueError('Accepted number of parameters is 3, 6, 7, 9 and 12') + MAX_DIST = 1e10 scale, shear, angles, translate = (None, ) * 4 translate = _threshold(t[0:3], MAX_DIST) if size in [6, 7, 9, 12]: diff --git a/dipy/align/tests/test_imaffine.py b/dipy/align/tests/test_imaffine.py index 57305e46a6..297ea4a7eb 100644 --- a/dipy/align/tests/test_imaffine.py +++ b/dipy/align/tests/test_imaffine.py @@ -186,12 +186,10 @@ def test_affreg_all_transforms(): # Test affine registration using all transforms with typical settings # Make sure dictionary entries are processed in the same order regardless - # of the platform. - # Otherwise any random numbers drawn within the loop would make - # the test non-deterministic even if we fix the seed before the loop. - # Right now, this test does not draw any samples, - # but we still sort the entries - # to prevent future related failures. + # of the platform. Otherwise any random numbers drawn within the loop would + # make the test non-deterministic even if we fix the seed before the loop. + # Right now, this test does not draw any samples, but we still sort the + # entries to prevent future related failures. for ttype in sorted(factors): dim = ttype[1] if dim == 2: @@ -200,9 +198,14 @@ def test_affreg_all_transforms(): nslices = 45 factor = factors[ttype][0] sampling_pc = factors[ttype][1] - transform = regtransforms[ttype] - static, moving, static_grid2world, moving_grid2world, smask, mmask, T = \ - setup_random_transform(transform, factor, nslices, 1.0) + trans = regtransforms[ttype] + # Shorthand: + srt = setup_random_transform + static, moving, static_g2w, moving_g2w, smask, mmask, T = srt( + trans, + factor, + nslices, + 1.0) # Sum of absolute differences start_sad = np.abs(static - moving).sum() metric = imaffine.MutualInformationMetric(32, sampling_pc) @@ -213,9 +216,9 @@ def test_affreg_all_transforms(): 'L-BFGS-B', None, options=None) - x0 = transform.get_identity_parameters() - affine_map = affreg.optimize(static, moving, transform, x0, - static_grid2world, moving_grid2world) + x0 = trans.get_identity_parameters() + affine_map = affreg.optimize(static, moving, trans, x0, + static_g2w, moving_g2w) transformed = affine_map.transform(moving) # Sum of absolute differences end_sad = np.abs(static - transformed).sum() @@ -470,7 +473,7 @@ def test_affine_map(): # compatibility with previous versions assert_array_equal(affine, affine_map.affine) # new getter - new_copy_affine = affine_map.get_affine() + new_copy_affine = affine_map.affine # value must be the same assert_array_equal(affine, new_copy_affine) # but not its reference @@ -512,12 +515,12 @@ def test_affine_map(): aff_map = AffineMap(affine_mat) if affine_mat is None: continue - bad_aug = aff_map.get_affine() + bad_aug = aff_map.affine # no zeros in the first n-1 columns on last row bad_aug[-1,:] = 1 assert_raises(AffineInvalidValuesError, AffineMap, bad_aug) - bad_aug = aff_map.get_affine() + bad_aug = aff_map.affine bad_aug[-1, -1] = 0 # lower right not 1 assert_raises(AffineInvalidValuesError, AffineMap, bad_aug) diff --git a/dipy/align/tests/test_imwarp.py b/dipy/align/tests/test_imwarp.py index 60b17e117c..10c15877f9 100644 --- a/dipy/align/tests/test_imwarp.py +++ b/dipy/align/tests/test_imwarp.py @@ -5,7 +5,7 @@ assert_array_equal, assert_array_almost_equal, assert_raises) -from dipy.data import get_data +from dipy.data import get_fnames from dipy.align import floating from dipy.align import imwarp as imwarp from dipy.align import metrics as metrics @@ -375,8 +375,8 @@ def test_ssd_2d_demons(): Classical Circle-To-C experiment for 2D monomodal registration. We verify that the final registration is of good quality. ''' - fname_moving = get_data('reg_o') - fname_static = get_data('reg_c') + fname_moving = get_fnames('reg_o') + fname_static = get_fnames('reg_c') moving = np.load(fname_moving) static = np.load(fname_static) @@ -444,8 +444,8 @@ def test_ssd_2d_gauss_newton(): Classical Circle-To-C experiment for 2D monomodal registration. We verify that the final registration is of good quality. ''' - fname_moving = get_data('reg_o') - fname_static = get_data('reg_c') + fname_moving = get_fnames('reg_o') + fname_static = get_fnames('reg_c') moving = np.load(fname_moving) static = np.load(fname_static) @@ -563,7 +563,7 @@ def get_warped_stacked_image(image, nslices, b, m): def get_synthetic_warped_circle(nslices): # get a subsampled circle - fname_cicle = get_data('reg_o') + fname_cicle = get_fnames('reg_o') circle = np.load(fname_cicle)[::4, ::4].astype(floating) # create a synthetic invertible map and warp the circle @@ -695,7 +695,7 @@ def test_cc_2d(): it under a synthetic invertible map. We verify that the final registration is of good quality. ''' - fname = get_data('t1_coronal_slice') + fname = get_fnames('t1_coronal_slice') nslices = 1 b = 0.1 m = 4 @@ -732,7 +732,7 @@ def test_cc_3d(): invertible map. We verify that the final registration is of good quality. ''' - fname = get_data('t1_coronal_slice') + fname = get_fnames('t1_coronal_slice') nslices = 21 b = 0.1 m = 4 @@ -782,7 +782,7 @@ def test_em_3d_gauss_newton(): invertible map. We verify that the final registration is of good quality. ''' - fname = get_data('t1_coronal_slice') + fname = get_fnames('t1_coronal_slice') nslices = 21 b = 0.1 m = 4 @@ -835,7 +835,7 @@ def test_em_2d_gauss_newton(): registration is of good quality. ''' - fname = get_data('t1_coronal_slice') + fname = get_fnames('t1_coronal_slice') nslices = 1 b = 0.1 m = 4 @@ -876,7 +876,7 @@ def test_em_3d_demons(): invertible map. We verify that the final registration is of good quality. ''' - fname = get_data('t1_coronal_slice') + fname = get_fnames('t1_coronal_slice') nslices = 21 b = 0.1 m = 4 @@ -928,7 +928,7 @@ def test_em_2d_demons(): it under a synthetic invertible map. We verify that the final registration is of good quality. ''' - fname = get_data('t1_coronal_slice') + fname = get_fnames('t1_coronal_slice') nslices = 1 b = 0.1 m = 4 diff --git a/dipy/align/tests/test_parzenhist.py b/dipy/align/tests/test_parzenhist.py index fc5268d406..b08299523c 100644 --- a/dipy/align/tests/test_parzenhist.py +++ b/dipy/align/tests/test_parzenhist.py @@ -3,7 +3,7 @@ from functools import reduce from operator import mul from dipy.core.ndindex import ndindex -from dipy.data import get_data +from dipy.data import get_fnames from dipy.align import vector_fields as vf from dipy.align.transforms import regtransforms from dipy.align.parzenhist import (ParzenJointHistogram, @@ -279,7 +279,7 @@ def setup_random_transform(transform, rfactor, nslices=45, sigma=1): np.random.seed(3147702) zero_slices = nslices // 3 - fname = get_data('t1_coronal_slice') + fname = get_fnames('t1_coronal_slice') moving_slice = np.load(fname) moving_slice = moving_slice[40:180, 50:210] diff --git a/dipy/align/tests/test_reslice.py b/dipy/align/tests/test_reslice.py index 047542a530..d2374497a6 100644 --- a/dipy/align/tests/test_reslice.py +++ b/dipy/align/tests/test_reslice.py @@ -4,13 +4,13 @@ assert_, assert_equal, assert_almost_equal) -from dipy.data import get_data +from dipy.data import get_fnames from dipy.align.reslice import reslice from dipy.denoise.noise_estimate import estimate_sigma def test_resample(): - fimg, _, _ = get_data("small_25") + fimg, _, _ = get_fnames("small_25") img = nib.load(fimg) data = img.get_data() affine = img.affine diff --git a/dipy/align/tests/test_streamlinear.py b/dipy/align/tests/test_streamlinear.py index 2991b71dcf..c50cbf4df1 100644 --- a/dipy/align/tests/test_streamlinear.py +++ b/dipy/align/tests/test_streamlinear.py @@ -22,7 +22,7 @@ from dipy.core.geometry import compose_matrix -from dipy.data import get_data, two_cingulum_bundles +from dipy.data import get_fnames, two_cingulum_bundles from nibabel import trackvis as tv from dipy.align.bundlemin import (_bundle_minimum_distance_matrix, _bundle_minimum_distance, @@ -45,7 +45,7 @@ def simulated_bundle(no_streamlines=10, waves=False, no_pts=12): def fornix_streamlines(no_pts=12): - fname = get_data('fornix') + fname = get_fnames('fornix') streams, hdr = tv.read(fname) streamlines = [set_number_of_points(i[0], no_pts) for i in streams] return streamlines diff --git a/dipy/align/tests/test_whole_brain_slr.py b/dipy/align/tests/test_whole_brain_slr.py index 56f9ecb324..89fa147f19 100644 --- a/dipy/align/tests/test_whole_brain_slr.py +++ b/dipy/align/tests/test_whole_brain_slr.py @@ -2,16 +2,16 @@ import nibabel as nib from numpy.testing import (assert_equal, run_module_suite, assert_array_almost_equal) -from dipy.data import get_data +from dipy.data import get_fnames from dipy.tracking.streamline import Streamlines -from dipy.align.streamlinear import whole_brain_slr, slr_with_qb +from dipy.align.streamlinear import whole_brain_slr, slr_with_qbx from dipy.tracking.distances import bundles_distances_mam from dipy.align.streamlinear import transform_streamlines from dipy.align.streamlinear import compose_matrix44, decompose_matrix44 def test_whole_brain_slr(): - streams, hdr = nib.trackvis.read(get_data('fornix')) + streams, hdr = nib.trackvis.read(get_fnames('fornix')) fornix = [s[0] for s in streams] f = Streamlines(fornix) @@ -22,8 +22,9 @@ def test_whole_brain_slr(): f2._data += np.array([50, 0, 0]) moved, transform, qb_centroids1, qb_centroids2 = whole_brain_slr( - f1, f2, verbose=True, rm_small_clusters=2, greater_than=0, - less_than=np.inf, qb_thr=5, progressive=False) + f1, f2, x0='affine', verbose=True, rm_small_clusters=2, + greater_than=0, less_than=np.inf, + qbx_thr=[5, 2, 1], progressive=False) # we can check the quality of registration by comparing the matrices # MAM streamline distances before and after SLR @@ -33,31 +34,37 @@ def test_whole_brain_slr(): d12_minsum = np.sum(np.min(D12, axis=0)) d1m_minsum = np.sum(np.min(D1M, axis=0)) + print("distances= ", d12_minsum, " ", d1m_minsum) + assert_equal(d1m_minsum < d12_minsum, True) - assert_array_almost_equal(transform[:3, 3], [-50, -0, -0], 3) + assert_array_almost_equal(transform[:3, 3], [-50, -0, -0], 2) # check rotation + mat = compose_matrix44([0, 0, 0, 15, 0, 0]) f3 = f.copy() f3 = transform_streamlines(f3, mat) - moved, transform, qb_centroids1, qb_centroids2 = slr_with_qb( + moved, transform, qb_centroids1, qb_centroids2 = slr_with_qbx( f1, f3, verbose=False, rm_small_clusters=1, greater_than=20, - less_than=np.inf, qb_thr=2, progressive=True) + less_than=np.inf, qbx_thr=[2], + progressive=True) # we can also check the quality by looking at the decomposed transform + assert_array_almost_equal(decompose_matrix44(transform)[3], -15, 2) - moved, transform, qb_centroids1, qb_centroids2 = slr_with_qb( + moved, transform, qb_centroids1, qb_centroids2 = slr_with_qbx( f1, f3, verbose=False, rm_small_clusters=1, select_random=400, - greater_than=20, - less_than=np.inf, qb_thr=2, progressive=True) + greater_than=20, less_than=np.inf, qbx_thr=[2], + progressive=True) # we can also check the quality by looking at the decomposed transform - assert_array_almost_equal(decompose_matrix44(transform)[3], -15, 2) + assert_array_almost_equal(decompose_matrix44(transform)[3], -15, 2) if __name__ == '__main__': - run_module_suite() + # run_module_suite() + test_whole_brain_slr() diff --git a/dipy/core/gradients.py b/dipy/core/gradients.py index 417dedf650..92072c7d0d 100644 --- a/dipy/core/gradients.py +++ b/dipy/core/gradients.py @@ -1,4 +1,5 @@ from __future__ import division, print_function, absolute_import +from warnings import warn from dipy.utils.six import string_types @@ -14,6 +15,8 @@ from dipy.core.geometry import vector_norm from dipy.core.sphere import disperse_charges, HemiSphere +WATER_GYROMAGNETIC_RATIO = 267.513e6 # 1/(sT) + class GradientTable(object): """Diffusion gradient information @@ -57,7 +60,7 @@ class GradientTable(object): """ def __init__(self, gradients, big_delta=None, small_delta=None, - b0_threshold=0): + b0_threshold=50): """Constructor for GradientTable class""" gradients = np.asarray(gradients) if gradients.ndim != 2 or gradients.shape[1] != 3: @@ -73,11 +76,23 @@ def __init__(self, gradients, big_delta=None, small_delta=None, def bvals(self): return vector_norm(self.gradients) + @auto_attr + def tau(self): + return self.big_delta - self.small_delta / 3.0 + @auto_attr def qvals(self): tau = self.big_delta - self.small_delta / 3.0 return np.sqrt(self.bvals / tau) / (2 * np.pi) + @auto_attr + def gradient_strength(self): + tau = self.big_delta - self.small_delta / 3.0 + qvals = np.sqrt(self.bvals / tau) / (2 * np.pi) + gradient_strength = (qvals * (2 * np.pi) / + (self.small_delta * WATER_GYROMAGNETIC_RATIO)) + return gradient_strength + @auto_attr def b0s_mask(self): return self.bvals <= self.b0_threshold @@ -100,7 +115,7 @@ def info(self): print(' max %f ' % self.bvecs.max()) -def gradient_table_from_bvals_bvecs(bvals, bvecs, b0_threshold=0, atol=1e-2, +def gradient_table_from_bvals_bvecs(bvals, bvecs, b0_threshold=50, atol=1e-2, **kwargs): """Creates a GradientTable from a bvals array and a bvecs array @@ -141,6 +156,19 @@ def gradient_table_from_bvals_bvecs(bvals, bvecs, b0_threshold=0, atol=1e-2, raise ValueError("bvals and bvecs should be (N,) and (N, 3) arrays " "respectively, where N is the number of diffusion " "gradients") + # checking for negative bvals + if b0_threshold < 0: + raise ValueError("Negative bvals in the data are not feasible") + + # Upper bound for the b0_threshold + if b0_threshold >= 200: + warn("b0_threshold has a value > 199") + + # checking for the correctness of bvals + if b0_threshold < bvals.min(): + warn("b0_threshold (value: {0}) is too low, increase your \ + b0_threshold. It should higher than the first b0 value \ + ({1}).".format(b0_threshold, bvals.min())) bvecs = np.where(np.isnan(bvecs), 0, bvecs) bvecs_close_to_1 = abs(vector_norm(bvecs) - 1) <= atol @@ -162,8 +190,163 @@ def gradient_table_from_bvals_bvecs(bvals, bvecs, b0_threshold=0, atol=1e-2, return grad_table +def gradient_table_from_qvals_bvecs(qvals, bvecs, big_delta, small_delta, + b0_threshold=50, atol=1e-2): + """A general function for creating diffusion MR gradients. + + It reads, loads and prepares scanner parameters like the b-values and + b-vectors so that they can be useful during the reconstruction process. + + Parameters + ---------- + + qvals : an array of shape (N,), + q-value given in 1/mm + + bvecs : can be any of two options + + 1. an array of shape (N, 3) or (3, N) with the b-vectors. + 2. a path for the file which contains an array like the previous. + + big_delta : float or array of shape (N,) + acquisition pulse separation time in seconds + + small_delta : float + acquisition pulse duration time in seconds + + b0_threshold : float + All b-values with values less than or equal to `bo_threshold` are + considered as b0s i.e. without diffusion weighting. + + atol : float + All b-vectors need to be unit vectors up to a tolerance. + + Returns + ------- + gradients : GradientTable + A GradientTable with all the gradient information. + + Examples + -------- + >>> from dipy.core.gradients import gradient_table_from_qvals_bvecs + >>> qvals = 30. * np.ones(7) + >>> big_delta = .03 # pulse separation of 30ms + >>> small_delta = 0.01 # pulse duration of 10ms + >>> qvals[0] = 0 + >>> sq2 = np.sqrt(2) / 2 + >>> bvecs = np.array([[0, 0, 0], + ... [1, 0, 0], + ... [0, 1, 0], + ... [0, 0, 1], + ... [sq2, sq2, 0], + ... [sq2, 0, sq2], + ... [0, sq2, sq2]]) + >>> gt = gradient_table_from_qvals_bvecs(qvals, bvecs, + ... big_delta, small_delta) + + Notes + ----- + 1. Often b0s (b-values which correspond to images without diffusion + weighting) have 0 values however in some cases the scanner cannot + provide b0s of an exact 0 value and it gives a bit higher values + e.g. 6 or 12. This is the purpose of the b0_threshold in the __init__. + 2. We assume that the minimum number of b-values is 7. + 3. B-vectors should be unit vectors. + + """ + qvals = np.asarray(qvals) + bvecs = np.asarray(bvecs) + + if (bvecs.shape[1] > bvecs.shape[0]) and bvecs.shape[0] > 1: + bvecs = bvecs.T + bvals = (qvals * 2 * np.pi) ** 2 * (big_delta - small_delta / 3.) + return gradient_table_from_bvals_bvecs(bvals, bvecs, big_delta=big_delta, + small_delta=small_delta, + b0_threshold=b0_threshold, + atol=atol) + + +def gradient_table_from_gradient_strength_bvecs(gradient_strength, bvecs, + big_delta, small_delta, + b0_threshold=50, atol=1e-2): + """A general function for creating diffusion MR gradients. + + It reads, loads and prepares scanner parameters like the b-values and + b-vectors so that they can be useful during the reconstruction process. + + Parameters + ---------- + + gradient_strength : an array of shape (N,), + gradient strength given in T/mm + + bvecs : can be any of two options + + 1. an array of shape (N, 3) or (3, N) with the b-vectors. + 2. a path for the file which contains an array like the previous. + + big_delta : float or array of shape (N,) + acquisition pulse separation time in seconds + + small_delta : float + acquisition pulse duration time in seconds + + b0_threshold : float + All b-values with values less than or equal to `bo_threshold` are + considered as b0s i.e. without diffusion weighting. + + atol : float + All b-vectors need to be unit vectors up to a tolerance. + + Returns + ------- + gradients : GradientTable + A GradientTable with all the gradient information. + + Examples + -------- + >>> from dipy.core.gradients import ( + ... gradient_table_from_gradient_strength_bvecs) + >>> gradient_strength = .03e-3 * np.ones(7) # clinical strength at 30 mT/m + >>> big_delta = .03 # pulse separation of 30ms + >>> small_delta = 0.01 # pulse duration of 10ms + >>> gradient_strength[0] = 0 + >>> sq2 = np.sqrt(2) / 2 + >>> bvecs = np.array([[0, 0, 0], + ... [1, 0, 0], + ... [0, 1, 0], + ... [0, 0, 1], + ... [sq2, sq2, 0], + ... [sq2, 0, sq2], + ... [0, sq2, sq2]]) + >>> gt = gradient_table_from_gradient_strength_bvecs( + ... gradient_strength, bvecs, big_delta, small_delta) + + Notes + ----- + 1. Often b0s (b-values which correspond to images without diffusion + weighting) have 0 values however in some cases the scanner cannot + provide b0s of an exact 0 value and it gives a bit higher values + e.g. 6 or 12. This is the purpose of the b0_threshold in the __init__. + 2. We assume that the minimum number of b-values is 7. + 3. B-vectors should be unit vectors. + + """ + gradient_strength = np.asarray(gradient_strength) + bvecs = np.asarray(bvecs) + if (bvecs.shape[1] > bvecs.shape[0]) and bvecs.shape[0] > 1: + bvecs = bvecs.T + qvals = gradient_strength * small_delta * WATER_GYROMAGNETIC_RATIO /\ + (2 * np.pi) + bvals = (qvals * 2 * np.pi) ** 2 * (big_delta - small_delta / 3.) + return gradient_table_from_bvals_bvecs(bvals, bvecs, big_delta=big_delta, + small_delta=small_delta, + b0_threshold=b0_threshold, + atol=atol) + + def gradient_table(bvals, bvecs=None, big_delta=None, small_delta=None, - b0_threshold=0, atol=1e-2): + b0_threshold=50, atol=1e-2): """A general function for creating diffusion MR gradients. It reads, loads and prepares scanner parameters like the b-values and @@ -187,10 +370,10 @@ def gradient_table(bvals, bvecs=None, big_delta=None, small_delta=None, 2. a path for the file which contains an array like the previous. big_delta : float - acquisition timing duration (default None) + acquisition pulse separation time in seconds (default None) small_delta : float - acquisition timing duration (default None) + acquisition pulse duration time in seconds (default None) b0_threshold : float All b-values with values less than or equal to `bo_threshold` are @@ -207,16 +390,16 @@ def gradient_table(bvals, bvecs=None, big_delta=None, small_delta=None, Examples -------- >>> from dipy.core.gradients import gradient_table - >>> bvals=1500*np.ones(7) - >>> bvals[0]=0 - >>> sq2=np.sqrt(2)/2 - >>> bvecs=np.array([[0, 0, 0], - ... [1, 0, 0], - ... [0, 1, 0], - ... [0, 0, 1], - ... [sq2, sq2, 0], - ... [sq2, 0, sq2], - ... [0, sq2, sq2]]) + >>> bvals = 1500 * np.ones(7) + >>> bvals[0] = 0 + >>> sq2 = np.sqrt(2) / 2 + >>> bvecs = np.array([[0, 0, 0], + ... [1, 0, 0], + ... [0, 1, 0], + ... [0, 0, 1], + ... [sq2, sq2, 0], + ... [sq2, 0, sq2], + ... [0, sq2, sq2]]) >>> gt = gradient_table(bvals, bvecs) >>> gt.bvecs.shape == bvecs.shape True @@ -243,6 +426,7 @@ def gradient_table(bvals, bvecs=None, big_delta=None, small_delta=None, _, bvecs = io.read_bvals_bvecs(None, bvecs) bvals = np.asarray(bvals) + # If bvecs is None we expect bvals to be an (N, 4) or (4, N) array. if bvecs is None: if bvals.shape[-1] == 4: diff --git a/dipy/core/tests/test_gradients.py b/dipy/core/tests/test_gradients.py index c0edde361b..4838b45c88 100644 --- a/dipy/core/tests/test_gradients.py +++ b/dipy/core/tests/test_gradients.py @@ -1,12 +1,15 @@ import warnings -from nose.tools import assert_true, assert_raises +from nose.tools import assert_raises import numpy as np import numpy.testing as npt -from dipy.data import get_data +from dipy.data import get_fnames from dipy.core.gradients import (gradient_table, GradientTable, gradient_table_from_bvals_bvecs, + gradient_table_from_qvals_bvecs, + gradient_table_from_gradient_strength_bvecs, + WATER_GYROMAGNETIC_RATIO, reorient_bvecs, generate_bvecs, check_multi_b) from dipy.io.gradients import read_bvals_bvecs @@ -27,9 +30,8 @@ def test_btable_prepare(): bt = gradient_table(bvals, bvecs) npt.assert_array_equal(bt.bvecs, bvecs) # bt.info - fimg, fbvals, fbvecs = get_data('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + fimg, fbvals, fbvecs = get_fnames('small_64D') + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) bvecs = np.where(np.isnan(bvecs), 0, bvecs) bt = gradient_table(bvals, bvecs) npt.assert_array_equal(bt.bvecs, bvecs) @@ -69,9 +71,62 @@ def test_GradientTable(): npt.assert_array_equal(gt.bvals, expected_bvals) npt.assert_array_equal(gt.bvecs, expected_bvecs) + # checks negative values in gtab + npt.assert_raises(ValueError, GradientTable, -1) npt.assert_raises(ValueError, GradientTable, np.ones((6, 2))) npt.assert_raises(ValueError, GradientTable, np.ones((6,))) + with warnings.catch_warnings(record=True) as w: + bad_gt = gradient_table(expected_bvals, expected_bvecs, + b0_threshold=200) + assert len(w) == 1 + + +def test_gradient_table_from_qvals_bvecs(): + qvals = 30. * np.ones(7) + big_delta = .03 # pulse separation of 30ms + small_delta = 0.01 # pulse duration of 10ms + qvals[0] = 0 + sq2 = np.sqrt(2) / 2 + bvecs = np.array([[0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [sq2, sq2, 0], + [sq2, 0, sq2], + [0, sq2, sq2]]) + gt = gradient_table_from_qvals_bvecs(qvals, bvecs, + big_delta, small_delta) + + bvals_expected = (qvals * 2 * np.pi) ** 2 * (big_delta - small_delta / 3.) + gradient_strength_expected = qvals * 2 * np.pi /\ + (small_delta * WATER_GYROMAGNETIC_RATIO) + npt.assert_almost_equal(gt.gradient_strength, gradient_strength_expected) + npt.assert_almost_equal(gt.bvals, bvals_expected) + + +def test_gradient_table_from_gradient_strength_bvecs(): + gradient_strength = .03e-3 * np.ones(7) # clinical strength at 30 mT/m + big_delta = .03 # pulse separation of 30ms + small_delta = 0.01 # pulse duration of 10ms + gradient_strength[0] = 0 + sq2 = np.sqrt(2) / 2 + bvecs = np.array([[0, 0, 0], + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [sq2, sq2, 0], + [sq2, 0, sq2], + [0, sq2, sq2]]) + gt = gradient_table_from_gradient_strength_bvecs(gradient_strength, bvecs, + big_delta, small_delta) + qvals_expected = (gradient_strength * WATER_GYROMAGNETIC_RATIO * + small_delta / (2 * np.pi)) + bvals_expected = (qvals_expected * 2 * np.pi) ** 2 *\ + (big_delta - small_delta / 3.) + npt.assert_almost_equal(gt.qvals, qvals_expected) + npt.assert_almost_equal(gt.bvals, bvals_expected) + def test_gradient_table_from_bvals_bvecs(): @@ -104,6 +159,10 @@ def test_gradient_table_from_bvals_bvecs(): bvecs, b0_threshold=0.) # num_gard inconsistent bvals, bvecs bad_bvals = np.ones(7) + npt.assert_raises(ValueError, gradient_table_from_bvals_bvecs, bad_bvals, + bvecs, b0_threshold=0.) + # negative bvals + bad_bvals = [-1, -1, -1, -5, -6, -10] npt.assert_raises(ValueError, gradient_table_from_bvals_bvecs, bad_bvals, bvecs, b0_threshold=0.) # bvals not 1d @@ -150,7 +209,7 @@ def test_b0s(): def test_gtable_from_files(): - fimg, fbvals, fbvecs = get_data('small_101D') + fimg, fbvals, fbvecs = get_fnames('small_101D') gt = gradient_table(fbvals, fbvecs) bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) npt.assert_array_equal(gt.bvals, bvals) @@ -259,7 +318,7 @@ def test_nan_bvecs(): indicate a 0 b-value, but also raised a warning when testing for the length of these vectors. This checks that it doesn't happen. """ - fdata, fbvals, fbvecs = get_data() + fdata, fbvals, fbvecs = get_fnames() with warnings.catch_warnings(record=True) as w: gradient_table(fbvals, fbvecs) npt.assert_(len(w) == 0) diff --git a/dipy/data/__init__.py b/dipy/data/__init__.py index 723ba2a2ce..e801f011db 100644 --- a/dipy/data/__init__.py +++ b/dipy/data/__init__.py @@ -32,8 +32,6 @@ read_stanford_t1, fetch_stanford_pve_maps, read_stanford_pve_maps, - fetch_viz_icons, - read_viz_icons, fetch_bundles_2_subjects, read_bundles_2_subjects, fetch_cenir_multib, @@ -46,7 +44,11 @@ read_tissue_data, fetch_cfin_multib, read_cfin_dwi, - read_cfin_t1) + read_cfin_t1, + fetch_target_tractogram_hcp, + fetch_bundle_atlas_hcp842, + get_bundle_atlas_hcp842, + get_target_tractogram_hcp) from ..utils.arrfuncs import as_native_array from dipy.tracking.streamline import relist_streamlines @@ -200,7 +202,7 @@ def get_sphere(name='symmetric362'): small_sphere = HemiSphere.from_sphere(get_sphere('symmetric362')) -def get_data(name='small_64D'): +def get_fnames(name='small_64D'): """ provides filenames of some test datasets or other useful parametrisations Parameters @@ -228,8 +230,8 @@ def get_data(name='small_64D'): Examples ---------- >>> import numpy as np - >>> from dipy.data import get_data - >>> fimg,fbvals,fbvecs=get_data('small_101D') + >>> from dipy.data import get_fnames + >>> fimg,fbvals,fbvecs=get_fnames('small_101D') >>> bvals=np.loadtxt(fbvals) >>> bvecs=np.loadtxt(fbvecs).T >>> import nibabel as nib @@ -244,8 +246,8 @@ def get_data(name='small_64D'): """ if name == 'small_64D': - fbvals = pjoin(DATA_DIR, 'small_64D.bvals.npy') - fbvecs = pjoin(DATA_DIR, 'small_64D.gradients.npy') + fbvals = pjoin(DATA_DIR, 'small_64D.bval') + fbvecs = pjoin(DATA_DIR, 'small_64D.bvec') fimg = pjoin(DATA_DIR, 'small_64D.nii') return fimg, fbvals, fbvecs if name == '55dir_grad.bvec': @@ -290,6 +292,16 @@ def get_data(name='small_64D'): return pjoin(DATA_DIR, 't1_coronal_slice.npy') +def get_data(name='small_64D'): + """Deprecate function.""" + warnings.warn("The `dipy.data.get_data` function is deprecated as of" + + " version 0.15 of Dipy and will be removed in a future" + + " version. Please use `dipy.data.get_fnames` function" + + " instead", + DeprecationWarning) + return get_fnames(name) + + def _gradient_from_file(filename): """Reads a gradient file saved as a text file compatible with np.loadtxt and saved in the dipy data directory""" @@ -307,7 +319,7 @@ def gtab_getter(): def dsi_voxels(): - fimg, fbvals, fbvecs = get_data('small_101D') + fimg, fbvals, fbvecs = get_fnames('small_101D') bvals = np.loadtxt(fbvals) bvecs = np.loadtxt(fbvecs).T img = load(fimg) @@ -317,7 +329,7 @@ def dsi_voxels(): def dsi_deconv_voxels(): - gtab = gradient_table(np.loadtxt(get_data('dsi515btable'))) + gtab = gradient_table(np.loadtxt(get_fnames('dsi515btable'))) data = np.zeros((2, 2, 2, 515)) for ix in range(2): for iy in range(2): @@ -398,7 +410,7 @@ def simple_cmap(v): def two_cingulum_bundles(): - fname = get_data('cb_2') + fname = get_fnames('cb_2') res = np.load(fname) cb1 = relist_streamlines(res['points'], res['offsets']) cb2 = relist_streamlines(res['points2'], res['offsets2']) diff --git a/dipy/data/fetcher.py b/dipy/data/fetcher.py index 139ae41379..659da49f75 100644 --- a/dipy/data/fetcher.py +++ b/dipy/data/fetcher.py @@ -1,5 +1,5 @@ +# -*- coding: utf-8 -*- from __future__ import division, print_function, absolute_import - import os import sys import contextlib @@ -13,7 +13,8 @@ import tarfile import zipfile -from dipy.core.gradients import gradient_table +from dipy.core.gradients import (gradient_table, + gradient_table_from_gradient_strength_bvecs) from dipy.io.gradients import read_bvals_bvecs if sys.version_info[0] < 3: @@ -21,7 +22,6 @@ else: from urllib.request import urlopen - # Set a user-writeable file-system location to put files: if 'DIPY_HOME' in os.environ: dipy_home = os.environ['DIPY_HOME'] @@ -30,7 +30,8 @@ # The URL to the University of Washington Researchworks repository: UW_RW_URL = \ - "https://digital.lib.washington.edu/researchworks/bitstream/handle/" + "https://digital.lib.washington.edu/researchworks/bitstream/handle/" + class FetcherError(Exception): pass @@ -123,15 +124,11 @@ def check_md5(filename, stored_md5=None): def _get_file_data(fname, url): with contextlib.closing(urlopen(url)) as opener: - if sys.version_info[0] < 3: - try: - response_size = opener.headers['content-length'] - except KeyError: - response_size = None - else: - # python3.x - # returns none if header not found - response_size = opener.getheader("Content-Length") + try: + response_size = opener.headers['content-length'] + except KeyError: + response_size = None + with open(fname, 'wb') as data: if(response_size is None): copyfileobj(opener, data) @@ -243,6 +240,7 @@ def fetcher(): raise ValueError('File extension is not recognized') elif split_ext[-1] == '.zip': z = zipfile.ZipFile(pjoin(folder, f), 'r') + files[f] += (tuple(z.namelist()), ) z.extractall(folder) z.close() else: @@ -331,8 +329,8 @@ def fetcher(): 'a95eb1be44748c20214dc7aa654f9e6b', '7fa1d5e272533e832cc7453eeba23f44'], doc="Download a DSI dataset with 203 gradient directions", - msg="See DSI203_license.txt for LICENSE. For the complete datasets please visit : \ - http://dsi-studio.labsolver.org", + msg="See DSI203_license.txt for LICENSE. For the complete datasets" + + " please visit http://dsi-studio.labsolver.org", data_size="91MB") fetch_syn_data = _make_fetcher( @@ -371,22 +369,12 @@ def fetcher(): UW_RW_URL + "1773/38479/", ['datasets_multi-site_all_companies.zip'], ['datasets_multi-site_all_companies.zip'], - None, - doc="Download b=0 datasets from multiple MR systems (GE, Philips, Siemens) \ - and different magnetic fields (1.5T and 3T)", + ["e9810fa5bf21b99da786647994d7d5b7"], + doc="Download b=0 datasets from multiple MR systems (GE, Philips, " + + "Siemens) and different magnetic fields (1.5T and 3T)", data_size="9.2MB", unzip=True) -fetch_viz_icons = _make_fetcher("fetch_viz_icons", - pjoin(dipy_home, "icons"), - UW_RW_URL + "1773/38478/", - ['icomoon.tar.gz'], - ['icomoon.tar.gz'], - ['94a07cba06b4136b6687396426f1e380'], - data_size="12KB", - doc="Download icons for dipy.viz", - unzip=True) - fetch_bundles_2_subjects = _make_fetcher( "fetch_bundles_2_subjects", pjoin(dipy_home, 'exp_bundles_and_maps'), @@ -430,6 +418,137 @@ def fetcher(): " More details about the data are available in their paper: " + " https://www.nature.com/articles/sdata201672")) +fetch_bundle_atlas_hcp842 = _make_fetcher( + "fetch_bundle_atlas_hcp842", + pjoin(dipy_home, 'bundle_atlas_hcp842'), + 'https://ndownloader.figshare.com/files/', + ['13638644'], + ['Atlas_80_Bundles.zip'], + ['78331d527a10ec000d4f33bac472e099'], + doc="Download atlas tractogram from the hcp842 dataset with 80 bundles", + data_size="200MB", + unzip=True) + +fetch_target_tractogram_hcp = _make_fetcher( + "fetch_target_tractogram_hcp", + pjoin(dipy_home, 'target_tractogram_hcp'), + 'https://ndownloader.figshare.com/files/', + ['12871127'], + ['hcp_tractogram.zip'], + ['fa25ef19c9d3748929b6423397963b6a'], + doc="Download tractogram of one of the hcp dataset subjects", + data_size="541MB", + unzip=True) + + +fetch_qtdMRI_test_retest_2subjects = _make_fetcher( + "fetch_qtdMRI_test_retest_2subjects", + pjoin(dipy_home, 'qtdMRI_test_retest_2subjects'), + 'https://zenodo.org/record/996889/files/', + ['subject1_dwis_test.nii.gz', 'subject2_dwis_test.nii.gz', + 'subject1_dwis_retest.nii.gz', 'subject2_dwis_retest.nii.gz', + 'subject1_ccmask_test.nii.gz', 'subject2_ccmask_test.nii.gz', + 'subject1_ccmask_retest.nii.gz', 'subject2_ccmask_retest.nii.gz', + 'subject1_scheme_test.txt', 'subject2_scheme_test.txt', + 'subject1_scheme_retest.txt', 'subject2_scheme_retest.txt'], + ['subject1_dwis_test.nii.gz', 'subject2_dwis_test.nii.gz', + 'subject1_dwis_retest.nii.gz', 'subject2_dwis_retest.nii.gz', + 'subject1_ccmask_test.nii.gz', 'subject2_ccmask_test.nii.gz', + 'subject1_ccmask_retest.nii.gz', 'subject2_ccmask_retest.nii.gz', + 'subject1_scheme_test.txt', 'subject2_scheme_test.txt', + 'subject1_scheme_retest.txt', 'subject2_scheme_retest.txt'], + ['ebd7441f32c40e25c28b9e069bd81981', + 'dd6a64dd68c8b321c75b9d5fb42c275a', + '830a7a028a66d1b9812f93309a3f9eae', + 'd7f1951e726c35842f7ea0a15d990814', + 'ddb8dfae908165d5e82c846bcc317cab', + '5630c06c267a0f9f388b07b3e563403c', + '02e9f92b31e8980f658da99e532e14b5', + '6e7ce416e7cfda21cecce3731f81712b', + '957cb969f97d89e06edd7a04ffd61db0', + '5540c0c9bd635c29fc88dd599cbbf5e6', + '5540c0c9bd635c29fc88dd599cbbf5e6', + '5540c0c9bd635c29fc88dd599cbbf5e6'], + doc="Downloads test-retest qt-dMRI acquisitions of two C57Bl6 mice.", + data_size="298.2MB") + + +def read_qtdMRI_test_retest_2subjects(): + """ Load test-retest qt-dMRI acquisitions of two C57Bl6 mice. These + datasets were used to study test-retest reproducibility of time-dependent + q-space indices (q$\tau$-indices) in the corpus callosum of two mice [1]. + The data itself and its details are publicly available and can be cited at + [2]. + + The test-retest diffusion MRI spin echo sequences were acquired from two + C57Bl6 wild-type mice on an 11.7 Tesla Bruker scanner. The test and retest + acquisition were taken 48 hours from each other. The (processed) data + consists of 80x160x5 voxels of size 110x110x500μm. Each data set consists + of 515 Diffusion-Weighted Images (DWIs) spread over 35 acquisition shells. + The shells are spread over 7 gradient strength shells with a maximum + gradient strength of 491 mT/m, 5 pulse separation shells between + [10.8 - 20.0]ms, and a pulse length of 5ms. We manually created a brain + mask and corrected the data from eddy currents and motion artifacts using + FSL's eddy. A region of interest was then drawn in the middle slice in the + corpus callosum, where the tissue is reasonably coherent. + + Returns + ------- + data : list of length 4 + contains the dwi datasets ordered as + (subject1_test, subject1_retest, subject2_test, subject2_retest) + cc_masks : list of length 4 + contains the corpus callosum masks ordered in the same order as data. + gtabs : list of length 4 + contains the qt-dMRI gradient tables of the data sets. + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + .. [2] Wassermann, Demian, et al., "Test-Retest qt-dMRI datasets for + `Non-Parametric GraphNet-Regularized Representation of dMRI in Space + and Time'". doi:10.5281/zenodo.996889, 2017. + """ + data = [] + data_names = [ + 'subject1_dwis_test.nii.gz', 'subject1_dwis_retest.nii.gz', + 'subject2_dwis_test.nii.gz', 'subject2_dwis_retest.nii.gz' + ] + for data_name in data_names: + data_loc = pjoin(dipy_home, 'qtdMRI_test_retest_2subjects', data_name) + data.append(nib.load(data_loc).get_data()) + + cc_masks = [] + mask_names = [ + 'subject1_ccmask_test.nii.gz', 'subject1_ccmask_retest.nii.gz', + 'subject2_ccmask_test.nii.gz', 'subject2_ccmask_retest.nii.gz' + ] + for mask_name in mask_names: + mask_loc = pjoin(dipy_home, 'qtdMRI_test_retest_2subjects', mask_name) + cc_masks.append(nib.load(mask_loc).get_data()) + + gtabs = [] + gtab_txt_names = [ + 'subject1_scheme_test.txt', 'subject1_scheme_retest.txt', + 'subject2_scheme_test.txt', 'subject2_scheme_retest.txt' + ] + for gtab_txt_name in gtab_txt_names: + txt_loc = pjoin(dipy_home, 'qtdMRI_test_retest_2subjects', + gtab_txt_name) + qtdmri_scheme = np.loadtxt(txt_loc, skiprows=1) + bvecs = qtdmri_scheme[:, 1:4] + G = qtdmri_scheme[:, 4] / 1e3 # because dipy takes T/mm not T/m + small_delta = qtdmri_scheme[:, 5] + big_delta = qtdmri_scheme[:, 6] + gtab = gradient_table_from_gradient_strength_bvecs( + G, bvecs, big_delta, small_delta + ) + gtabs.append(gtab) + + return data, cc_masks, gtabs + def read_scil_b0(): """ Load GE 3T b0 image form the scil b0 dataset. @@ -722,9 +841,9 @@ def read_mni_template(version="a", contrast="T2"): Examples -------- Get only the T1 file for version c: - >>> T1_nifti = read_mni_template("c", contrast = "T1") # doctest: +SKIP + >>> T1 = read_mni_template("c", contrast = "T1") # doctest: +SKIP Get both files in this order for version a: - >>> T1_nifti, T2_nifti = read_mni_template(contrast = ["T1", "T2"]) # doctest: +SKIP + >>> T1, T2 = read_mni_template(contrast = ["T1", "T2"]) # doctest: +SKIP """ files, folder = fetch_mni_template() file_dict_a = {"T1": pjoin(folder, 'mni_icbm152_t1_tal_nlin_asym_09a.nii'), @@ -881,35 +1000,15 @@ def read_cenir_multib(bvals=None): Notes ----- Details of the acquisition and processing, and additional meta-data are - available through `UW researchworks `_ + available through UW researchworks: + + https://digital.lib.washington.edu/researchworks/handle/1773/33311 """ fetch_cenir_multib.__doc__ += CENIR_notes read_cenir_multib.__doc__ += CENIR_notes -def read_viz_icons(style='icomoon', fname='infinity.png'): - """ Read specific icon from specific style - - Parameters - ---------- - style : str - Current icon style. Default is icomoon. - fname : str - Filename of icon. This should be found in folder HOME/.dipy/style/. - Default is infinity.png. - - Returns - -------- - path : str - Complete path of icon. - - """ - - folder = pjoin(dipy_home, 'icons', style) - return pjoin(folder, fname) - - def read_bundles_2_subjects(subj_id='subj_1', metrics=['fa'], bundles=['af.left', 'cst.right', 'cc_1']): r""" Read images and streamlines from 2 subjects of the SNAIL dataset @@ -993,7 +1092,7 @@ def read_ivim(): fbval = pjoin(folder, 'ivim.bval') fbvec = pjoin(folder, 'ivim.bvec') bvals, bvecs = read_bvals_bvecs(fbval, fbvec) - gtab = gradient_table(bvals, bvecs) + gtab = gradient_table(bvals, bvecs, b0_threshold=0) img = nib.load(fraw) return img, gtab @@ -1031,4 +1130,62 @@ def read_cfin_t1(): """ files, folder = fetch_cfin_multib() img = nib.load(pjoin(folder, 'T1.nii')) - return img, gtab + return img # , gtab + + +def get_bundle_atlas_hcp842(): + """ + Returns + ------- + file1 : string + file2 : string + """ + file1 = pjoin(dipy_home, + 'bundle_atlas_hcp842', + 'Atlas_80_Bundles', + 'whole_brain', + 'whole_brain_MNI.trk') + + file2 = pjoin(dipy_home, + 'bundle_atlas_hcp842', + 'Atlas_80_Bundles', + 'bundles', + '*.trk') + + return file1, file2 + + +def get_two_hcp842_bundles(): + """ + Returns + ------- + file1 : string + file2 : string + """ + file1 = pjoin(dipy_home, + 'bundle_atlas_hcp842', + 'Atlas_80_Bundles', + 'bundles', + 'AF_L.trk') + + file2 = pjoin(dipy_home, + 'bundle_atlas_hcp842', + 'Atlas_80_Bundles', + 'bundles', + 'CST_L.trk') + + return file1, file2 + + +def get_target_tractogram_hcp(): + """ + Returns + ------- + file1 : string + """ + file1 = pjoin(dipy_home, + 'target_tractogram_hcp', + 'hcp_tractogram', + 'streamlines.trk') + + return file1 diff --git a/dipy/data/files/evenly_distributed_sphere_642.npz b/dipy/data/files/evenly_distributed_sphere_642.npz index b686b01485..4dc4a14970 100644 Binary files a/dipy/data/files/evenly_distributed_sphere_642.npz and b/dipy/data/files/evenly_distributed_sphere_642.npz differ diff --git a/dipy/data/files/small_64D.bval b/dipy/data/files/small_64D.bval new file mode 100644 index 0000000000..b28675d97d --- /dev/null +++ b/dipy/data/files/small_64D.bval @@ -0,0 +1 @@ +0.000000000000000000e+00 9.928797843126392308e+02 1.001021565029311773e+03 9.909633063747885444e+02 1.000364252774984038e+03 9.942512723242982702e+02 9.939778505212066193e+02 9.891889881986279534e+02 9.969196834317706362e+02 9.911624731427109509e+02 9.974664035236321524e+02 9.954073405684313229e+02 9.919624279819356616e+02 9.931252799365950068e+02 9.940771219847928251e+02 9.879731333782065121e+02 9.976620483924543805e+02 9.900065516055614125e+02 9.899225600178399418e+02 9.983197906107794779e+02 9.948038487748110583e+02 9.968387096356296979e+02 9.916486546586751274e+02 9.944719806636786643e+02 9.940168126899344543e+02 9.879607569811387293e+02 1.002991244056878372e+03 9.994929363816755767e+02 9.876152811875540465e+02 9.980621612786436572e+02 9.947304108200418113e+02 9.917164860613376050e+02 9.877203533987607216e+02 9.869461881512532955e+02 9.895954288021255252e+02 9.959811626110692941e+02 9.930680889749546623e+02 1.000572361150565143e+03 9.967157942552372560e+02 9.904673272499851464e+02 9.896968074768431052e+02 9.961932041683226089e+02 9.981179783129384759e+02 9.906313722879589250e+02 9.938537721133617424e+02 9.968091554088852035e+02 9.924976723400540095e+02 1.001038001192093816e+03 9.933728499281497761e+02 9.953140425466273200e+02 9.926488343299331518e+02 9.984048920718623776e+02 9.972655377164984429e+02 9.925551656113113950e+02 9.890137636780316370e+02 9.885720623682524320e+02 1.000291601730783896e+03 9.918921471607960711e+02 1.001110504195383328e+03 9.905137412739931051e+02 1.001481457968169707e+03 9.884648285274217869e+02 9.903733083706629259e+02 9.944549794894621755e+02 1.001693658211986531e+03 \ No newline at end of file diff --git a/dipy/data/files/small_64D.bvec b/dipy/data/files/small_64D.bvec new file mode 100644 index 0000000000..5eaf37a38b --- /dev/null +++ b/dipy/data/files/small_64D.bvec @@ -0,0 +1,65 @@ +nan nan nan +4.163478118279527636e-03 9.999827048187632794e-01 -4.153975602799726656e-03 +9.710771441530797743e-01 -9.949625405048536080e-04 2.387638794982227808e-01 +4.484975525965129717e-01 2.497431846103612477e-02 8.934350724772029961e-01 +8.065213506191166726e-01 5.887964846640970640e-01 -5.331051155933171776e-02 +7.115302898994344538e-01 -2.350396452768258593e-01 -6.621789876640383765e-01 +3.454708480709161589e-01 -8.926283878308921560e-01 -2.895936020902123986e-01 +2.289071300630692724e-02 7.977559449542579451e-01 -6.025458219490047451e-01 +8.361017508803255671e-01 -2.326122361001924099e-01 4.968152672687530247e-01 +6.117524509784533909e-02 -9.368291423024520670e-01 3.443962071801469071e-01 +7.798364496140078872e-01 5.044848155988271854e-01 3.706078556690835524e-01 +7.280092759402950753e-01 3.449800194759559679e-01 5.924451707181486171e-01 +4.638900575839798868e-01 4.567629595072385529e-01 7.590610076251582683e-01 +5.725407393389004840e-01 -4.866172062933480924e-01 -6.598490708764559454e-01 +5.581262946147673709e-01 6.171547684827388691e-01 5.546305355807656934e-01 +9.942856509510769603e-02 5.781386146941556170e-01 -8.098578286604697363e-01 +5.603713195256103674e-01 -8.248798185830820140e-01 -7.454709348772665944e-02 +7.413641793937422730e-02 -8.949442633413948744e-01 -4.399756323337084551e-01 +3.373237308808803570e-01 2.898505589550048889e-01 8.956558234378173555e-01 +8.768163366278239890e-01 1.152814573980816548e-01 4.668011326065273914e-01 +5.035104108851373717e-01 7.995291679813333330e-01 -3.274604948346549471e-01 +7.757347240068448446e-01 -5.124427971198213250e-01 3.682906700556473623e-01 +2.976885732092064418e-01 7.892763288277961919e-01 -5.370515712040915268e-01 +2.819504777234194681e-01 9.486619923932395615e-01 -1.433330118989497026e-01 +6.232379823719625955e-01 -2.320164922113039652e-01 7.468217757074890883e-01 +6.050338085237734476e-02 1.947738415084017405e-02 -9.979779418464482799e-01 +9.759409902906680534e-01 2.151818865167165751e-01 3.515592674894449376e-02 +6.353466709480807273e-01 7.715365788614170217e-01 -3.264835668164086518e-02 +1.244764298570132099e-01 1.599343550651121104e-01 9.792479872228273541e-01 +8.743126749329229730e-01 1.464436261515695559e-01 -4.627435691732692535e-01 +3.617517941136363935e-01 -8.877710659259927528e-01 2.846017813721339884e-01 +4.234504161725230476e-01 5.621192138273652938e-01 -7.104306683198732264e-01 +8.736588532595805645e-02 -3.812120228599760186e-01 -9.203502570805403016e-01 +3.726468191887949422e-02 3.056051540215105056e-01 -9.514288377577031497e-01 +3.580783343848107370e-01 -3.317178379678600852e-01 -8.727789997577440895e-01 +2.722414522447213492e-01 -9.620691694719444298e-01 1.753581567102956151e-02 +1.564914425310713619e-01 9.598269716399268070e-01 2.329004356523862451e-01 +8.806037057829864123e-01 4.508578266987513516e-01 1.458229524654813536e-01 +5.912869227710892961e-01 7.719437645497624345e-01 2.334150794885302138e-01 +2.603772661861981641e-01 -7.094721239216934539e-01 6.548686773937528738e-01 +1.566041352813896115e-01 -6.939788963499021746e-01 -7.027577365164613399e-01 +6.396611894362609352e-01 -6.808897123262511730e-01 -3.566829998433663773e-01 +8.703417501146003543e-01 -1.411334918339134659e-01 -4.717908175136747428e-01 +2.469526034715729401e-01 7.409388219811428034e-01 6.245190739439496763e-01 +6.648409967484210092e-01 1.021779043512370533e-01 7.399635970133635610e-01 +7.157210270916442019e-01 5.831057599687722304e-01 -3.843580154883236011e-01 +5.583121379159108333e-01 -8.739995547229871542e-02 -8.250144268067105546e-01 +8.333840007224894153e-01 -5.500905125207584678e-01 -5.358670893446502298e-02 +3.766549463259035724e-01 8.381553084046976521e-01 3.944955391398701217e-01 +7.293684374934888970e-01 3.615995413482260834e-01 -5.807473237863943760e-01 +6.044737588089437175e-01 1.831503974899794385e-01 -7.752853712089822213e-01 +6.774144314458742100e-01 -7.189693188462814577e-01 1.555403697648201633e-01 +8.081532700735619690e-01 -4.320944987827871620e-01 -4.002282301275862930e-01 +5.470294864281085578e-01 -5.019604419915061344e-01 6.699212309323325787e-01 +2.922710122332284333e-01 -1.700591514095178003e-01 9.410938000167882178e-01 +2.248563410182856936e-01 -4.633438858718583742e-01 8.571768016745640040e-01 +8.948795193930951797e-01 3.834677269608361416e-01 -2.283487423882004097e-01 +4.041415251468313818e-01 -7.132497231793889503e-01 -5.726643519868492849e-01 +9.534686225700469420e-01 -2.589146765750766077e-01 -1.544693368549269752e-01 +3.215216961394217199e-01 -8.201509429384397994e-05 -9.469022083537210754e-01 +9.806514599835323143e-01 3.624096994593548754e-02 -1.923780292277270654e-01 +1.121394846554832070e-01 5.710716476281935128e-01 8.132047154661753430e-01 +3.760475566018071647e-01 2.815636138887444573e-01 -8.827854589353636428e-01 +5.131690040773655426e-01 -7.208576195215492532e-01 4.658560567728727841e-01 +9.530327551768297267e-01 -2.653357783804909942e-01 1.460325041601345242e-01 diff --git a/dipy/data/files/test_button_and_slider_widgets.log.gz b/dipy/data/files/test_button_and_slider_widgets.log.gz deleted file mode 100644 index bb1911a7eb..0000000000 Binary files a/dipy/data/files/test_button_and_slider_widgets.log.gz and /dev/null differ diff --git a/dipy/data/files/test_custom_interactor_style_events.log.gz b/dipy/data/files/test_custom_interactor_style_events.log.gz deleted file mode 100644 index 5b2072b4cc..0000000000 Binary files a/dipy/data/files/test_custom_interactor_style_events.log.gz and /dev/null differ diff --git a/dipy/data/files/test_ui_button_panel.log.gz b/dipy/data/files/test_ui_button_panel.log.gz deleted file mode 100644 index 5e0c549d95..0000000000 Binary files a/dipy/data/files/test_ui_button_panel.log.gz and /dev/null differ diff --git a/dipy/data/files/test_ui_button_panel.pkl b/dipy/data/files/test_ui_button_panel.pkl deleted file mode 100644 index 8c87d84e58..0000000000 Binary files a/dipy/data/files/test_ui_button_panel.pkl and /dev/null differ diff --git a/dipy/data/files/test_ui_disk_slider_2d.log.gz b/dipy/data/files/test_ui_disk_slider_2d.log.gz deleted file mode 100644 index 963858ccce..0000000000 Binary files a/dipy/data/files/test_ui_disk_slider_2d.log.gz and /dev/null differ diff --git a/dipy/data/files/test_ui_disk_slider_2d.pkl b/dipy/data/files/test_ui_disk_slider_2d.pkl deleted file mode 100644 index 33cb3899a9..0000000000 Binary files a/dipy/data/files/test_ui_disk_slider_2d.pkl and /dev/null differ diff --git a/dipy/data/files/test_ui_file_select_menu_2d.log.gz b/dipy/data/files/test_ui_file_select_menu_2d.log.gz deleted file mode 100644 index 95cb791bf3..0000000000 Binary files a/dipy/data/files/test_ui_file_select_menu_2d.log.gz and /dev/null differ diff --git a/dipy/data/files/test_ui_file_select_menu_2d.pkl b/dipy/data/files/test_ui_file_select_menu_2d.pkl deleted file mode 100644 index 17f695eb3b..0000000000 Binary files a/dipy/data/files/test_ui_file_select_menu_2d.pkl and /dev/null differ diff --git a/dipy/data/files/test_ui_line_slider_2d.log.gz b/dipy/data/files/test_ui_line_slider_2d.log.gz deleted file mode 100644 index 6036e351cc..0000000000 Binary files a/dipy/data/files/test_ui_line_slider_2d.log.gz and /dev/null differ diff --git a/dipy/data/files/test_ui_line_slider_2d.pkl b/dipy/data/files/test_ui_line_slider_2d.pkl deleted file mode 100644 index 9f9ae3b2aa..0000000000 Binary files a/dipy/data/files/test_ui_line_slider_2d.pkl and /dev/null differ diff --git a/dipy/data/files/test_ui_textbox.log.gz b/dipy/data/files/test_ui_textbox.log.gz deleted file mode 100644 index cc526adf55..0000000000 Binary files a/dipy/data/files/test_ui_textbox.log.gz and /dev/null differ diff --git a/dipy/data/files/test_ui_textbox.pkl b/dipy/data/files/test_ui_textbox.pkl deleted file mode 100644 index be74588863..0000000000 Binary files a/dipy/data/files/test_ui_textbox.pkl and /dev/null differ diff --git a/dipy/data/tests/test_data.py b/dipy/data/tests/test_data.py index e69de29bb2..cc410d13ac 100644 --- a/dipy/data/tests/test_data.py +++ b/dipy/data/tests/test_data.py @@ -0,0 +1,8 @@ +import numpy.testing as npt +from dipy.data import SPHERE_FILES +import numpy as np + +def test_sphere_dtypes(): + for sphere_name, sphere_path in SPHERE_FILES.items(): + sphere_data = np.load(sphere_path) + npt.assert_equal(sphere_data['vertices'].dtype, np.float64) diff --git a/dipy/data/tests/test_fetcher.py b/dipy/data/tests/test_fetcher.py index f727917d0f..eabe0ac9b9 100644 --- a/dipy/data/tests/test_fetcher.py +++ b/dipy/data/tests/test_fetcher.py @@ -10,8 +10,10 @@ if sys.version_info[0] < 3: from SimpleHTTPServer import SimpleHTTPRequestHandler # Python 2 from SocketServer import TCPServer as HTTPServer + from urllib import pathname2url else: from http.server import HTTPServer, SimpleHTTPRequestHandler # Python 3 + from urllib.request import pathname2url def test_check_md5(): @@ -31,13 +33,14 @@ def test_make_fetcher(): stored_md5 = fetcher._get_file_md5(symmetric362) # create local HTTP Server - testfile_url = op.split(symmetric362)[0] + os.sep + testfile_folder = op.split(symmetric362)[0] + os.sep + testfile_url = 'file:' + pathname2url(testfile_folder) test_server_url = "http://127.0.0.1:8000/" print(testfile_url) print(symmetric362) current_dir = os.getcwd() # change pwd to directory containing testfile. - os.chdir(testfile_url) + os.chdir(testfile_folder) server = HTTPServer(('localhost', 8000), SimpleHTTPRequestHandler) server_thread = Thread(target=server.serve_forever) server_thread.deamon = True @@ -45,12 +48,19 @@ def test_make_fetcher(): # test make_fetcher sphere_fetcher = fetcher._make_fetcher("sphere_fetcher", - tmpdir, test_server_url, - [op.split(symmetric362)[-1]], + tmpdir, testfile_url, + [op.sep + + op.split(symmetric362)[-1]], ["sphere_name"], md5_list=[stored_md5]) - sphere_fetcher() + try: + sphere_fetcher() + except Exception as e: + print(e) + # stop local HTTP Server + server.shutdown() + assert op.isfile(op.join(tmpdir, "sphere_name")) npt.assert_equal(fetcher._get_file_md5(op.join(tmpdir, "sphere_name")), stored_md5) @@ -84,13 +94,23 @@ def test_fetch_data(): server_thread.start() files = {"testfile.txt": (test_server_url, md5)} - fetcher.fetch_data(files, tmpdir) + try: + fetcher.fetch_data(files, tmpdir) + except Exception as e: + print(e) + # stop local HTTP Server + server.shutdown() npt.assert_(op.exists(newfile)) # Test that the file is replaced when the md5 doesn't match with open(newfile, 'a') as f: f.write("some junk") - fetcher.fetch_data(files, tmpdir) + try: + fetcher.fetch_data(files, tmpdir) + except Exception as e: + print(e) + # stop local HTTP Server + server.shutdown() npt.assert_(op.exists(newfile)) npt.assert_equal(fetcher._get_file_md5(newfile), md5) diff --git a/dipy/denoise/tests/test_ascm.py b/dipy/denoise/tests/test_ascm.py index 33c8f165c6..4b06555966 100644 --- a/dipy/denoise/tests/test_ascm.py +++ b/dipy/denoise/tests/test_ascm.py @@ -101,8 +101,8 @@ def test_sharpness(): def test_ascm_accuracy(): - test_ascm_data_ref = nib.load(dpd.get_data("ascm_test")).get_data() - test_data = nib.load(dpd.get_data("aniso_vox")).get_data() + test_ascm_data_ref = nib.load(dpd.get_fnames("ascm_test")).get_data() + test_data = nib.load(dpd.get_fnames("aniso_vox")).get_data() # the test data was constructed in this manner mask = test_data > 50 diff --git a/dipy/denoise/tests/test_denoise.py b/dipy/denoise/tests/test_denoise.py index e8fbfde7b4..669183dad5 100644 --- a/dipy/denoise/tests/test_denoise.py +++ b/dipy/denoise/tests/test_denoise.py @@ -10,7 +10,7 @@ def test_denoise(): """ """ - fdata, fbval, fbvec = dpd.get_data() + fdata, fbval, fbvec = dpd.get_fnames() # Test on 4D image: data = nib.load(fdata).get_data() sigma1 = estimate_sigma(data) diff --git a/dipy/denoise/tests/test_noise_estimate.py b/dipy/denoise/tests/test_noise_estimate.py index 73aaf10079..518d6a987a 100644 --- a/dipy/denoise/tests/test_noise_estimate.py +++ b/dipy/denoise/tests/test_noise_estimate.py @@ -30,7 +30,7 @@ def test_inv_nchi(): def test_piesno(): # Values taken from hispeed.OptimalPIESNO with the test data # in the package computed in matlab - test_piesno_data = nib.load(dpd.get_data("test_piesno")).get_data() + test_piesno_data = nib.load(dpd.get_fnames("test_piesno")).get_data() sigma = piesno(test_piesno_data, N=8, alpha=0.01, l=1, eps=1e-10, return_mask=False) assert_almost_equal(sigma, 0.010749458025559) diff --git a/dipy/direction/closest_peak_direction_getter.pyx b/dipy/direction/closest_peak_direction_getter.pyx index b303486169..535e752c6d 100644 --- a/dipy/direction/closest_peak_direction_getter.pyx +++ b/dipy/direction/closest_peak_direction_getter.pyx @@ -98,11 +98,14 @@ cdef class BaseDirectionGetter(DirectionGetter): cdef: size_t _len, i double[:] pmf + double absolute_pmf_threshold pmf = self.pmf_gen.get_pmf_c(point) _len = pmf.shape[0] + + absolute_pmf_threshold = self.pmf_threshold*np.max(pmf) for i in range(_len): - if pmf[i] < self.pmf_threshold: + if pmf[i] < absolute_pmf_threshold: pmf[i] = 0.0 return pmf diff --git a/dipy/direction/peaks.py b/dipy/direction/peaks.py index 5725309d65..ffe8ef537e 100644 --- a/dipy/direction/peaks.py +++ b/dipy/direction/peaks.py @@ -423,10 +423,11 @@ def peaks_from_model(model, data, sphere, relative_peak_threshold, sh_order : int, optional Maximum SH order in the SH fit. For `sh_order`, there will be ``(sh_order + 1) * (sh_order + 2) / 2`` SH coefficients (default 8). - sh_basis_type : {None, 'mrtrix', 'fibernav'} - ``None`` for the default dipy basis which is the fibernav basis, - ``mrtrix`` for the MRtrix basis, and - ``fibernav`` for the FiberNavigator basis + sh_basis_type : {None, 'tournier07', 'descoteaux07'} + ``None`` for the default DIPY basis, + ``tournier07`` for the Tournier 2007 [2]_ basis, and + ``descoteaux07`` for the Descoteaux 2007 [1]_ basis + (``None`` defaults to ``descoteaux07``). sh_smooth : float, optional Lambda-regularization in the SH fit (default 0.0). npeaks : int @@ -450,6 +451,17 @@ def peaks_from_model(model, data, sphere, relative_peak_threshold, pam : PeaksAndMetrics An object with ``gfa``, ``peak_directions``, ``peak_values``, ``peak_indices``, ``odf``, ``shm_coeffs`` as attributes + + References + ---------- + .. [1] Descoteaux, M., Angelino, E., Fitzgibbons, S. and Deriche, R. + Regularized, Fast, and Robust Analytical Q-ball Imaging. + Magn. Reson. Med. 2007;58:497-510. + .. [2] Tournier J.D., Calamante F. and Connelly A. Robust determination + of the fibre orientation distribution in diffusion MRI: + Non-negativity constrained super-resolved spherical deconvolution. + NeuroImage. 2007;35(4):1459-1472. + """ if return_sh and (B is None or invB is None): B, invB = sh_to_sf_matrix( diff --git a/dipy/direction/probabilistic_direction_getter.pyx b/dipy/direction/probabilistic_direction_getter.pyx index feac505e7c..7568bfa61a 100644 --- a/dipy/direction/probabilistic_direction_getter.pyx +++ b/dipy/direction/probabilistic_direction_getter.pyx @@ -180,7 +180,7 @@ cdef class DeterministicMaximumDirectionGetter(ProbabilisticDirectionGetter): max_idx = i max_value = pmf[i] - if pmf[max_idx] == 0: + if max_value <= 0: return 1 newdir = self.vertices[max_idx] diff --git a/dipy/direction/tests/test_peaks.py b/dipy/direction/tests/test_peaks.py index eed6d00970..6b2a68ece8 100644 --- a/dipy/direction/tests/test_peaks.py +++ b/dipy/direction/tests/test_peaks.py @@ -16,10 +16,11 @@ from dipy.core.subdivide_octahedron import create_unit_hemisphere from dipy.core.sphere import unit_icosahedron from dipy.sims.voxel import multi_tensor, multi_tensor_odf -from dipy.data import get_data, get_sphere +from dipy.data import get_fnames, get_sphere from dipy.core.gradients import gradient_table, GradientTable from dipy.core.sphere_stats import angular_similarity from dipy.core.sphere import HemiSphere +from dipy.io.gradients import read_bvals_bvecs def test_peak_directions_nl(): @@ -151,10 +152,9 @@ def test_peak_directions(): def _create_mt_sim(mevals, angles, fractions, S0, SNR, half_sphere=False): - _, fbvals, fbvecs = get_data('small_64D') + _, fbvals, fbvecs = get_fnames('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) @@ -451,77 +451,79 @@ def test_degenerative_cases(): def test_peaksFromModel(): data = np.zeros((10, 2)) - # Test basic case - model = SimpleOdfModel(_gtab) - odf_argmax = _odf.argmax() - pam = peaks_from_model(model, data, _sphere, .5, 45, normalize_peaks=True) - - assert_array_equal(pam.gfa, gfa(_odf)) - assert_array_equal(pam.peak_values[:, 0], 1.) - assert_array_equal(pam.peak_values[:, 1:], 0.) - mn, mx = _odf.min(), _odf.max() - assert_array_equal(pam.qa[:, 0], (mx - mn) / mx) - assert_array_equal(pam.qa[:, 1:], 0.) - assert_array_equal(pam.peak_indices[:, 0], odf_argmax) - assert_array_equal(pam.peak_indices[:, 1:], -1) - - # Test that odf array matches and is right shape - pam = peaks_from_model(model, data, _sphere, .5, 45, return_odf=True) - expected_shape = (len(data), len(_odf)) - assert_equal(pam.odf.shape, expected_shape) - assert_((_odf == pam.odf).all()) - assert_array_equal(pam.peak_values[:, 0], _odf.max()) - - # Test mask - mask = (np.arange(10) % 2) == 1 - - pam = peaks_from_model(model, data, _sphere, .5, 45, mask=mask, - normalize_peaks=True) - assert_array_equal(pam.gfa[~mask], 0) - assert_array_equal(pam.qa[~mask], 0) - assert_array_equal(pam.peak_values[~mask], 0) - assert_array_equal(pam.peak_indices[~mask], -1) - - assert_array_equal(pam.gfa[mask], gfa(_odf)) - assert_array_equal(pam.peak_values[mask, 0], 1.) - assert_array_equal(pam.peak_values[mask, 1:], 0.) - mn, mx = _odf.min(), _odf.max() - assert_array_equal(pam.qa[mask, 0], (mx - mn) / mx) - assert_array_equal(pam.qa[mask, 1:], 0.) - assert_array_equal(pam.peak_indices[mask, 0], odf_argmax) - assert_array_equal(pam.peak_indices[mask, 1:], -1) - - # Test serialization and deserialization: - for normalize_peaks in [True, False]: - for return_odf in [True, False]: - for return_sh in [True, False]: - pam = peaks_from_model(model, data, _sphere, .5, 45, - normalize_peaks=normalize_peaks, - return_odf=return_odf, - return_sh=return_sh) - - b = BytesIO() - pickle.dump(pam, b) - b.seek(0) - new_pam = pickle.load(b) - b.close() - - for attr in ['peak_dirs', 'peak_values', 'peak_indices', - 'gfa', 'qa', 'shm_coeff', 'B', 'odf']: - assert_array_equal(getattr(pam, attr), - getattr(new_pam, attr)) - assert_array_equal(pam.sphere.vertices, - new_pam.sphere.vertices) + for sphere in [_sphere, get_sphere('symmetric642')]: + # Test basic case + model = SimpleOdfModel(_gtab) + _odf = (sphere.vertices * [1, 2, 3]).sum(-1) + odf_argmax = _odf.argmax() + pam = peaks_from_model(model, data, sphere, .5, 45, + normalize_peaks=True) + + assert_array_equal(pam.gfa, gfa(_odf)) + assert_array_equal(pam.peak_values[:, 0], 1.) + assert_array_equal(pam.peak_values[:, 1:], 0.) + mn, mx = _odf.min(), _odf.max() + assert_array_equal(pam.qa[:, 0], (mx - mn) / mx) + assert_array_equal(pam.qa[:, 1:], 0.) + assert_array_equal(pam.peak_indices[:, 0], odf_argmax) + assert_array_equal(pam.peak_indices[:, 1:], -1) + + # Test that odf array matches and is right shape + pam = peaks_from_model(model, data, sphere, .5, 45, return_odf=True) + expected_shape = (len(data), len(_odf)) + assert_equal(pam.odf.shape, expected_shape) + assert_((_odf == pam.odf).all()) + assert_array_equal(pam.peak_values[:, 0], _odf.max()) + + # Test mask + mask = (np.arange(10) % 2) == 1 + + pam = peaks_from_model(model, data, sphere, .5, 45, mask=mask, + normalize_peaks=True) + assert_array_equal(pam.gfa[~mask], 0) + assert_array_equal(pam.qa[~mask], 0) + assert_array_equal(pam.peak_values[~mask], 0) + assert_array_equal(pam.peak_indices[~mask], -1) + + assert_array_equal(pam.gfa[mask], gfa(_odf)) + assert_array_equal(pam.peak_values[mask, 0], 1.) + assert_array_equal(pam.peak_values[mask, 1:], 0.) + mn, mx = _odf.min(), _odf.max() + assert_array_equal(pam.qa[mask, 0], (mx - mn) / mx) + assert_array_equal(pam.qa[mask, 1:], 0.) + assert_array_equal(pam.peak_indices[mask, 0], odf_argmax) + assert_array_equal(pam.peak_indices[mask, 1:], -1) + + # Test serialization and deserialization: + for normalize_peaks in [True, False]: + for return_odf in [True, False]: + for return_sh in [True, False]: + pam = peaks_from_model(model, data, sphere, .5, 45, + normalize_peaks=normalize_peaks, + return_odf=return_odf, + return_sh=return_sh) + + b = BytesIO() + pickle.dump(pam, b) + b.seek(0) + new_pam = pickle.load(b) + b.close() + + for attr in ['peak_dirs', 'peak_values', 'peak_indices', + 'gfa', 'qa', 'shm_coeff', 'B', 'odf']: + assert_array_equal(getattr(pam, attr), + getattr(new_pam, attr)) + assert_array_equal(pam.sphere.vertices, + new_pam.sphere.vertices) def test_peaksFromModelParallel(): SNR = 100 S0 = 100 - _, fbvals, fbvecs = get_data('small_64D') + _, fbvals, fbvecs = get_fnames('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) mevals = np.array(([0.0015, 0.0003, 0.0003], @@ -530,54 +532,56 @@ def test_peaksFromModelParallel(): data, _ = multi_tensor(gtab, mevals, S0, angles=[(0, 0), (60, 0)], fractions=[50, 50], snr=SNR) - # test equality with/without multiprocessing - model = SimpleOdfModel(gtab) - pam_multi = peaks_from_model(model, data, _sphere, .5, 45, - normalize_peaks=True, return_odf=True, - return_sh=True, parallel=True) + for sphere in [_sphere, get_sphere('symmetric724')]: + + # test equality with/without multiprocessing + model = SimpleOdfModel(gtab) + pam_multi = peaks_from_model(model, data, sphere, .5, 45, + normalize_peaks=True, return_odf=True, + return_sh=True, parallel=True) - pam_single = peaks_from_model(model, data, _sphere, .5, 45, - normalize_peaks=True, return_odf=True, - return_sh=True, parallel=False) + pam_single = peaks_from_model(model, data, sphere, .5, 45, + normalize_peaks=True, return_odf=True, + return_sh=True, parallel=False) - pam_multi_inv1 = peaks_from_model(model, data, _sphere, .5, 45, - normalize_peaks=True, return_odf=True, - return_sh=True, parallel=True, - nbr_processes=0) + pam_multi_inv1 = peaks_from_model(model, data, sphere, .5, 45, + normalize_peaks=True, return_odf=True, + return_sh=True, parallel=True, + nbr_processes=0) - pam_multi_inv2 = peaks_from_model(model, data, _sphere, .5, 45, - normalize_peaks=True, return_odf=True, - return_sh=True, parallel=True, - nbr_processes=-2) + pam_multi_inv2 = peaks_from_model(model, data, sphere, .5, 45, + normalize_peaks=True, return_odf=True, + return_sh=True, parallel=True, + nbr_processes=-2) - for pam in [pam_multi, pam_multi_inv1, pam_multi_inv2]: - assert_equal(pam.gfa.dtype, pam_single.gfa.dtype) - assert_equal(pam.gfa.shape, pam_single.gfa.shape) - assert_array_almost_equal(pam.gfa, pam_single.gfa) + for pam in [pam_multi, pam_multi_inv1, pam_multi_inv2]: + assert_equal(pam.gfa.dtype, pam_single.gfa.dtype) + assert_equal(pam.gfa.shape, pam_single.gfa.shape) + assert_array_almost_equal(pam.gfa, pam_single.gfa) - assert_equal(pam.qa.dtype, pam_single.qa.dtype) - assert_equal(pam.qa.shape, pam_single.qa.shape) - assert_array_almost_equal(pam.qa, pam_single.qa) + assert_equal(pam.qa.dtype, pam_single.qa.dtype) + assert_equal(pam.qa.shape, pam_single.qa.shape) + assert_array_almost_equal(pam.qa, pam_single.qa) - assert_equal(pam.peak_values.dtype, pam_single.peak_values.dtype) - assert_equal(pam.peak_values.shape, pam_single.peak_values.shape) - assert_array_almost_equal(pam.peak_values, pam_single.peak_values) + assert_equal(pam.peak_values.dtype, pam_single.peak_values.dtype) + assert_equal(pam.peak_values.shape, pam_single.peak_values.shape) + assert_array_almost_equal(pam.peak_values, pam_single.peak_values) - assert_equal(pam.peak_indices.dtype, pam_single.peak_indices.dtype) - assert_equal(pam.peak_indices.shape, pam_single.peak_indices.shape) - assert_array_equal(pam.peak_indices, pam_single.peak_indices) + assert_equal(pam.peak_indices.dtype, pam_single.peak_indices.dtype) + assert_equal(pam.peak_indices.shape, pam_single.peak_indices.shape) + assert_array_equal(pam.peak_indices, pam_single.peak_indices) - assert_equal(pam.peak_dirs.dtype, pam_single.peak_dirs.dtype) - assert_equal(pam.peak_dirs.shape, pam_single.peak_dirs.shape) - assert_array_almost_equal(pam.peak_dirs, pam_single.peak_dirs) + assert_equal(pam.peak_dirs.dtype, pam_single.peak_dirs.dtype) + assert_equal(pam.peak_dirs.shape, pam_single.peak_dirs.shape) + assert_array_almost_equal(pam.peak_dirs, pam_single.peak_dirs) - assert_equal(pam.shm_coeff.dtype, pam_single.shm_coeff.dtype) - assert_equal(pam.shm_coeff.shape, pam_single.shm_coeff.shape) - assert_array_almost_equal(pam.shm_coeff, pam_single.shm_coeff) + assert_equal(pam.shm_coeff.dtype, pam_single.shm_coeff.dtype) + assert_equal(pam.shm_coeff.shape, pam_single.shm_coeff.shape) + assert_array_almost_equal(pam.shm_coeff, pam_single.shm_coeff) - assert_equal(pam.odf.dtype, pam_single.odf.dtype) - assert_equal(pam.odf.shape, pam_single.odf.shape) - assert_array_almost_equal(pam.odf, pam_single.odf) + assert_equal(pam.odf.dtype, pam_single.odf.dtype) + assert_equal(pam.odf.shape, pam_single.odf.shape) + assert_array_almost_equal(pam.odf, pam_single.odf) def test_peaks_shm_coeff(): @@ -585,14 +589,13 @@ def test_peaks_shm_coeff(): SNR = 100 S0 = 100 - _, fbvals, fbvecs = get_data('small_64D') + _, fbvals, fbvecs = get_fnames('small_64D') from dipy.data import get_sphere sphere = get_sphere('repulsion724') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) mevals = np.array(([0.0015, 0.0003, 0.0003], diff --git a/dipy/direction/tests/test_prob_direction_getter.py b/dipy/direction/tests/test_prob_direction_getter.py index 7c2af8c4a7..84e4776008 100644 --- a/dipy/direction/tests/test_prob_direction_getter.py +++ b/dipy/direction/tests/test_prob_direction_getter.py @@ -3,7 +3,8 @@ from dipy.core.sphere import unit_octahedron from dipy.reconst.shm import SphHarmFit, SphHarmModel -from dipy.direction import ProbabilisticDirectionGetter +from dipy.direction import (DeterministicMaximumDirectionGetter, + ProbabilisticDirectionGetter) def test_ProbabilisticDirectionGetter(): @@ -62,3 +63,26 @@ def fit(self, data, mask=None): fit.shm_coeff, 90, unit_octahedron, pmf_threshold=0.1, basis_type="not a basis") + + +def test_DeterministicMaximumDirectionGetter(): + # Test the DeterministicMaximumDirectionGetter + + dir = unit_octahedron.vertices[-1].copy() + point = np.zeros(3) + N = unit_octahedron.theta.shape[0] + + # No valid direction + pmf = np.zeros((3, 3, 3, N)) + dg = DeterministicMaximumDirectionGetter.from_pmf(pmf, 90, + unit_octahedron) + state = dg.get_direction(point, dir) + npt.assert_equal(state, 1) + + # Test BF #1566 - bad condition in DeterministicMaximumDirectionGetter + pmf = np.zeros((3, 3, 3, N)) + pmf[0, 0, 0, 0] = 1 + dg = DeterministicMaximumDirectionGetter.from_pmf(pmf, 0, + unit_octahedron) + state = dg.get_direction(point, dir) + npt.assert_equal(state, 1) diff --git a/dipy/fixes/argparse.py b/dipy/fixes/argparse.py deleted file mode 100644 index 4821d3ba74..0000000000 --- a/dipy/fixes/argparse.py +++ /dev/null @@ -1,2283 +0,0 @@ -# emacs: -*- coding: utf-8; mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- -# vi: set ft=python sts=4 ts=4 sw=4 et: - -# Copyright 2006-2009 Steven J. Bethard . -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, -# this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE -# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER -# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -# POSSIBILITY OF SUCH DAMAGE. - -import copy as _copy -import os as _os -import re as _re -import sys as _sys -import textwrap as _textwrap - -from gettext import gettext as _ - -"""Command-line parsing library - -This module is an optparse-inspired command-line parsing library that: - - - handles both optional and positional arguments - - produces highly informative usage messages - - supports parsers that dispatch to sub-parsers - -The following is a simple usage example that sums integers from the -command-line and writes the result to a file:: - - parser = argparse.ArgumentParser( - description='sum the integers at the command line') - parser.add_argument( - 'integers', metavar='int', nargs='+', type=int, - help='an integer to be summed') - parser.add_argument( - '--log', default=sys.stdout, type=argparse.FileType('w'), - help='the file where the sum should be written') - args = parser.parse_args() - args.log.write('%s' % sum(args.integers)) - args.log.close() - -The module contains the following public classes: - - - ArgumentParser -- The main entry point for command-line parsing. As the - example above shows, the add_argument() method is used to populate - the parser with actions for optional and positional arguments. Then - the parse_args() method is invoked to convert the args at the - command-line into an object with attributes. - - - ArgumentError -- The exception raised by ArgumentParser objects when - there are errors with the parser's actions. Errors raised while - parsing the command-line are caught by ArgumentParser and emitted - as command-line messages. - - - FileType -- A factory for defining types of files to be created. As the - example above shows, instances of FileType are typically passed as - the type= argument of add_argument() calls. - - - Action -- The base class for parser actions. Typically actions are - selected by passing strings like 'store_true' or 'append_const' to - the action= argument of add_argument(). However, for greater - customization of ArgumentParser actions, subclasses of Action may - be defined and passed as the action= argument. - - - HelpFormatter, RawDescriptionHelpFormatter, RawTextHelpFormatter, - ArgumentDefaultsHelpFormatter -- Formatter classes which - may be passed as the formatter_class= argument to the - ArgumentParser constructor. HelpFormatter is the default, - RawDescriptionHelpFormatter and RawTextHelpFormatter tell the parser - not to change the formatting for help text, and - ArgumentDefaultsHelpFormatter adds information about argument defaults - to the help. - -All other classes in this module are considered implementation details. -(Also note that HelpFormatter and RawDescriptionHelpFormatter are only -considered public as object names -- the API of the formatter objects is -still considered an implementation detail.) -""" - -__version__ = '1.0.1' -__all__ = [ - 'ArgumentParser', - 'ArgumentError', - 'Namespace', - 'Action', - 'FileType', - 'HelpFormatter', - 'RawDescriptionHelpFormatter', - 'RawTextHelpFormatter' - 'ArgumentDefaultsHelpFormatter', -] - -try: - _set = set -except NameError: - from sets import Set as _set - -try: - _basestring = basestring -except NameError: - _basestring = str - -try: - _sorted = sorted -except NameError: - - def _sorted(iterable, reverse=False): - result = list(iterable) - result.sort() - if reverse: - result.reverse() - return result - -# silence Python 2.6 buggy warnings about Exception.message -if _sys.version_info[:2] == (2, 6): - import warnings - warnings.filterwarnings( - action='ignore', - message='BaseException.message has been deprecated as of Python 2.6', - category=DeprecationWarning, - module='argparse') - - -SUPPRESS = '==SUPPRESS==' - -OPTIONAL = '?' -ZERO_OR_MORE = '*' -ONE_OR_MORE = '+' -PARSER = '==PARSER==' - -# ============================= -# Utility functions and classes -# ============================= - - -class _AttributeHolder(object): - """Abstract base class that provides __repr__. - - The __repr__ method returns a string in the format:: - ClassName(attr=name, attr=name, ...) - The attributes are determined either by a class-level attribute, - '_kwarg_names', or by inspecting the instance __dict__. - """ - - def __repr__(self): - type_name = type(self).__name__ - arg_strings = [] - for arg in self._get_args(): - arg_strings.append(repr(arg)) - for name, value in self._get_kwargs(): - arg_strings.append('%s=%r' % (name, value)) - return '%s(%s)' % (type_name, ', '.join(arg_strings)) - - def _get_kwargs(self): - return _sorted(self.__dict__.items()) - - def _get_args(self): - return [] - - -def _ensure_value(namespace, name, value): - if getattr(namespace, name, None) is None: - setattr(namespace, name, value) - return getattr(namespace, name) - - -# =============== -# Formatting Help -# =============== - -class HelpFormatter(object): - """Formatter for generating usage messages and argument help strings. - - Only the name of this class is considered a public API. All the methods - provided by the class are considered an implementation detail. - """ - - def __init__(self, - prog, - indent_increment=2, - max_help_position=24, - width=None): - - # default setting for width - if width is None: - try: - width = int(_os.environ['COLUMNS']) - except (KeyError, ValueError): - width = 80 - width -= 2 - - self._prog = prog - self._indent_increment = indent_increment - self._max_help_position = max_help_position - self._width = width - - self._current_indent = 0 - self._level = 0 - self._action_max_length = 0 - - self._root_section = self._Section(self, None) - self._current_section = self._root_section - - self._whitespace_matcher = _re.compile(r'\s+') - self._long_break_matcher = _re.compile(r'\n\n\n+') - - # =============================== - # Section and indentation methods - # =============================== - def _indent(self): - self._current_indent += self._indent_increment - self._level += 1 - - def _dedent(self): - self._current_indent -= self._indent_increment - assert self._current_indent >= 0, 'Indent decreased below 0.' - self._level -= 1 - - class _Section(object): - - def __init__(self, formatter, parent, heading=None): - self.formatter = formatter - self.parent = parent - self.heading = heading - self.items = [] - - def format_help(self): - # format the indented section - if self.parent is not None: - self.formatter._indent() - join = self.formatter._join_parts - for func, args in self.items: - func(*args) - item_help = join([func(*args) for func, args in self.items]) - if self.parent is not None: - self.formatter._dedent() - - # return nothing if the section was empty - if not item_help: - return '' - - # add the heading if the section was non-empty - if self.heading is not SUPPRESS and self.heading is not None: - current_indent = self.formatter._current_indent - heading = '%*s%s:\n' % (current_indent, '', self.heading) - else: - heading = '' - - # join the section-initial newline, the heading and the help - return join(['\n', heading, item_help, '\n']) - - def _add_item(self, func, args): - self._current_section.items.append((func, args)) - - # ======================== - # Message building methods - # ======================== - def start_section(self, heading): - self._indent() - section = self._Section(self, self._current_section, heading) - self._add_item(section.format_help, []) - self._current_section = section - - def end_section(self): - self._current_section = self._current_section.parent - self._dedent() - - def add_text(self, text): - if text is not SUPPRESS and text is not None: - self._add_item(self._format_text, [text]) - - def add_usage(self, usage, actions, groups, prefix=None): - if usage is not SUPPRESS: - args = usage, actions, groups, prefix - self._add_item(self._format_usage, args) - - def add_argument(self, action): - if action.help is not SUPPRESS: - - # find all invocations - get_invocation = self._format_action_invocation - invocations = [get_invocation(action)] - for subaction in self._iter_indented_subactions(action): - invocations.append(get_invocation(subaction)) - - # update the maximum item length - invocation_length = max([len(s) for s in invocations]) - action_length = invocation_length + self._current_indent - self._action_max_length = max(self._action_max_length, - action_length) - - # add the item to the list - self._add_item(self._format_action, [action]) - - def add_arguments(self, actions): - for action in actions: - self.add_argument(action) - - # ======================= - # Help-formatting methods - # ======================= - def format_help(self): - help = self._root_section.format_help() - if help: - help = self._long_break_matcher.sub('\n\n', help) - help = help.strip('\n') + '\n' - return help - - def _join_parts(self, part_strings): - return ''.join([part - for part in part_strings - if part and part is not SUPPRESS]) - - def _format_usage(self, usage, actions, groups, prefix): - if prefix is None: - prefix = _('usage: ') - - # if usage is specified, use that - if usage is not None: - usage = usage % dict(prog=self._prog) - - # if no optionals or positionals are available, usage is just prog - elif usage is None and not actions: - usage = '%(prog)s' % dict(prog=self._prog) - - # if optionals and positionals are available, calculate usage - elif usage is None: - prog = '%(prog)s' % dict(prog=self._prog) - - # split optionals from positionals - optionals = [] - positionals = [] - for action in actions: - if action.option_strings: - optionals.append(action) - else: - positionals.append(action) - - # build full usage string - format = self._format_actions_usage - action_usage = format(optionals + positionals, groups) - usage = ' '.join([s for s in [prog, action_usage] if s]) - - # wrap the usage parts if it's too long - text_width = self._width - self._current_indent - if len(prefix) + len(usage) > text_width: - - # break usage into wrappable parts - part_regexp = r'\(.*?\)+|\[.*?\]+|\S+' - opt_usage = format(optionals, groups) - pos_usage = format(positionals, groups) - opt_parts = _re.findall(part_regexp, opt_usage) - pos_parts = _re.findall(part_regexp, pos_usage) - assert ' '.join(opt_parts) == opt_usage - assert ' '.join(pos_parts) == pos_usage - - # helper for wrapping lines - def get_lines(parts, indent, prefix=None): - lines = [] - line = [] - if prefix is not None: - line_len = len(prefix) - 1 - else: - line_len = len(indent) - 1 - for part in parts: - if line_len + 1 + len(part) > text_width: - lines.append(indent + ' '.join(line)) - line = [] - line_len = len(indent) - 1 - line.append(part) - line_len += len(part) + 1 - if line: - lines.append(indent + ' '.join(line)) - if prefix is not None: - lines[0] = lines[0][len(indent):] - return lines - - # if prog is short, follow it with optionals or positionals - if len(prefix) + len(prog) <= 0.75 * text_width: - indent = ' ' * (len(prefix) + len(prog) + 1) - if opt_parts: - lines = get_lines([prog] + opt_parts, indent, prefix) - lines.extend(get_lines(pos_parts, indent)) - elif pos_parts: - lines = get_lines([prog] + pos_parts, indent, prefix) - else: - lines = [prog] - - # if prog is long, put it on its own line - else: - indent = ' ' * len(prefix) - parts = opt_parts + pos_parts - lines = get_lines(parts, indent) - if len(lines) > 1: - lines = [] - lines.extend(get_lines(opt_parts, indent)) - lines.extend(get_lines(pos_parts, indent)) - lines = [prog] + lines - - # join lines into usage - usage = '\n'.join(lines) - - # prefix with 'usage:' - return '%s%s\n\n' % (prefix, usage) - - def _format_actions_usage(self, actions, groups): - # find group indices and identify actions in groups - group_actions = _set() - inserts = {} - for group in groups: - try: - start = actions.index(group._group_actions[0]) - except ValueError: - continue - else: - end = start + len(group._group_actions) - if actions[start:end] == group._group_actions: - for action in group._group_actions: - group_actions.add(action) - if not group.required: - inserts[start] = '[' - inserts[end] = ']' - else: - inserts[start] = '(' - inserts[end] = ')' - for i in range(start + 1, end): - inserts[i] = '|' - - # collect all actions format strings - parts = [] - for i, action in enumerate(actions): - - # suppressed arguments are marked with None - # remove | separators for suppressed arguments - if action.help is SUPPRESS: - parts.append(None) - if inserts.get(i) == '|': - inserts.pop(i) - elif inserts.get(i + 1) == '|': - inserts.pop(i + 1) - - # produce all arg strings - elif not action.option_strings: - part = self._format_args(action, action.dest) - - # if it's in a group, strip the outer [] - if action in group_actions: - if part[0] == '[' and part[-1] == ']': - part = part[1:-1] - - # add the action string to the list - parts.append(part) - - # produce the first way to invoke the option in brackets - else: - option_string = action.option_strings[0] - - # if the Optional doesn't take a value, format is: - # -s or --long - if action.nargs == 0: - part = '%s' % option_string - - # if the Optional takes a value, format is: - # -s ARGS or --long ARGS - else: - default = action.dest.upper() - args_string = self._format_args(action, default) - part = '%s %s' % (option_string, args_string) - - # make it look optional if it's not required or in a group - if not action.required and action not in group_actions: - part = '[%s]' % part - - # add the action string to the list - parts.append(part) - - # insert things at the necessary indices - for i in _sorted(inserts, reverse=True): - parts[i:i] = [inserts[i]] - - # join all the action items with spaces - text = ' '.join([item for item in parts if item is not None]) - - # clean up separators for mutually exclusive groups - open = r'[\[(]' - close = r'[\])]' - text = _re.sub(r'(%s) ' % open, r'\1', text) - text = _re.sub(r' (%s)' % close, r'\1', text) - text = _re.sub(r'%s *%s' % (open, close), r'', text) - text = _re.sub(r'\(([^|]*)\)', r'\1', text) - text = text.strip() - - # return the text - return text - - def _format_text(self, text): - text_width = self._width - self._current_indent - indent = ' ' * self._current_indent - return self._fill_text(text, text_width, indent) + '\n\n' - - def _format_action(self, action): - # determine the required width and the entry label - help_position = min(self._action_max_length + 2, - self._max_help_position) - help_width = self._width - help_position - action_width = help_position - self._current_indent - 2 - action_header = self._format_action_invocation(action) - - # ho nelp; start on same line and add a final newline - if not action.help: - tup = self._current_indent, '', action_header - action_header = '%*s%s\n' % tup - - # short action name; start on the same line and pad two spaces - elif len(action_header) <= action_width: - tup = self._current_indent, '', action_width, action_header - action_header = '%*s%-*s ' % tup - indent_first = 0 - - # long action name; start on the next line - else: - tup = self._current_indent, '', action_header - action_header = '%*s%s\n' % tup - indent_first = help_position - - # collect the pieces of the action help - parts = [action_header] - - # if there was help for the action, add lines of help text - if action.help: - help_text = self._expand_help(action) - help_lines = self._split_lines(help_text, help_width) - parts.append('%*s%s\n' % (indent_first, '', help_lines[0])) - for line in help_lines[1:]: - parts.append('%*s%s\n' % (help_position, '', line)) - - # or add a newline if the description doesn't end with one - elif not action_header.endswith('\n'): - parts.append('\n') - - # if there are any sub-actions, add their help as well - for subaction in self._iter_indented_subactions(action): - parts.append(self._format_action(subaction)) - - # return a single string - return self._join_parts(parts) - - def _format_action_invocation(self, action): - if not action.option_strings: - metavar, = self._metavar_formatter(action, action.dest)(1) - return metavar - - else: - parts = [] - - # if the Optional doesn't take a value, format is: - # -s, --long - if action.nargs == 0: - parts.extend(action.option_strings) - - # if the Optional takes a value, format is: - # -s ARGS, --long ARGS - else: - default = action.dest.upper() - args_string = self._format_args(action, default) - for option_string in action.option_strings: - parts.append('%s %s' % (option_string, args_string)) - - return ', '.join(parts) - - def _metavar_formatter(self, action, default_metavar): - if action.metavar is not None: - result = action.metavar - elif action.choices is not None: - choice_strs = [str(choice) for choice in action.choices] - result = '{%s}' % ','.join(choice_strs) - else: - result = default_metavar - - def format(tuple_size): - if isinstance(result, tuple): - return result - else: - return (result, ) * tuple_size - return format - - def _format_args(self, action, default_metavar): - get_metavar = self._metavar_formatter(action, default_metavar) - if action.nargs is None: - result = '%s' % get_metavar(1) - elif action.nargs == OPTIONAL: - result = '[%s]' % get_metavar(1) - elif action.nargs == ZERO_OR_MORE: - result = '[%s [%s ...]]' % get_metavar(2) - elif action.nargs == ONE_OR_MORE: - result = '%s [%s ...]' % get_metavar(2) - elif action.nargs is PARSER: - result = '%s ...' % get_metavar(1) - else: - formats = ['%s' for _ in range(action.nargs)] - result = ' '.join(formats) % get_metavar(action.nargs) - return result - - def _expand_help(self, action): - params = dict(vars(action), prog=self._prog) - for name in list(params): - if params[name] is SUPPRESS: - del params[name] - if params.get('choices') is not None: - choices_str = ', '.join([str(c) for c in params['choices']]) - params['choices'] = choices_str - return self._get_help_string(action) % params - - def _iter_indented_subactions(self, action): - try: - get_subactions = action._get_subactions - except AttributeError: - pass - else: - self._indent() - for subaction in get_subactions(): - yield subaction - self._dedent() - - def _split_lines(self, text, width): - text = self._whitespace_matcher.sub(' ', text).strip() - return _textwrap.wrap(text, width) - - def _fill_text(self, text, width, indent): - text = self._whitespace_matcher.sub(' ', text).strip() - return _textwrap.fill(text, width, initial_indent=indent, - subsequent_indent=indent) - - def _get_help_string(self, action): - return action.help - - -class RawDescriptionHelpFormatter(HelpFormatter): - """Help message formatter which retains any formatting in descriptions. - - Only the name of this class is considered a public API. All the methods - provided by the class are considered an implementation detail. - """ - - def _fill_text(self, text, width, indent): - return ''.join([indent + line for line in text.splitlines(True)]) - - -class RawTextHelpFormatter(RawDescriptionHelpFormatter): - """Help message formatter which retains formatting of all help text. - - Only the name of this class is considered a public API. All the methods - provided by the class are considered an implementation detail. - """ - - def _split_lines(self, text, width): - return text.splitlines() - - -class ArgumentDefaultsHelpFormatter(HelpFormatter): - """Help message formatter which adds default values to argument help. - - Only the name of this class is considered a public API. All the methods - provided by the class are considered an implementation detail. - """ - - def _get_help_string(self, action): - help = action.help - if '%(default)' not in action.help: - if action.default is not SUPPRESS: - defaulting_nargs = [OPTIONAL, ZERO_OR_MORE] - if action.option_strings or action.nargs in defaulting_nargs: - help += ' (default: %(default)s)' - return help - - -# ===================== -# Options and Arguments -# ===================== - -def _get_action_name(argument): - if argument is None: - return None - elif argument.option_strings: - return '/'.join(argument.option_strings) - elif argument.metavar not in (None, SUPPRESS): - return argument.metavar - elif argument.dest not in (None, SUPPRESS): - return argument.dest - else: - return None - - -class ArgumentError(Exception): - """An error from creating or using an argument (optional or positional). - - The string value of this exception is the message, augmented with - information about the argument that caused it. - """ - - def __init__(self, argument, message): - self.argument_name = _get_action_name(argument) - self.message = message - - def __str__(self): - if self.argument_name is None: - format = '%(message)s' - else: - format = 'argument %(argument_name)s: %(message)s' - return format % dict(message=self.message, - argument_name=self.argument_name) - -# ============== -# Action classes -# ============== - - -class Action(_AttributeHolder): - """Information about how to convert command line strings to Python objects. - - Action objects are used by an ArgumentParser to represent the information - needed to parse a single argument from one or more strings from the - command line. The keyword arguments to the Action constructor are also - all attributes of Action instances. - - Keyword Arguments: - - - option_strings -- A list of command-line option strings which - should be associated with this action. - - - dest -- The name of the attribute to hold the created object(s) - - - nargs -- The number of command-line arguments that should be - consumed. By default, one argument will be consumed and a single - value will be produced. Other values include: - - N (an integer) consumes N arguments (and produces a list) - - '?' consumes zero or one arguments - - '*' consumes zero or more arguments (and produces a list) - - '+' consumes one or more arguments (and produces a list) - Note that the difference between the default and nargs=1 is that - with the default, a single value will be produced, while with - nargs=1, a list containing a single value will be produced. - - - const -- The value to be produced if the option is specified and the - option uses an action that takes no values. - - - default -- The value to be produced if the option is not specified. - - - type -- The type which the command-line arguments should be converted - to, should be one of 'string', 'int', 'float', 'complex' or a - callable object that accepts a single string argument. If None, - 'string' is assumed. - - - choices -- A container of values that should be allowed. If not None, - after a command-line argument has been converted to the appropriate - type, an exception will be raised if it is not a member of this - collection. - - - required -- True if the action must always be specified at the - command line. This is only meaningful for optional command-line - arguments. - - - help -- The help string describing the argument. - - - metavar -- The name to be used for the option's argument with the - help string. If None, the 'dest' value will be used as the name. - """ - - def __init__(self, - option_strings, - dest, - nargs=None, - const=None, - default=None, - type=None, - choices=None, - required=False, - help=None, - metavar=None): - self.option_strings = option_strings - self.dest = dest - self.nargs = nargs - self.const = const - self.default = default - self.type = type - self.choices = choices - self.required = required - self.help = help - self.metavar = metavar - - def _get_kwargs(self): - names = [ - 'option_strings', - 'dest', - 'nargs', - 'const', - 'default', - 'type', - 'choices', - 'help', - 'metavar', - ] - return [(name, getattr(self, name)) for name in names] - - def __call__(self, parser, namespace, values, option_string=None): - raise NotImplementedError(_('.__call__() not defined')) - - -class _StoreAction(Action): - - def __init__(self, - option_strings, - dest, - nargs=None, - const=None, - default=None, - type=None, - choices=None, - required=False, - help=None, - metavar=None): - if nargs == 0: - raise ValueError('nargs for store actions must be > 0; if you ' - 'have nothing to store, actions such as store ' - 'true or store const may be more appropriate') - if const is not None and nargs != OPTIONAL: - raise ValueError('nargs must be %r to supply const' % OPTIONAL) - super(_StoreAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=nargs, - const=const, - default=default, - type=type, - choices=choices, - required=required, - help=help, - metavar=metavar) - - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, values) - - -class _StoreConstAction(Action): - - def __init__(self, - option_strings, - dest, - const, - default=None, - required=False, - help=None, - metavar=None): - super(_StoreConstAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=0, - const=const, - default=default, - required=required, - help=help) - - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, self.const) - - -class _StoreTrueAction(_StoreConstAction): - - def __init__(self, - option_strings, - dest, - default=False, - required=False, - help=None): - super(_StoreTrueAction, self).__init__( - option_strings=option_strings, - dest=dest, - const=True, - default=default, - required=required, - help=help) - - -class _StoreFalseAction(_StoreConstAction): - - def __init__(self, - option_strings, - dest, - default=True, - required=False, - help=None): - super(_StoreFalseAction, self).__init__( - option_strings=option_strings, - dest=dest, - const=False, - default=default, - required=required, - help=help) - - -class _AppendAction(Action): - - def __init__(self, - option_strings, - dest, - nargs=None, - const=None, - default=None, - type=None, - choices=None, - required=False, - help=None, - metavar=None): - if nargs == 0: - raise ValueError('nargs for append actions must be > 0; if arg ' - 'strings are not supplying the value to append, ' - 'the append const action may be more appropriate') - if const is not None and nargs != OPTIONAL: - raise ValueError('nargs must be %r to supply const' % OPTIONAL) - super(_AppendAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=nargs, - const=const, - default=default, - type=type, - choices=choices, - required=required, - help=help, - metavar=metavar) - - def __call__(self, parser, namespace, values, option_string=None): - items = _copy.copy(_ensure_value(namespace, self.dest, [])) - items.append(values) - setattr(namespace, self.dest, items) - - -class _AppendConstAction(Action): - - def __init__(self, - option_strings, - dest, - const, - default=None, - required=False, - help=None, - metavar=None): - super(_AppendConstAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=0, - const=const, - default=default, - required=required, - help=help, - metavar=metavar) - - def __call__(self, parser, namespace, values, option_string=None): - items = _copy.copy(_ensure_value(namespace, self.dest, [])) - items.append(self.const) - setattr(namespace, self.dest, items) - - -class _CountAction(Action): - - def __init__(self, - option_strings, - dest, - default=None, - required=False, - help=None): - super(_CountAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=0, - default=default, - required=required, - help=help) - - def __call__(self, parser, namespace, values, option_string=None): - new_count = _ensure_value(namespace, self.dest, 0) + 1 - setattr(namespace, self.dest, new_count) - - -class _HelpAction(Action): - - def __init__(self, - option_strings, - dest=SUPPRESS, - default=SUPPRESS, - help=None): - super(_HelpAction, self).__init__( - option_strings=option_strings, - dest=dest, - default=default, - nargs=0, - help=help) - - def __call__(self, parser, namespace, values, option_string=None): - parser.print_help() - parser.exit() - - -class _VersionAction(Action): - - def __init__(self, - option_strings, - dest=SUPPRESS, - default=SUPPRESS, - help=None): - super(_VersionAction, self).__init__( - option_strings=option_strings, - dest=dest, - default=default, - nargs=0, - help=help) - - def __call__(self, parser, namespace, values, option_string=None): - parser.print_version() - parser.exit() - - -class _SubParsersAction(Action): - - class _ChoicesPseudoAction(Action): - - def __init__(self, name, help): - sup = super(_SubParsersAction._ChoicesPseudoAction, self) - sup.__init__(option_strings=[], dest=name, help=help) - - def __init__(self, - option_strings, - prog, - parser_class, - dest=SUPPRESS, - help=None, - metavar=None): - - self._prog_prefix = prog - self._parser_class = parser_class - self._name_parser_map = {} - self._choices_actions = [] - - super(_SubParsersAction, self).__init__( - option_strings=option_strings, - dest=dest, - nargs=PARSER, - choices=self._name_parser_map, - help=help, - metavar=metavar) - - def add_parser(self, name, **kwargs): - # set prog from the existing prefix - if kwargs.get('prog') is None: - kwargs['prog'] = '%s %s' % (self._prog_prefix, name) - - # create a pseudo-action to hold the choice help - if 'help' in kwargs: - help = kwargs.pop('help') - choice_action = self._ChoicesPseudoAction(name, help) - self._choices_actions.append(choice_action) - - # create the parser and add it to the map - parser = self._parser_class(**kwargs) - self._name_parser_map[name] = parser - return parser - - def _get_subactions(self): - return self._choices_actions - - def __call__(self, parser, namespace, values, option_string=None): - parser_name = values[0] - arg_strings = values[1:] - - # set the parser name if requested - if self.dest is not SUPPRESS: - setattr(namespace, self.dest, parser_name) - - # select the parser - try: - parser = self._name_parser_map[parser_name] - except KeyError: - tup = parser_name, ', '.join(self._name_parser_map) - msg = _('unknown parser %r (choices: %s)' % tup) - raise ArgumentError(self, msg) - - # parse all the remaining options into the namespace - parser.parse_args(arg_strings, namespace) - - -# ============== -# Type classes -# ============== - -class FileType(object): - """Factory for creating file object types - - Instances of FileType are typically passed as type= arguments to the - ArgumentParser add_argument() method. - - Keyword Arguments: - - mode -- A string indicating how the file is to be opened. Accepts the - same values as the builtin open() function. - - bufsize -- The file's desired buffer size. Accepts the same values as - the builtin open() function. - """ - - def __init__(self, mode='r', bufsize=None): - self._mode = mode - self._bufsize = bufsize - - def __call__(self, string): - # the special argument "-" means sys.std{in,out} - if string == '-': - if 'r' in self._mode: - return _sys.stdin - elif 'w' in self._mode: - return _sys.stdout - else: - msg = _('argument "-" with mode %r' % self._mode) - raise ValueError(msg) - - # all other arguments are used as file names - if self._bufsize: - return open(string, self._mode, self._bufsize) - else: - return open(string, self._mode) - - def __repr__(self): - args = [self._mode, self._bufsize] - args_str = ', '.join([repr(arg) for arg in args if arg is not None]) - return '%s(%s)' % (type(self).__name__, args_str) - -# =========================== -# Optional and Positional Parsing -# =========================== - - -class Namespace(_AttributeHolder): - """Simple object for storing attributes. - - Implements equality by attribute names and values, and provides a simple - string representation. - """ - - def __init__(self, **kwargs): - for name in kwargs: - setattr(self, name, kwargs[name]) - - def __eq__(self, other): - return vars(self) == vars(other) - - def __ne__(self, other): - return not (self == other) - - -class _ActionsContainer(object): - - def __init__(self, - description, - prefix_chars, - argument_default, - conflict_handler): - super(_ActionsContainer, self).__init__() - - self.description = description - self.argument_default = argument_default - self.prefix_chars = prefix_chars - self.conflict_handler = conflict_handler - - # set up registries - self._registries = {} - - # register actions - self.register('action', None, _StoreAction) - self.register('action', 'store', _StoreAction) - self.register('action', 'store_const', _StoreConstAction) - self.register('action', 'store_true', _StoreTrueAction) - self.register('action', 'store_false', _StoreFalseAction) - self.register('action', 'append', _AppendAction) - self.register('action', 'append_const', _AppendConstAction) - self.register('action', 'count', _CountAction) - self.register('action', 'help', _HelpAction) - self.register('action', 'version', _VersionAction) - self.register('action', 'parsers', _SubParsersAction) - - # raise an exception if the conflict handler is invalid - self._get_handler() - - # action storage - self._actions = [] - self._option_string_actions = {} - - # groups - self._action_groups = [] - self._mutually_exclusive_groups = [] - - # defaults storage - self._defaults = {} - - # determines whether an "option" looks like a negative number - self._negative_number_matcher = _re.compile(r'^-\d+|-\d*.\d+$') - - # whether or not there are any optionals that look like negative - # numbers -- uses a list so it can be shared and edited - self._has_negative_number_optionals = [] - - # ==================== - # Registration methods - # ==================== - def register(self, registry_name, value, object): - registry = self._registries.setdefault(registry_name, {}) - registry[value] = object - - def _registry_get(self, registry_name, value, default=None): - return self._registries[registry_name].get(value, default) - - # ================================== - # Namespace default settings methods - # ================================== - def set_defaults(self, **kwargs): - self._defaults.update(kwargs) - - # if these defaults match any existing arguments, replace - # the previous default on the object with the new one - for action in self._actions: - if action.dest in kwargs: - action.default = kwargs[action.dest] - - # ======================= - # Adding argument actions - # ======================= - def add_argument(self, *args, **kwargs): - """ - add_argument(dest, ..., name=value, ...) - add_argument(option_string, option_string, ..., name=value, ...) - """ - - # if no positional args are supplied or only one is supplied and - # it doesn't look like an option string, parse a positional - # argument - chars = self.prefix_chars - if not args or len(args) == 1 and args[0][0] not in chars: - kwargs = self._get_positional_kwargs(*args, **kwargs) - - # otherwise, we're adding an optional argument - else: - kwargs = self._get_optional_kwargs(*args, **kwargs) - - # if no default was supplied, use the parser-level default - if 'default' not in kwargs: - dest = kwargs['dest'] - if dest in self._defaults: - kwargs['default'] = self._defaults[dest] - elif self.argument_default is not None: - kwargs['default'] = self.argument_default - - # create the action object, and add it to the parser - action_class = self._pop_action_class(kwargs) - action = action_class(**kwargs) - return self._add_action(action) - - def add_argument_group(self, *args, **kwargs): - group = _ArgumentGroup(self, *args, **kwargs) - self._action_groups.append(group) - return group - - def add_mutually_exclusive_group(self, **kwargs): - group = _MutuallyExclusiveGroup(self, **kwargs) - self._mutually_exclusive_groups.append(group) - return group - - def _add_action(self, action): - # resolve any conflicts - self._check_conflict(action) - - # add to actions list - self._actions.append(action) - action.container = self - - # index the action by any option strings it has - for option_string in action.option_strings: - self._option_string_actions[option_string] = action - - # set the flag if any option strings look like negative numbers - for option_string in action.option_strings: - if self._negative_number_matcher.match(option_string): - if not self._has_negative_number_optionals: - self._has_negative_number_optionals.append(True) - - # return the created action - return action - - def _remove_action(self, action): - self._actions.remove(action) - - def _add_container_actions(self, container): - # collect groups by titles - title_group_map = {} - for group in self._action_groups: - if group.title in title_group_map: - msg = _('cannot merge actions - two groups are named %r') - raise ValueError(msg % (group.title)) - title_group_map[group.title] = group - - # map each action to its group - group_map = {} - for group in container._action_groups: - - # if a group with the title exists, use that, otherwise - # create a new group matching the container's group - if group.title not in title_group_map: - title_group_map[group.title] = self.add_argument_group( - title=group.title, - description=group.description, - conflict_handler=group.conflict_handler) - - # map the actions to their new group - for action in group._group_actions: - group_map[action] = title_group_map[group.title] - - # add container's mutually exclusive groups - # NOTE: if add_mutually_exclusive_group ever gains title= and - # description= then this code will need to be expanded as above - for group in container._mutually_exclusive_groups: - mutex_group = self.add_mutually_exclusive_group( - required=group.required) - - # map the actions to their new mutex group - for action in group._group_actions: - group_map[action] = mutex_group - - # add all actions to this container or their group - for action in container._actions: - group_map.get(action, self)._add_action(action) - - def _get_positional_kwargs(self, dest, **kwargs): - # make sure required is not specified - if 'required' in kwargs: - msg = _("'required' is an invalid argument for positionals") - raise TypeError(msg) - - # mark positional arguments as required if at least one is - # always required - if kwargs.get('nargs') not in [OPTIONAL, ZERO_OR_MORE]: - kwargs['required'] = True - if kwargs.get('nargs') == ZERO_OR_MORE and 'default' not in kwargs: - kwargs['required'] = True - - # return the keyword arguments with no option strings - return dict(kwargs, dest=dest, option_strings=[]) - - def _get_optional_kwargs(self, *args, **kwargs): - # determine short and long option strings - option_strings = [] - long_option_strings = [] - for option_string in args: - # error on one-or-fewer-character option strings - if len(option_string) < 2: - msg = _('invalid option string %r: ' - 'must be at least two characters long') - raise ValueError(msg % option_string) - - # error on strings that don't start with an appropriate prefix - if not option_string[0] in self.prefix_chars: - msg = _('invalid option string %r: ' - 'must start with a character %r') - tup = option_string, self.prefix_chars - raise ValueError(msg % tup) - - # error on strings that are all prefix characters - if not (_set(option_string) - _set(self.prefix_chars)): - msg = _('invalid option string %r: ' - 'must contain characters other than %r') - tup = option_string, self.prefix_chars - raise ValueError(msg % tup) - - # strings starting with two prefix characters are long options - option_strings.append(option_string) - if option_string[0] in self.prefix_chars: - if option_string[1] in self.prefix_chars: - long_option_strings.append(option_string) - - # infer destination, '--foo-bar' -> 'foo_bar' and '-x' -> 'x' - dest = kwargs.pop('dest', None) - if dest is None: - if long_option_strings: - dest_option_string = long_option_strings[0] - else: - dest_option_string = option_strings[0] - dest = dest_option_string.lstrip(self.prefix_chars) - dest = dest.replace('-', '_') - - # return the updated keyword arguments - return dict(kwargs, dest=dest, option_strings=option_strings) - - def _pop_action_class(self, kwargs, default=None): - action = kwargs.pop('action', default) - return self._registry_get('action', action, action) - - def _get_handler(self): - # determine function from conflict handler string - handler_func_name = '_handle_conflict_%s' % self.conflict_handler - try: - return getattr(self, handler_func_name) - except AttributeError: - msg = _('invalid conflict_resolution value: %r') - raise ValueError(msg % self.conflict_handler) - - def _check_conflict(self, action): - - # find all options that conflict with this option - confl_optionals = [] - for option_string in action.option_strings: - if option_string in self._option_string_actions: - confl_optional = self._option_string_actions[option_string] - confl_optionals.append((option_string, confl_optional)) - - # resolve any conflicts - if confl_optionals: - conflict_handler = self._get_handler() - conflict_handler(action, confl_optionals) - - def _handle_conflict_error(self, action, conflicting_actions): - message = _('conflicting option string(s): %s') - conflict_string = ', '.join([option_string - for option_string, action - in conflicting_actions]) - raise ArgumentError(action, message % conflict_string) - - def _handle_conflict_resolve(self, action, conflicting_actions): - - # remove all conflicting options - for option_string, action in conflicting_actions: - - # remove the conflicting option - action.option_strings.remove(option_string) - self._option_string_actions.pop(option_string, None) - - # if the option now has no option string, remove it from the - # container holding it - if not action.option_strings: - action.container._remove_action(action) - - -class _ArgumentGroup(_ActionsContainer): - - def __init__(self, container, title=None, description=None, **kwargs): - # add any missing keyword arguments by checking the container - update = kwargs.setdefault - update('conflict_handler', container.conflict_handler) - update('prefix_chars', container.prefix_chars) - update('argument_default', container.argument_default) - super_init = super(_ArgumentGroup, self).__init__ - super_init(description=description, **kwargs) - - # group attributes - self.title = title - self._group_actions = [] - - # share most attributes with the container - self._registries = container._registries - self._actions = container._actions - self._option_string_actions = container._option_string_actions - self._defaults = container._defaults - self._has_negative_number_optionals = \ - container._has_negative_number_optionals - - def _add_action(self, action): - action = super(_ArgumentGroup, self)._add_action(action) - self._group_actions.append(action) - return action - - def _remove_action(self, action): - super(_ArgumentGroup, self)._remove_action(action) - self._group_actions.remove(action) - - -class _MutuallyExclusiveGroup(_ArgumentGroup): - - def __init__(self, container, required=False): - super(_MutuallyExclusiveGroup, self).__init__(container) - self.required = required - self._container = container - - def _add_action(self, action): - if action.required: - msg = _('mutually exclusive arguments must be optional') - raise ValueError(msg) - action = self._container._add_action(action) - self._group_actions.append(action) - return action - - def _remove_action(self, action): - self._container._remove_action(action) - self._group_actions.remove(action) - - -class ArgumentParser(_AttributeHolder, _ActionsContainer): - """Object for parsing command line strings into Python objects. - - Keyword Arguments: - - prog -- The name of the program (default: sys.argv[0]) - - usage -- A usage message (default: auto-generated from arguments) - - description -- A description of what the program does - - epilog -- Text following the argument descriptions - - version -- Add a -v/--version option with the given version string - - parents -- Parsers whose arguments should be copied into this one - - formatter_class -- HelpFormatter class for printing help messages - - prefix_chars -- Characters that prefix optional arguments - - fromfile_prefix_chars -- Characters that prefix files containing - additional arguments - - argument_default -- The default value for all arguments - - conflict_handler -- String indicating how to handle conflicts - - add_help -- Add a -h/-help option - """ - - def __init__(self, - prog=None, - usage=None, - description=None, - epilog=None, - version=None, - parents=[], - formatter_class=HelpFormatter, - prefix_chars='-', - fromfile_prefix_chars=None, - argument_default=None, - conflict_handler='error', - add_help=True): - - superinit = super(ArgumentParser, self).__init__ - superinit(description=description, - prefix_chars=prefix_chars, - argument_default=argument_default, - conflict_handler=conflict_handler) - - # default setting for prog - if prog is None: - prog = _os.path.basename(_sys.argv[0]) - - self.prog = prog - self.usage = usage - self.epilog = epilog - self.version = version - self.formatter_class = formatter_class - self.fromfile_prefix_chars = fromfile_prefix_chars - self.add_help = add_help - - add_group = self.add_argument_group - self._positionals = add_group(_('positional arguments')) - self._optionals = add_group(_('optional arguments')) - self._subparsers = None - - # register types - def identity(string): - return string - self.register('type', None, identity) - - # add help and version arguments if necessary - # (using explicit default to override global argument_default) - if self.add_help: - self.add_argument( - '-h', '--help', action='help', default=SUPPRESS, - help=_('show this help message and exit')) - if self.version: - self.add_argument( - '-v', '--version', action='version', default=SUPPRESS, - help=_("show program's version number and exit")) - - # add parent arguments and defaults - for parent in parents: - self._add_container_actions(parent) - try: - defaults = parent._defaults - except AttributeError: - pass - else: - self._defaults.update(defaults) - - # ======================= - # Pretty __repr__ methods - # ======================= - def _get_kwargs(self): - names = [ - 'prog', - 'usage', - 'description', - 'version', - 'formatter_class', - 'conflict_handler', - 'add_help', - ] - return [(name, getattr(self, name)) for name in names] - - # ================================== - # Optional/Positional adding methods - # ================================== - def add_subparsers(self, **kwargs): - if self._subparsers is not None: - self.error(_('cannot have multiple subparser arguments')) - - # add the parser class to the arguments if it's not present - kwargs.setdefault('parser_class', type(self)) - - if 'title' in kwargs or 'description' in kwargs: - title = _(kwargs.pop('title', 'subcommands')) - description = _(kwargs.pop('description', None)) - self._subparsers = self.add_argument_group(title, description) - else: - self._subparsers = self._positionals - - # prog defaults to the usage message of this parser, skipping - # optional arguments and with no "usage:" prefix - if kwargs.get('prog') is None: - formatter = self._get_formatter() - positionals = self._get_positional_actions() - groups = self._mutually_exclusive_groups - formatter.add_usage(self.usage, positionals, groups, '') - kwargs['prog'] = formatter.format_help().strip() - - # create the parsers action and add it to the positionals list - parsers_class = self._pop_action_class(kwargs, 'parsers') - action = parsers_class(option_strings=[], **kwargs) - self._subparsers._add_action(action) - - # return the created parsers action - return action - - def _add_action(self, action): - if action.option_strings: - self._optionals._add_action(action) - else: - self._positionals._add_action(action) - return action - - def _get_optional_actions(self): - return [action - for action in self._actions - if action.option_strings] - - def _get_positional_actions(self): - return [action - for action in self._actions - if not action.option_strings] - - # ===================================== - # Command line argument parsing methods - # ===================================== - def parse_args(self, args=None, namespace=None): - args, argv = self.parse_known_args(args, namespace) - if argv: - msg = _('unrecognized arguments: %s') - self.error(msg % ' '.join(argv)) - return args - - def parse_known_args(self, args=None, namespace=None): - # args default to the system args - if args is None: - args = _sys.argv[1:] - - # default Namespace built from parser defaults - if namespace is None: - namespace = Namespace() - - # add any action defaults that aren't present - for action in self._actions: - if action.dest is not SUPPRESS: - if not hasattr(namespace, action.dest): - if action.default is not SUPPRESS: - default = action.default - if isinstance(action.default, _basestring): - default = self._get_value(action, default) - setattr(namespace, action.dest, default) - - # add any parser defaults that aren't present - for dest in self._defaults: - if not hasattr(namespace, dest): - setattr(namespace, dest, self._defaults[dest]) - - # parse the arguments and exit if there are any errors - try: - return self._parse_known_args(args, namespace) - except ArgumentError: - err = _sys.exc_info()[1] - self.error(str(err)) - - def _parse_known_args(self, arg_strings, namespace): - # replace arg strings that are file references - if self.fromfile_prefix_chars is not None: - arg_strings = self._read_args_from_files(arg_strings) - - # map all mutually exclusive arguments to the other arguments - # they can't occur with - action_conflicts = {} - for mutex_group in self._mutually_exclusive_groups: - group_actions = mutex_group._group_actions - for i, mutex_action in enumerate(mutex_group._group_actions): - conflicts = action_conflicts.setdefault(mutex_action, []) - conflicts.extend(group_actions[:i]) - conflicts.extend(group_actions[i + 1:]) - - # find all option indices, and determine the arg_string_pattern - # which has an 'O' if there is an option at an index, - # an 'A' if there is an argument, or a '-' if there is a '--' - option_string_indices = {} - arg_string_pattern_parts = [] - arg_strings_iter = iter(arg_strings) - for i, arg_string in enumerate(arg_strings_iter): - - # all args after -- are non-options - if arg_string == '--': - arg_string_pattern_parts.append('-') - for _ in arg_strings_iter: - arg_string_pattern_parts.append('A') - - # otherwise, add the arg to the arg strings - # and note the index if it was an option - else: - option_tuple = self._parse_optional(arg_string) - if option_tuple is None: - pattern = 'A' - else: - option_string_indices[i] = option_tuple - pattern = 'O' - arg_string_pattern_parts.append(pattern) - - # join the pieces together to form the pattern - arg_strings_pattern = ''.join(arg_string_pattern_parts) - - # converts arg strings to the appropriate and then takes the action - seen_actions = _set() - seen_non_default_actions = _set() - - def take_action(action, argument_strings, option_string=None): - seen_actions.add(action) - argument_values = self._get_values(action, argument_strings) - - # error if this argument is not allowed with other previously - # seen arguments, assuming that actions that use the default - # value don't really count as "present" - if argument_values is not action.default: - seen_non_default_actions.add(action) - for conflict_action in action_conflicts.get(action, []): - if conflict_action in seen_non_default_actions: - msg = _('not allowed with argument %s') - action_name = _get_action_name(conflict_action) - raise ArgumentError(action, msg % action_name) - - # take the action if we didn't receive a SUPPRESS value - # (e.g. from a default) - if argument_values is not SUPPRESS: - action(self, namespace, argument_values, option_string) - - # function to convert arg_strings into an optional action - def consume_optional(start_index): - - # get the optional identified at this index - option_tuple = option_string_indices[start_index] - action, option_string, explicit_arg = option_tuple - - # identify additional optionals in the same arg string - # (e.g. -xyz is the same as -x -y -z if no args are required) - match_argument = self._match_argument - action_tuples = [] - while True: - - # if we found no optional action, skip it - if action is None: - extras.append(arg_strings[start_index]) - return start_index + 1 - - # if there is an explicit argument, try to match the - # optional's string arguments to only this - if explicit_arg is not None: - arg_count = match_argument(action, 'A') - - # if the action is a single-dash option and takes no - # arguments, try to parse more single-dash options out - # of the tail of the option string - chars = self.prefix_chars - if arg_count == 0 and option_string[1] not in chars: - action_tuples.append((action, [], option_string)) - for char in self.prefix_chars: - option_string = char + explicit_arg[0] - explicit_arg = explicit_arg[1:] or None - optionals_map = self._option_string_actions - if option_string in optionals_map: - action = optionals_map[option_string] - break - else: - msg = _('ignored explicit argument %r') - raise ArgumentError(action, msg % explicit_arg) - - # if the action expect exactly one argument, we've - # successfully matched the option; exit the loop - elif arg_count == 1: - stop = start_index + 1 - args = [explicit_arg] - action_tuples.append((action, args, option_string)) - break - - # error if a double-dash option did not use the - # explicit argument - else: - msg = _('ignored explicit argument %r') - raise ArgumentError(action, msg % explicit_arg) - - # if there is no explicit argument, try to match the - # optional's string arguments with the following strings - # if successful, exit the loop - else: - start = start_index + 1 - selected_patterns = arg_strings_pattern[start:] - arg_count = match_argument(action, selected_patterns) - stop = start + arg_count - args = arg_strings[start:stop] - action_tuples.append((action, args, option_string)) - break - - # add the Optional to the list and return the index at which - # the Optional's string args stopped - assert action_tuples - for action, args, option_string in action_tuples: - take_action(action, args, option_string) - return stop - - # the list of Positionals left to be parsed; this is modified - # by consume_positionals() - positionals = self._get_positional_actions() - - # function to convert arg_strings into positional actions - def consume_positionals(start_index): - # match as many Positionals as possible - match_partial = self._match_arguments_partial - selected_pattern = arg_strings_pattern[start_index:] - arg_counts = match_partial(positionals, selected_pattern) - - # slice off the appropriate arg strings for each Positional - # and add the Positional and its args to the list - for action, arg_count in zip(positionals, arg_counts): - args = arg_strings[start_index: start_index + arg_count] - start_index += arg_count - take_action(action, args) - - # slice off the Positionals that we just parsed and return the - # index at which the Positionals' string args stopped - positionals[:] = positionals[len(arg_counts):] - return start_index - - # consume Positionals and Optionals alternately, until we have - # passed the last option string - extras = [] - start_index = 0 - if option_string_indices: - max_option_string_index = max(option_string_indices) - else: - max_option_string_index = -1 - while start_index <= max_option_string_index: - - # consume any Positionals preceding the next option - next_option_string_index = min([ - index - for index in option_string_indices - if index >= start_index]) - if start_index != next_option_string_index: - positionals_end_index = consume_positionals(start_index) - - # only try to parse the next optional if we didn't consume - # the option string during the positionals parsing - if positionals_end_index > start_index: - start_index = positionals_end_index - continue - else: - start_index = positionals_end_index - - # if we consumed all the positionals we could and we're not - # at the index of an option string, there were extra arguments - if start_index not in option_string_indices: - strings = arg_strings[start_index:next_option_string_index] - extras.extend(strings) - start_index = next_option_string_index - - # consume the next optional and any arguments for it - start_index = consume_optional(start_index) - - # consume any positionals following the last Optional - stop_index = consume_positionals(start_index) - - # if we didn't consume all the argument strings, there were extras - extras.extend(arg_strings[stop_index:]) - - # if we didn't use all the Positional objects, there were too few - # arg strings supplied. - if positionals: - self.error(_('too few arguments')) - - # make sure all required actions were present - for action in self._actions: - if action.required: - if action not in seen_actions: - name = _get_action_name(action) - self.error(_('argument %s is required') % name) - - # make sure all required groups had one option present - for group in self._mutually_exclusive_groups: - if group.required: - for action in group._group_actions: - if action in seen_non_default_actions: - break - - # if no actions were used, report the error - else: - names = [_get_action_name(action) - for action in group._group_actions - if action.help is not SUPPRESS] - msg = _('one of the arguments %s is required') - self.error(msg % ' '.join(names)) - - # return the updated namespace and the extra arguments - return namespace, extras - - def _read_args_from_files(self, arg_strings): - # expand arguments referencing files - new_arg_strings = [] - for arg_string in arg_strings: - - # for regular arguments, just add them back into the list - if arg_string[0] not in self.fromfile_prefix_chars: - new_arg_strings.append(arg_string) - - # replace arguments referencing files with the file content - else: - try: - args_file = open(arg_string[1:]) - try: - arg_strings = args_file.read().splitlines() - arg_strings = self._read_args_from_files(arg_strings) - new_arg_strings.extend(arg_strings) - finally: - args_file.close() - except IOError: - err = _sys.exc_info()[1] - self.error(str(err)) - - # return the modified argument list - return new_arg_strings - - def _match_argument(self, action, arg_strings_pattern): - # match the pattern for this action to the arg strings - nargs_pattern = self._get_nargs_pattern(action) - match = _re.match(nargs_pattern, arg_strings_pattern) - - # raise an exception if we weren't able to find a match - if match is None: - nargs_errors = { - None: _('expected one argument'), - OPTIONAL: _('expected at most one argument'), - ONE_OR_MORE: _('expected at least one argument'), - } - default = _('expected %s argument(s)') % action.nargs - msg = nargs_errors.get(action.nargs, default) - raise ArgumentError(action, msg) - - # return the number of arguments matched - return len(match.group(1)) - - def _match_arguments_partial(self, actions, arg_strings_pattern): - # progressively shorten the actions list by slicing off the - # final actions until we find a match - result = [] - for i in range(len(actions), 0, -1): - actions_slice = actions[:i] - pattern = ''.join([self._get_nargs_pattern(action) - for action in actions_slice]) - match = _re.match(pattern, arg_strings_pattern) - if match is not None: - result.extend([len(string) for string in match.groups()]) - break - - # return the list of arg string counts - return result - - def _parse_optional(self, arg_string): - # if it's an empty string, it was meant to be a positional - if not arg_string: - return None - - # if it doesn't start with a prefix, it was meant to be positional - if not arg_string[0] in self.prefix_chars: - return None - - # if it's just dashes, it was meant to be positional - if not arg_string.strip('-'): - return None - - # if the option string is present in the parser, return the action - if arg_string in self._option_string_actions: - action = self._option_string_actions[arg_string] - return action, arg_string, None - - # search through all possible prefixes of the option string - # and all actions in the parser for possible interpretations - option_tuples = self._get_option_tuples(arg_string) - - # if multiple actions match, the option string was ambiguous - if len(option_tuples) > 1: - options = ', '.join([option_string - for action, option_string, explicit_arg in - option_tuples]) - tup = arg_string, options - self.error(_('ambiguous option: %s could match %s') % tup) - - # if exactly one action matched, this segmentation is good, - # so return the parsed action - elif len(option_tuples) == 1: - option_tuple, = option_tuples - return option_tuple - - # if it was not found as an option, but it looks like a negative - # number, it was meant to be positional - # unless there are negative-number-like options - if self._negative_number_matcher.match(arg_string): - if not self._has_negative_number_optionals: - return None - - # if it contains a space, it was meant to be a positional - if ' ' in arg_string: - return None - - # it was meant to be an optional but there is no such option - # in this parser (though it might be a valid option in a subparser) - return None, arg_string, None - - def _get_option_tuples(self, option_string): - result = [] - - # option strings starting with two prefix characters are only - # split at the '=' - chars = self.prefix_chars - if option_string[0] in chars and option_string[1] in chars: - if '=' in option_string: - option_prefix, explicit_arg = option_string.split('=', 1) - else: - option_prefix = option_string - explicit_arg = None - for option_string in self._option_string_actions: - if option_string.startswith(option_prefix): - action = self._option_string_actions[option_string] - tup = action, option_string, explicit_arg - result.append(tup) - - # single character options can be concatenated with their arguments - # but multiple character options always have to have their argument - # separate - elif option_string[0] in chars and option_string[1] not in chars: - option_prefix = option_string - explicit_arg = None - short_option_prefix = option_string[:2] - short_explicit_arg = option_string[2:] - - for option_string in self._option_string_actions: - if option_string == short_option_prefix: - action = self._option_string_actions[option_string] - tup = action, option_string, short_explicit_arg - result.append(tup) - elif option_string.startswith(option_prefix): - action = self._option_string_actions[option_string] - tup = action, option_string, explicit_arg - result.append(tup) - - # shouldn't ever get here - else: - self.error(_('unexpected option string: %s') % option_string) - - # return the collected option tuples - return result - - def _get_nargs_pattern(self, action): - # in all examples below, we have to allow for '--' args - # which are represented as '-' in the pattern - nargs = action.nargs - - # the default (None) is assumed to be a single argument - if nargs is None: - nargs_pattern = '(-*A-*)' - - # allow zero or one arguments - elif nargs == OPTIONAL: - nargs_pattern = '(-*A?-*)' - - # allow zero or more arguments - elif nargs == ZERO_OR_MORE: - nargs_pattern = '(-*[A-]*)' - - # allow one or more arguments - elif nargs == ONE_OR_MORE: - nargs_pattern = '(-*A[A-]*)' - - # allow one argument followed by any number of options or arguments - elif nargs is PARSER: - nargs_pattern = '(-*A[-AO]*)' - - # all others should be integers - else: - nargs_pattern = '(-*%s-*)' % '-*'.join('A' * nargs) - - # if this is an optional action, -- is not allowed - if action.option_strings: - nargs_pattern = nargs_pattern.replace('-*', '') - nargs_pattern = nargs_pattern.replace('-', '') - - # return the pattern - return nargs_pattern - - # ======================== - # Value conversion methods - # ======================== - def _get_values(self, action, arg_strings): - # for everything but PARSER args, strip out '--' - if action.nargs is not PARSER: - arg_strings = [s for s in arg_strings if s != '--'] - - # optional argument produces a default when not present - if not arg_strings and action.nargs == OPTIONAL: - if action.option_strings: - value = action.const - else: - value = action.default - if isinstance(value, _basestring): - value = self._get_value(action, value) - self._check_value(action, value) - - # when nargs='*' on a positional, if there were no command-line - # args, use the default if it is anything other than None - elif (not arg_strings and action.nargs == ZERO_OR_MORE and - not action.option_strings): - if action.default is not None: - value = action.default - else: - value = arg_strings - self._check_value(action, value) - - # single argument or optional argument produces a single value - elif len(arg_strings) == 1 and action.nargs in [None, OPTIONAL]: - arg_string, = arg_strings - value = self._get_value(action, arg_string) - self._check_value(action, value) - - # PARSER arguments convert all values, but check only the first - elif action.nargs is PARSER: - value = [self._get_value(action, v) for v in arg_strings] - self._check_value(action, value[0]) - - # all other types of nargs produce a list - else: - value = [self._get_value(action, v) for v in arg_strings] - for v in value: - self._check_value(action, v) - - # return the converted value - return value - - def _get_value(self, action, arg_string): - type_func = self._registry_get('type', action.type, action.type) - if not hasattr(type_func, '__call__'): - if not hasattr(type_func, '__bases__'): # classic classes - msg = _('%r is not callable') - raise ArgumentError(action, msg % type_func) - - # convert the value to the appropriate type - try: - result = type_func(arg_string) - - # TypeErrors or ValueErrors indicate errors - except (TypeError, ValueError): - name = getattr(action.type, '__name__', repr(action.type)) - msg = _('invalid %s value: %r') - raise ArgumentError(action, msg % (name, arg_string)) - - # return the converted value - return result - - def _check_value(self, action, value): - # converted value must be one of the choices (if specified) - if action.choices is not None and value not in action.choices: - tup = value, ', '.join(map(repr, action.choices)) - msg = _('invalid choice: %r (choose from %s)') % tup - raise ArgumentError(action, msg) - - # ======================= - # Help-formatting methods - # ======================= - def format_usage(self): - formatter = self._get_formatter() - formatter.add_usage(self.usage, self._actions, - self._mutually_exclusive_groups) - return formatter.format_help() - - def format_help(self): - formatter = self._get_formatter() - - # usage - formatter.add_usage(self.usage, self._actions, - self._mutually_exclusive_groups) - - # description - formatter.add_text(self.description) - - # positionals, optionals and user-defined groups - for action_group in self._action_groups: - formatter.start_section(action_group.title) - formatter.add_text(action_group.description) - formatter.add_arguments(action_group._group_actions) - formatter.end_section() - - # epilog - formatter.add_text(self.epilog) - - # determine help from format above - return formatter.format_help() - - def format_version(self): - formatter = self._get_formatter() - formatter.add_text(self.version) - return formatter.format_help() - - def _get_formatter(self): - return self.formatter_class(prog=self.prog) - - # ===================== - # Help-printing methods - # ===================== - def print_usage(self, file=None): - self._print_message(self.format_usage(), file) - - def print_help(self, file=None): - self._print_message(self.format_help(), file) - - def print_version(self, file=None): - self._print_message(self.format_version(), file) - - def _print_message(self, message, file=None): - if message: - if file is None: - file = _sys.stderr - file.write(message) - - # =============== - # Exiting methods - # =============== - def exit(self, status=0, message=None): - if message: - _sys.stderr.write(message) - _sys.exit(status) - - def error(self, message): - """error(message: string) - - Prints a usage message incorporating the message to stderr and - exits. - - If you override this in a subclass, it should not return -- it - should either exit or raise an exception. - """ - self.print_usage(_sys.stderr) - self.exit(2, _('%s: error: %s\n') % (self.prog, message)) diff --git a/dipy/info.py b/dipy/info.py index d8764a7d4e..2582c5a0b4 100644 --- a/dipy/info.py +++ b/dipy/info.py @@ -7,10 +7,10 @@ # full release. '.dev' as a _version_extra string means this is a development # version _version_major = 0 -_version_minor = 14 +_version_minor = 15 _version_micro = 0 +# _version_extra = 'dev' _version_extra = '' -#_version_extra = '' # Format expected by setup.py and doc/source/conf.py: string of form "X.Y.Z" __version__ = "%s.%s.%s%s" % (_version_major, @@ -72,18 +72,16 @@ Please see the LICENSE file in the dipy distribution. DIPY uses other libraries also licensed under the BSD or the -MIT licenses, with the only exception of the SHORE module which -optionally uses the cvxopt library. Cvxopt is licensed -under the GPL license. +MIT licenses. """ # versions for dependencies # Check these versions against .travis.yml and requirements.txt -CYTHON_MIN_VERSION='0.25.1' -NUMPY_MIN_VERSION='1.7.1' -SCIPY_MIN_VERSION='0.9' -NIBABEL_MIN_VERSION='2.1.0' -H5PY_MIN_VERSION='2.4.0' +CYTHON_MIN_VERSION = '0.25.1' +NUMPY_MIN_VERSION = '1.7.1' +SCIPY_MIN_VERSION = '0.9' +NIBABEL_MIN_VERSION = '2.3.0' +H5PY_MIN_VERSION = '2.4.0' # Main setup parameters NAME = 'dipy' diff --git a/dipy/io/image.py b/dipy/io/image.py index 85ba58a016..d14fd1176e 100644 --- a/dipy/io/image.py +++ b/dipy/io/image.py @@ -8,7 +8,7 @@ def load_nifti(fname, return_img=False, return_voxsize=False, img = nib.load(fname) data = img.get_data() vox_size = img.header.get_zooms()[:3] - + ret_val = [data, img.affine] if return_img: diff --git a/dipy/io/streamline.py b/dipy/io/streamline.py index 61634c84cc..786e769957 100644 --- a/dipy/io/streamline.py +++ b/dipy/io/streamline.py @@ -1,10 +1,17 @@ +import os +from functools import partial import nibabel as nib -from nibabel.streamlines import Field +from nibabel.streamlines import (Field, TrkFile, TckFile, + Tractogram, LazyTractogram, + detect_format) from nibabel.orientations import aff2axcodes +from dipy.io.dpy import Dpy, Streamlines -def save_trk(fname, streamlines, affine, vox_size=None, shape=None, header=None): - """ Saves tractogram files (*.trk) +def save_tractogram(fname, streamlines, affine, vox_size=None, shape=None, + header=None, reduce_memory_usage=False, + tractogram_file=None): + """ Saves tractogram files (*.trk or *.tck or *.dpy) Parameters ---------- @@ -20,7 +27,23 @@ def save_trk(fname, streamlines, affine, vox_size=None, shape=None, header=None) The shape of the reference image (default: None) header : dict, optional Metadata associated to the tractogram file(*.trk). (default: None) + reduce_memory_usage : {False, True}, optional + If True, save streamlines in a lazy manner i.e. they will not be kept + in memory. Otherwise, keep all streamlines in memory until saving. + tractogram_file : class TractogramFile, optional + Define tractogram class type (TrkFile vs TckFile) + Default is None which means auto detect format """ + if 'dpy' in os.path.splitext(fname)[1].lower(): + dpw = Dpy(fname, 'w') + dpw.write_tracks(Streamlines(streamlines)) + dpw.close() + return + + tractogram_file = tractogram_file or detect_format(fname) + if tractogram_file is None: + raise ValueError("Unknown format for 'fname': {}".format(fname)) + if vox_size is not None and shape is not None: if not isinstance(header, dict): header = {} @@ -29,19 +52,29 @@ def save_trk(fname, streamlines, affine, vox_size=None, shape=None, header=None) header[Field.DIMENSIONS] = shape header[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine)) - tractogram = nib.streamlines.Tractogram(streamlines) + if reduce_memory_usage and not callable(streamlines): + sg = lambda: (s for s in streamlines) + else: + sg = streamlines + + tractogram_loader = LazyTractogram if reduce_memory_usage else Tractogram + tractogram = tractogram_loader(sg) tractogram.affine_to_rasmm = affine - trk_file = nib.streamlines.TrkFile(tractogram, header=header) - nib.streamlines.save(trk_file, fname) + track_file = tractogram_file(tractogram, header=header) + nib.streamlines.save(track_file, fname) -def load_trk(filename): - """ Loads tractogram files(*.trk) +def load_tractogram(filename, lazy_load=False): + """ Loads tractogram files (*.trk or *.tck or *.dpy) Parameters ---------- filename : str input trk filename + lazy_load : {False, True}, optional + If True, load streamlines in a lazy manner i.e. they will not be kept + in memory and only be loaded when needed. + Otherwise, load all streamlines in memory. Returns ------- @@ -50,5 +83,38 @@ def load_trk(filename): hdr : dict header from a trk file """ - trk_file = nib.streamlines.load(filename) + if 'dpy' in os.path.splitext(filename)[1].lower(): + dpw = Dpy(filename, 'r') + streamlines = dpw.read_tracks() + dpw.close() + return streamlines, {} + + trk_file = nib.streamlines.load(filename, lazy_load) return trk_file.streamlines, trk_file.header + + +load_tck = load_tractogram +load_tck.__doc__ = load_tractogram.__doc__.replace("(*.trk or *.tck or *.dpy)", + "(*.tck)") + + +load_trk = load_tractogram +load_trk.__doc__ = load_tractogram.__doc__.replace("(*.trk or *.tck or *.dpy)", + "(*.trk)") + +load_dpy = load_tractogram +load_dpy.__doc__ = load_tractogram.__doc__.replace("(*.trk or *.tck or *.dpy)", + "(*.dpy)") + +save_tck = partial(save_tractogram, tractogram_file=TckFile) +save_tck.__doc__ = save_tractogram.__doc__.replace("(*.trk or *.tck or *.dpy)", + "(*.tck)") + + +save_trk = partial(save_tractogram, tractogram_file=TrkFile) +save_trk.__doc__ = save_tractogram.__doc__.replace("(*.trk or *.tck or *.dpy)", + "(*.trk)") + +save_dpy = partial(save_tractogram, affine=None) +save_dpy.__doc__ = save_tractogram.__doc__.replace("(*.trk or *.tck or *.dpy)", + "(*.dpy)") diff --git a/dipy/io/tests/test_dpy.py b/dipy/io/tests/test_dpy.py index eebad3c903..804e01978c 100644 --- a/dipy/io/tests/test_dpy.py +++ b/dipy/io/tests/test_dpy.py @@ -1,12 +1,10 @@ -import os import numpy as np from nibabel.tmpdirs import InTemporaryDirectory -from dipy.io.dpy import Dpy +from dipy.io.dpy import Dpy, Streamlines import numpy.testing as npt -from dipy.tracking.streamline import Streamlines def test_dpy(): @@ -36,5 +34,4 @@ def test_dpy(): if __name__ == '__main__': - - npt.run_module_suite() \ No newline at end of file + npt.run_module_suite() diff --git a/dipy/io/tests/test_io_gradients.py b/dipy/io/tests/test_io_gradients.py index e271606c9e..089cf4b4df 100644 --- a/dipy/io/tests/test_io_gradients.py +++ b/dipy/io/tests/test_io_gradients.py @@ -6,13 +6,13 @@ import numpy as np import numpy.testing as npt -from dipy.data import get_data +from dipy.data import get_fnames from dipy.io.gradients import read_bvals_bvecs from dipy.core.gradients import gradient_table def test_read_bvals_bvecs(): - fimg, fbvals, fbvecs = get_data('small_101D') + fimg, fbvals, fbvecs = get_fnames('small_101D') bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gt = gradient_table(bvals, bvecs) npt.assert_array_equal(bvals, gt.bvals) diff --git a/dipy/io/tests/test_streamline.py b/dipy/io/tests/test_streamline.py index ef88d32fb7..e9f1edbfa9 100644 --- a/dipy/io/tests/test_streamline.py +++ b/dipy/io/tests/test_streamline.py @@ -1,12 +1,14 @@ from __future__ import division, print_function, absolute_import -import os import numpy as np import numpy.testing as npt import nibabel as nib from nibabel.tmpdirs import InTemporaryDirectory -from dipy.io.streamline import save_trk, load_trk +from dipy.io.streamline import (save_trk, load_trk, save_tractogram, + save_tck, load_tck, load_tractogram, + save_dpy, load_dpy) from dipy.io.trackvis import save_trk as trackvis_save_trk +from dipy.tracking.streamline import Streamlines streamline = np.array([[82.20181274, 91.36505891, 43.15737152], [82.38442231, 91.79336548, 43.87036514], @@ -128,9 +130,9 @@ [68.25946808, 90.94654083, 130.92756653]], dtype=np.float32) -streamlines = [streamline[[0, 10]], streamline, - streamline[::2], streamline[::3], - streamline[::5], streamline[::6]] +streamlines = Streamlines([streamline[[0, 10]], streamline, + streamline[::2], streamline[::3], + streamline[::5], streamline[::6]]) def test_io_streamline(): @@ -139,28 +141,107 @@ def test_io_streamline(): affine = np.eye(4) # Test save - save_trk(fname, streamlines, affine, vox_size=np.array([2, 1.5, 1.5]), shape=np.array([50, 50, 50])) + save_tractogram(fname, streamlines, affine, + vox_size=np.array([2, 1.5, 1.5]), + shape=np.array([50, 50, 50])) tfile = nib.streamlines.load(fname) npt.assert_array_equal(affine, tfile.affine) - npt.assert_array_equal(np.array([2, 1.5, 1.5]), tfile.header.get('voxel_sizes')) - npt.assert_array_equal(np.array([50, 50, 50]), tfile.header.get('dimensions')) + npt.assert_array_equal(np.array([2, 1.5, 1.5]), + tfile.header.get('voxel_sizes')) + npt.assert_array_equal(np.array([50, 50, 50]), + tfile.header.get('dimensions')) npt.assert_equal(len(tfile.streamlines), len(streamlines)) - npt.assert_array_almost_equal(tfile.streamlines[1], streamline, decimal=4) + npt.assert_array_almost_equal(tfile.streamlines[1], streamline, + decimal=4) # Test basic save - save_trk(fname, streamlines, affine) + save_tractogram(fname, streamlines, affine) tfile = nib.streamlines.load(fname) npt.assert_array_equal(affine, tfile.affine) npt.assert_equal(len(tfile.streamlines), len(streamlines)) - npt.assert_array_almost_equal(tfile.streamlines[1], streamline, decimal=5) + npt.assert_array_almost_equal(tfile.streamlines[1], streamline, + decimal=5) # Test Load - local_streamlines, hdr = load_trk(fname) + local_streamlines, hdr = load_tractogram(fname) npt.assert_equal(len(local_streamlines), len(streamlines)) for arr1, arr2 in zip(local_streamlines, streamlines): npt.assert_allclose(arr1, arr2) +def io_tractogram(load_fn, save_fn, extension): + with InTemporaryDirectory(): + fname = 'test.{}'.format(extension) + affine = np.eye(4) + + # Test save + save_fn(fname, streamlines, affine, vox_size=np.array([2, 1.5, 1.5]), + shape=np.array([50, 50, 50])) + tfile = nib.streamlines.load(fname) + npt.assert_array_equal(affine, tfile.affine) + vox_size = tfile.header.get('voxel_sizes') + dims = tfile.header.get('dimensions') + if isinstance(vox_size, str): + vox_size = vox_size.replace('[', '').replace(']', '') + vox_size = np.fromstring(vox_size, sep=" ", dtype=np.float) + if isinstance(dims, str): + dims = dims.replace('[', '').replace(']', '') + dims = np.fromstring(dims, sep=" ", dtype=np.int) + npt.assert_array_equal(np.array([2, 1.5, 1.5]), vox_size) + npt.assert_array_equal(np.array([50, 50, 50]), dims) + npt.assert_equal(len(tfile.streamlines), len(streamlines)) + npt.assert_array_almost_equal(tfile.streamlines[1], streamline, + decimal=4) + + # Test basic save + save_fn(fname, streamlines, affine) + tfile = nib.streamlines.load(fname) + npt.assert_array_equal(affine, tfile.affine) + npt.assert_equal(len(tfile.streamlines), len(streamlines)) + npt.assert_array_almost_equal(tfile.streamlines[1], streamline, + decimal=5) + + # Test lazy save + save_fn(fname, streamlines, affine, vox_size=np.array([2, 1.5, 1.5]), + shape=np.array([50, 50, 50]), reduce_memory_usage=True) + tfile = nib.streamlines.load(fname) + npt.assert_array_equal(affine, tfile.affine) + npt.assert_equal(len(tfile.streamlines), len(streamlines)) + npt.assert_array_almost_equal(tfile.streamlines[1], streamline, + decimal=4) + + # Test Load + local_streamlines, hdr = load_fn(fname) + npt.assert_equal(len(local_streamlines), len(streamlines)) + for arr1, arr2 in zip(local_streamlines, streamlines): + npt.assert_allclose(arr1, arr2, rtol=1e4) + + # Test lazy Load + local_streamlines, hdr = load_fn(fname, lazy_load=True) + for arr1, arr2 in zip(local_streamlines, streamlines): + npt.assert_allclose(arr1, arr2, rtol=1e4) + + +def test_io_trk(): + io_tractogram(load_trk, save_trk, "trk") + + +def test_io_tck(): + io_tractogram(load_tck, save_tck, "tck") + + +def test_io_dpy(): + with InTemporaryDirectory(): + fname = 'test.dpy' + + # Test save + save_dpy(fname, streamlines) + tracks, _ = load_dpy(fname) + npt.assert_equal(len(tracks), len(streamlines)) + npt.assert_array_almost_equal(tracks[1], streamline, + decimal=4) + + def test_trackvis(): with InTemporaryDirectory(): fname = 'trackvis_test.trk' @@ -170,13 +251,17 @@ def test_trackvis(): trackvis_save_trk(fname, streamlines, affine, np.array([50, 50, 50])) tfile = nib.streamlines.load(fname) npt.assert_array_equal(affine, tfile.affine) - npt.assert_array_equal(np.array([1., 1., 1.]), tfile.header.get('voxel_sizes')) - npt.assert_array_equal(np.array([50, 50, 50]), tfile.header.get('dimensions')) + npt.assert_array_equal(np.array([1., 1., 1.]), + tfile.header.get('voxel_sizes')) + npt.assert_array_equal(np.array([50, 50, 50]), + tfile.header.get('dimensions')) npt.assert_equal(len(tfile.streamlines), len(streamlines)) - npt.assert_array_almost_equal(tfile.streamlines[1], streamline, decimal=4) + npt.assert_array_almost_equal(tfile.streamlines[1], streamline, + decimal=4) # Test Deprecations - npt.assert_warns(DeprecationWarning, trackvis_save_trk, fname, streamlines, affine, np.array([50, 50, 50])) + npt.assert_warns(DeprecationWarning, trackvis_save_trk, fname, + streamlines, affine, np.array([50, 50, 50])) if __name__ == '__main__': diff --git a/dipy/io/tests/test_utils.py b/dipy/io/tests/test_utils.py new file mode 100644 index 0000000000..da5f4489f9 --- /dev/null +++ b/dipy/io/tests/test_utils.py @@ -0,0 +1,18 @@ +from dipy.io.utils import decfa +from nibabel import Nifti1Image +import numpy as np + + +def test_decfa(): + data_orig = np.zeros((4, 4, 4, 3)) + data_orig[0, 0, 0] = np.array([1, 0, 0]) + img_orig = Nifti1Image(data_orig, np.eye(4)) + img_new = decfa(img_orig) + data_new = img_new.get_data() + assert data_new[0, 0, 0] == np.array((1, 0, 0), + dtype=np.dtype([('R', 'uint8'), + ('G', 'uint8'), + ('B', 'uint8')])) + assert data_new.dtype == np.dtype([('R', 'uint8'), + ('G', 'uint8'), + ('B', 'uint8')]) diff --git a/dipy/io/utils.py b/dipy/io/utils.py index 26679e806e..b681ae80da 100644 --- a/dipy/io/utils.py +++ b/dipy/io/utils.py @@ -44,3 +44,42 @@ def make5d(input): shape = input.shape shape = shape[:-1] + (1,)*(5-len(shape)) + shape[-1:] return input.reshape(shape) + + +def decfa(img_orig): + """ + Create a nifti-compliant directional-encoded color FA file. + + Parameters + ---------- + data : Nifti1Image class instance. + Contains encoding of the DEC FA image with a 4D volume of data, where + the elements on the last dimension represent R, G and B components. + + Returns + ------- + img : Nifti1Image class instance. + + + Notes + ----- + For a description of this format, see: + + https://nifti.nimh.nih.gov/nifti-1/documentation/nifti1fields/nifti1fields_pages/datatype.html + """ + + dest_dtype = np.dtype([('R', 'uint8'), ('G', 'uint8'), ('B', 'uint8')]) + out_data = np.zeros(img_orig.shape[:3], dtype=dest_dtype) + + data_orig = img_orig.get_data() + + for ii in np.ndindex(img_orig.shape[:3]): + val = data_orig[ii] + out_data[ii] = (val[0], val[1], val[2]) + + new_hdr = img_orig.get_header() + new_hdr['dim'][4] = 1 + new_hdr.set_intent(1001, name='Color FA') + new_hdr.set_data_dtype(dest_dtype) + + return Nifti1Image(out_data, affine=img_orig.affine, header=new_hdr) diff --git a/dipy/io/vtk.py b/dipy/io/vtk.py index a44058b852..0d998291c3 100644 --- a/dipy/io/vtk.py +++ b/dipy/io/vtk.py @@ -1,22 +1,17 @@ from __future__ import division, print_function, absolute_import -from dipy.viz.utils import set_input - -# Conditional import machinery for vtk +# Conditional import machinery for fury from dipy.utils.optpkg import optional_package -# Allow import, but disable doctests if we don't have vtk -vtk, have_vtk, setup_module = optional_package('vtk') -colors, have_vtk_colors, _ = optional_package('vtk.util.colors') -ns, have_numpy_support, _ = optional_package('vtk.util.numpy_support') +# Allow import, but disable doctests if we don't have fury +fury, have_fury, setup_module = optional_package('fury') -if have_vtk: - version = vtk.vtkVersion.GetVTKSourceVersion().split(' ')[-1] - major_version = vtk.vtkVersion.GetVTKMajorVersion() +if have_fury: + from dipy.viz import utils, vtk def load_polydata(file_name): - """ Load a vtk polydata to a supported format file + """Load a vtk polydata to a supported format file. Supported file formats are OBJ, VTK, FIB, PLY, STL and XML @@ -27,6 +22,7 @@ def load_polydata(file_name): Returns ------- output : vtkPolyData + """ # get file extension (type) lower case file_extension = file_name.split(".")[-1].lower() @@ -56,7 +52,7 @@ def load_polydata(file_name): def save_polydata(polydata, file_name, binary=False, color_array_name=None): - """ Save a vtk polydata to a supported format file + """Save a vtk polydata to a supported format file. Save formats can be VTK, FIB, PLY, STL and XML. @@ -64,6 +60,7 @@ def save_polydata(polydata, file_name, binary=False, color_array_name=None): ---------- polydata : vtkPolyData file_name : string + """ # get file extension (type) file_extension = file_name.split(".")[-1].lower() @@ -80,10 +77,10 @@ def save_polydata(polydata, file_name, binary=False, color_array_name=None): writer = vtk.vtkXMLPolyDataWriter() elif file_extension == "obj": raise Exception("mni obj or Wavefront obj ?") - # writer = set_input(vtk.vtkMNIObjectWriter(), polydata) + # writer = utils.set_input(vtk.vtkMNIObjectWriter(), polydata) writer.SetFileName(file_name) - writer = set_input(writer, polydata) + writer = utils.set_input(writer, polydata) if color_array_name is not None: writer.SetArrayName(color_array_name) diff --git a/dipy/reconst/csdeconv.py b/dipy/reconst/csdeconv.py index 31b145247f..08d68276aa 100644 --- a/dipy/reconst/csdeconv.py +++ b/dipy/reconst/csdeconv.py @@ -702,8 +702,12 @@ def odf_sh_to_sharp(odfs_sh, sphere, basis=None, ratio=3 / 15., sh_order=8, array of odfs expressed as spherical harmonics coefficients sphere : Sphere sphere used to build the regularization matrix - basis : {None, 'mrtrix', 'fibernav'} - different spherical harmonic basis. None is the fibernav basis as well. + basis : {None, 'tournier07', 'descoteaux07'} + different spherical harmonic basis: + ``None`` for the default DIPY basis, + ``tournier07`` for the Tournier 2007 [4]_ basis, and + ``descoteaux07`` for the Descoteaux 2007 [3]_ basis + (``None`` defaults to ``descoteaux07``). ratio : float, ratio of the smallest vs the largest eigenvalue of the single prolate tensor response function (:math:`\frac{\lambda_2}{\lambda_1}`) @@ -737,8 +741,14 @@ def odf_sh_to_sharp(odfs_sh, sphere, basis=None, ratio=3 / 15., sh_order=8, .. [2] Descoteaux, M., et al. IEEE TMI 2009. Deterministic and Probabilistic Tractography Based on Complex Fibre Orientation Distributions - .. [3] Descoteaux, M, et al. MRM 2007. Fast, Regularized and Analytical - Q-Ball Imaging + .. [3] Descoteaux, M., Angelino, E., Fitzgibbons, S. and Deriche, R. + Regularized, Fast, and Robust Analytical Q-ball Imaging. + Magn. Reson. Med. 2007;58:497-510. + .. [4] Tournier J.D., Calamante F. and Connelly A. Robust determination + of the fibre orientation distribution in diffusion MRI: + Non-negativity constrained super-resolved spherical deconvolution. + NeuroImage. 2007;35(4):1459-1472. + """ r, theta, phi = cart2sphere(sphere.x, sphere.y, sphere.z) real_sym_sh = sph_harm_lookup[basis] @@ -811,9 +821,9 @@ def auto_response(gtab, data, roi_center=None, roi_radius=10, fa_thr=0.7, fa_thr : float FA threshold fa_callable : callable - A callable that defines an operation that compares FA with the fa_thr. The operator - should have two positional arguments (e.g., `fa_operator(FA, fa_thr)`) and it should - return a bool array. + A callable that defines an operation that compares FA with the fa_thr. + The operator should have two positional arguments + (e.g., `fa_operator(FA, fa_thr)`) and it should return a bool array. return_number_of_voxels : bool If True, returns the number of voxels used for estimating the response function. diff --git a/dipy/reconst/dki.py b/dipy/reconst/dki.py index cec7ece4f4..57cf10978a 100644 --- a/dipy/reconst/dki.py +++ b/dipy/reconst/dki.py @@ -121,6 +121,7 @@ def carlson_rd(x, y, z, errtol=1e-4): defined as: .. math:: + R_D = \frac{3}{2} \int_{0}^{\infty} (t+x)^{-\frac{1}{2}} (t+y)^{-\frac{1}{2}}(t+z) ^{-\frac{3}{2}} @@ -302,6 +303,7 @@ def _F2m(a, b, c): Function $F_2$ is defined as [1]_: .. math:: + F_2(\lambda_1,\lambda_2,\lambda_3)= \frac{(\lambda_1+\lambda_2+\lambda_3)^2} {3(\lambda_2-\lambda_3)^2} @@ -574,6 +576,7 @@ def apparent_kurtosis_coef(dki_params, sphere, min_diffusivity=0, calculation of AKC is done using formula [1]_: .. math :: + AKC(n)=\frac{MD^{2}}{ADC(n)^{2}}\sum_{i=1}^{3}\sum_{j=1}^{3} \sum_{k=1}^{3}\sum_{l=1}^{3}n_{i}n_{j}n_{k}n_{l}W_{ijkl} @@ -662,36 +665,39 @@ def mean_kurtosis(dki_params, min_kurtosis=-3./7, max_kurtosis=3): Notes -------- The MK analytical solution is calculated using the following equation [1]_: + .. math:: - MK=F_1(\lambda_1,\lambda_2,\lambda_3)\hat{W}_{1111}+ - F_1(\lambda_2,\lambda_1,\lambda_3)\hat{W}_{2222}+ - F_1(\lambda_3,\lambda_2,\lambda_1)\hat{W}_{3333}+ \\ - F_2(\lambda_1,\lambda_2,\lambda_3)\hat{W}_{2233}+ - F_2(\lambda_2,\lambda_1,\lambda_3)\hat{W}_{1133}+ - F_2(\lambda_3,\lambda_2,\lambda_1)\hat{W}_{1122} + MK=F_1(\lambda_1,\lambda_2,\lambda_3)\hat{W}_{1111}+ + F_1(\lambda_2,\lambda_1,\lambda_3)\hat{W}_{2222}+ + F_1(\lambda_3,\lambda_2,\lambda_1)\hat{W}_{3333}+ \\ + F_2(\lambda_1,\lambda_2,\lambda_3)\hat{W}_{2233}+ + F_2(\lambda_2,\lambda_1,\lambda_3)\hat{W}_{1133}+ + F_2(\lambda_3,\lambda_2,\lambda_1)\hat{W}_{1122} where $\hat{W}_{ijkl}$ are the components of the $W$ tensor in the coordinates system defined by the eigenvectors of the diffusion tensor $\mathbf{D}$ and - F_1(\lambda_1,\lambda_2,\lambda_3)= - \frac{(\lambda_1+\lambda_2+\lambda_3)^2} - {18(\lambda_1-\lambda_2)(\lambda_1-\lambda_3)} - [\frac{\sqrt{\lambda_2\lambda_3}}{\lambda_1} - R_F(\frac{\lambda_1}{\lambda_2},\frac{\lambda_1}{\lambda_3},1)+\\ - \frac{3\lambda_1^2-\lambda_1\lambda_2-\lambda_2\lambda_3- - \lambda_1\lambda_3} - {3\lambda_1 \sqrt{\lambda_2 \lambda_3}} - R_D(\frac{\lambda_1}{\lambda_2},\frac{\lambda_1}{\lambda_3},1)-1 ] - - F_2(\lambda_1,\lambda_2,\lambda_3)= - \frac{(\lambda_1+\lambda_2+\lambda_3)^2} - {3(\lambda_2-\lambda_3)^2} - [\frac{\lambda_2+\lambda_3}{\sqrt{\lambda_2\lambda_3}} - R_F(\frac{\lambda_1}{\lambda_2},\frac{\lambda_1}{\lambda_3},1)+\\ - \frac{2\lambda_1-\lambda_2-\lambda_3}{3\sqrt{\lambda_2 \lambda_3}} - R_D(\frac{\lambda_1}{\lambda_2},\frac{\lambda_1}{\lambda_3},1)-2] + .. math:: + + F_1(\lambda_1,\lambda_2,\lambda_3)= + \frac{(\lambda_1+\lambda_2+\lambda_3)^2} + {18(\lambda_1-\lambda_2)(\lambda_1-\lambda_3)} + [\frac{\sqrt{\lambda_2\lambda_3}}{\lambda_1} + R_F(\frac{\lambda_1}{\lambda_2},\frac{\lambda_1}{\lambda_3},1)+\\ + \frac{3\lambda_1^2-\lambda_1\lambda_2-\lambda_2\lambda_3- + \lambda_1\lambda_3} + {3\lambda_1 \sqrt{\lambda_2 \lambda_3}} + R_D(\frac{\lambda_1}{\lambda_2},\frac{\lambda_1}{\lambda_3},1)-1 ] + + F_2(\lambda_1,\lambda_2,\lambda_3)= + \frac{(\lambda_1+\lambda_2+\lambda_3)^2} + {3(\lambda_2-\lambda_3)^2} + [\frac{\lambda_2+\lambda_3}{\sqrt{\lambda_2\lambda_3}} + R_F(\frac{\lambda_1}{\lambda_2},\frac{\lambda_1}{\lambda_3},1)+\\ + \frac{2\lambda_1-\lambda_2-\lambda_3}{3\sqrt{\lambda_2 \lambda_3}} + R_D(\frac{\lambda_1}{\lambda_2},\frac{\lambda_1}{\lambda_3},1)-2] where $R_f$ and $R_d$ are the Carlson's elliptic integrals. @@ -764,6 +770,7 @@ def _G1m(a, b, c): Notes -------- Function $G_1$ is defined as [1]_: + .. math:: G_1(\lambda_1,\lambda_2,\lambda_3)= @@ -829,7 +836,9 @@ def _G2m(a, b, c): Notes -------- Function $G_2$ is defined as [1]_: + .. math:: + G_2(\lambda_1,\lambda_2,\lambda_3)= \frac{(\lambda_1+\lambda_2+\lambda_3)^2}{(\lambda_2-\lambda_3)^2} \left ( \frac{\lambda_2+\lambda_3}{\sqrt{\lambda_2\lambda_3}}-2\right ) @@ -899,14 +908,18 @@ def radial_kurtosis(dki_params, min_kurtosis=-3./7, max_kurtosis=10): Notes -------- - RK is calculated with the following equation [1]_:: + RK is calculated with the following equation [1]_: + .. math:: + K_{\bot} = G_1(\lambda_1,\lambda_2,\lambda_3)\hat{W}_{2222} + G_1(\lambda_1,\lambda_3,\lambda_2)\hat{W}_{3333} + G_2(\lambda_1,\lambda_2,\lambda_3)\hat{W}_{2233} where: + .. math:: + G_1(\lambda_1,\lambda_2,\lambda_3)= \frac{(\lambda_1+\lambda_2+\lambda_3)^2}{18\lambda_2(\lambda_2- \lambda_3)} \left (2\lambda_2 + @@ -916,6 +929,7 @@ def radial_kurtosis(dki_params, min_kurtosis=-3./7, max_kurtosis=10): and .. math:: + G_2(\lambda_1,\lambda_2,\lambda_3)= \frac{(\lambda_1+\lambda_2+\lambda_3)^2}{(\lambda_2-\lambda_3)^2} \left ( \frac{\lambda_2+\lambda_3}{\sqrt{\lambda_2\lambda_3}}-2\right ) @@ -1220,6 +1234,7 @@ def dki_prediction(dki_params, gtab, S0=1.): .. math:: S=S_{0}e^{-bD+\frac{1}{6}b^{2}D^{2}K} + """ evals, evecs, kt = split_dki_param(dki_params) @@ -1417,6 +1432,7 @@ def akc(self, sphere): calculation of AKC is done using formula: .. math :: + AKC(n)=\frac{MD^{2}}{ADC(n)^{2}}\sum_{i=1}^{3}\sum_{j=1}^{3} \sum_{k=1}^{3}\sum_{l=1}^{3}n_{i}n_{j}n_{k}n_{l}W_{ijkl} @@ -1424,6 +1440,7 @@ def akc(self, sphere): diffusivity and ADC the apparent diffusion coefficent computed as: .. math :: + ADC(n)=\sum_{i=1}^{3}\sum_{j=1}^{3}n_{i}n_{j}D_{ij} where $D_{ij}$ are the elements of the diffusion tensor. @@ -2014,7 +2031,7 @@ def Wcons(k_elements): k_elements : (15,) elements of the kurtosis tensor in the following order: - .. math:: + .. math:: \begin{matrix} ( & W_{xxxx} & W_{yyyy} & W_{zzzz} & W_{xxxy} & W_{xxxz} & ... \\ diff --git a/dipy/reconst/dsi.py b/dipy/reconst/dsi.py index 928cdadf34..09c285041e 100644 --- a/dipy/reconst/dsi.py +++ b/dipy/reconst/dsi.py @@ -5,6 +5,7 @@ from dipy.reconst.cache import Cache from dipy.reconst.multi_voxel import multi_voxel_fit +from dipy.testing import setup_test class DiffusionSpectrumModel(OdfModel, Cache): @@ -73,6 +74,8 @@ def __init__(self, and a reconstruction sphere, we calculate generalized FA for the first voxel in the data with the reconstruction performed using DSI. + >>> import warnings + >>> warnings.simplefilter("default") >>> from dipy.data import dsi_voxels, get_sphere >>> data, gtab = dsi_voxels() >>> sphere = get_sphere('symmetric724') diff --git a/dipy/reconst/forecast.py b/dipy/reconst/forecast.py index 89e5cf7075..28e59ddc7e 100644 --- a/dipy/reconst/forecast.py +++ b/dipy/reconst/forecast.py @@ -16,10 +16,11 @@ class ForecastModel(OdfModel, Cache): - r"""Fiber ORientation Estimated using Continuous Axially Symmetric Tensors - (FORECAST) [1,2,3]_. FORECAST is a Spherical Deconvolution reconstruction model - for multi-shell diffusion data which enables the calculation of a voxel - adaptive response function using the Spherical Mean Tecnique (SMT) [2,3]_. + r"""Fiber ORientation Estimated using Continuous Axially Symmetric Tensors + (FORECAST) [1,2,3]_. FORECAST is a Spherical Deconvolution reconstruction + model for multi-shell diffusion data which enables the calculation of a + voxel adaptive response function using the Spherical Mean Tecnique (SMT) + [2,3]_. With FORECAST it is possible to calculate crossing invariant parallel diffusivity, perpendicular diffusivity, mean diffusivity, and fractional @@ -31,8 +32,8 @@ class ForecastModel(OdfModel, Cache): Using High Angular Resolution Diffusion Imaging", Magnetic Resonance in Medicine, 2005. - .. [2] Kaden E. et al., "Quantitative Mapping of the Per-Axon Diffusion - Coefficients in Brain White Matter", Magnetic Resonance in + .. [2] Kaden E. et al., "Quantitative Mapping of the Per-Axon Diffusion + Coefficients in Brain White Matter", Magnetic Resonance in Medicine, 2016. .. [3] Zucchelli E. et al., "A generalized SMT-based framework for @@ -52,11 +53,11 @@ def __init__(self, lambda_csd=1.0): r""" Analytical and continuous modeling of the diffusion signal with respect to the FORECAST basis [1,2,3]_. - This implementation is a modification of the original FORECAST + This implementation is a modification of the original FORECAST model presented in [1]_ adapted for multi-shell data as in [2,3]_ . The main idea is to model the diffusion signal as the combination of a - single fiber response function $F(\mathbf{b})$ times the fODF + single fiber response function $F(\mathbf{b})$ times the fODF $\rho(\mathbf{v})$ ..math:: @@ -82,7 +83,7 @@ def __init__(self, Laplace-Beltrami regularization weight. dec_alg : str, Spherical deconvolution algorithm. The possible values are Weighted Least Squares ('WLS'), - Positivity Constraints using CVXPY ('POS') and the Constraint + Positivity Constraints using CVXPY ('POS') and the Constraint Spherical Deconvolution algorithm ('CSD'). Default is 'CSD'. sphere : array, shape (N,3), sphere points where to enforce positivity when 'POS' or 'CSD' @@ -96,8 +97,8 @@ def __init__(self, Using High Angular Resolution Diffusion Imaging", Magnetic Resonance in Medicine, 2005. - .. [2] Kaden E. et al., "Quantitative Mapping of the Per-Axon Diffusion - Coefficients in Brain White Matter", Magnetic Resonance in + .. [2] Kaden E. et al., "Quantitative Mapping of the Per-Axon Diffusion + Coefficients in Brain White Matter", Magnetic Resonance in Medicine, 2016. .. [3] Zucchelli M. et al., "A generalized SMT-based framework for @@ -108,7 +109,7 @@ def __init__(self, -------- In this example, where the data, gradient table and sphere tessellation used for reconstruction are provided, we model the diffusion signal - with respect to the FORECAST and compute the fODF, parallel and + with respect to the FORECAST and compute the fODF, parallel and perpendicular diffusivity. >>> from dipy.data import get_sphere, get_3shell_gtab @@ -243,9 +244,10 @@ def fit(self, data): coef = np.r_[c0, coef] if self.csd: - coef, num_it = csdeconv(data_single_b0, M, self.fod, tau=0.1, convergence=50) + coef, _ = csdeconv(data_single_b0, M, self.fod, tau=0.1, + convergence=50) coef = coef / coef[0] * c0 - + if self.pos: c = cvxpy.Variable(M.shape[1]) design_matrix = cvxpy.Constant(M) @@ -256,7 +258,7 @@ def fit(self, data): constraints = [c[0] == c0, self.fod * c >= 0] prob = cvxpy.Problem(objective, constraints) try: - prob.solve() + prob.solve(solver=cvxpy.OSQP, eps_abs=1e-05, eps_rel=1e-05) coef = np.asarray(c.value).squeeze() except Exception: warn('Optimization did not find a solution') @@ -304,7 +306,6 @@ def odf(self, sphere, clip_negative=True): clip_negative : boolean, optional if True clip the negative odf values to 0, default True """ - if self.rho is None: self.rho = rho_matrix(self.sh_order, sphere.vertices) @@ -337,11 +338,11 @@ def predict(self, gtab=None, S0=1.0): gradient directions and bvalues container class. S0 : float, optional the signal at b-value=0 - + """ if gtab is None: gtab = self.gtab - + M_diff = forecast_matrix(self.sh_order, self.d_par, self.d_perp, @@ -373,7 +374,7 @@ def dperp(self): def find_signal_means(b_unique, data_norm, bvals, rho, lb_matrix, w=1e-03): - r"""Calculates the mean signal for each shell + r"""Calculate the mean signal for each shell. Parameters ---------- @@ -388,7 +389,7 @@ def find_signal_means(b_unique, data_norm, bvals, rho, lb_matrix, w=1e-03): lb_matrix : 2d ndarray, Laplace-Beltrami regularization matrix w : float, - weight for the Laplace-Beltrami regularization + weight for the Laplace-Beltrami regularization Returns ------- @@ -433,7 +434,7 @@ def forecast_error_func(x, b_unique, E): return v -def psi_l(l,b): +def psi_l(l, b): n = l//2 v = (-b)**n v *= gamma(n + 1./2) / gamma(2*n + 3./2) @@ -479,8 +480,8 @@ def lb_forecast(sh_order): diag_lb = np.zeros(n_c) counter = 0 for l in range(0, sh_order + 1, 2): - for m in range(-l, l + 1): - diag_lb[counter] = (l * (l + 1)) ** 2 - counter += 1 + stop = 2 * l + 1 + counter + diag_lb[counter:stop] = (l * (l + 1)) ** 2 + counter = stop return np.diag(diag_lb) diff --git a/dipy/reconst/ivim.py b/dipy/reconst/ivim.py index 91b170de68..0fb6259659 100644 --- a/dipy/reconst/ivim.py +++ b/dipy/reconst/ivim.py @@ -135,7 +135,7 @@ def __init__(self, gtab, split_b_D=400.0, split_b_S0=200., bounds=None, x_scale=[1000., 0.1, 0.001, 0.0001], options={'gtol': 1e-15, 'ftol': 1e-15, 'eps': 1e-15, 'maxiter': 1000}): - """ + r""" Initialize an IVIM model. The IVIM model assumes that biological tissue includes a volume @@ -215,6 +215,16 @@ def __init__(self, gtab, split_b_D=400.0, split_b_S0=200., bounds=None, e_s += "The IVIM model requires signal measured at 0 bvalue" raise ValueError(e_s) + if gtab.b0_threshold > 0: + b0_s = "The IVIM model requires a measurement at b==0. As of " + b0_s += "version 0.15, the default b0_threshold for the " + b0_s += "GradientTable object is set to 50, so if you used the " + b0_s += "default settings to initialize the gtab input to the " + b0_s += "IVIM model, you may have provided a gtab with " + b0_s += "b0_threshold larger than 0. Please initialize the gtab " + b0_s += "input with b0_threshold=0" + raise ValueError(b0_s) + ReconstModel.__init__(self, gtab) self.split_b_D = split_b_D self.split_b_S0 = split_b_S0 @@ -381,7 +391,7 @@ def estimate_f_D_star(self, params_f_D_star, data, S0, D): warningMsg += " as initial guess for leastsq. Parameters are" warningMsg += " returned only from the linear fit." warnings.warn(warningMsg, UserWarning) - f, D_star = params_f_D + f, D_star = params_f_D_star return f, D_star else: try: diff --git a/dipy/reconst/mapmri.py b/dipy/reconst/mapmri.py index b465cdc96e..33c765cb91 100644 --- a/dipy/reconst/mapmri.py +++ b/dipy/reconst/mapmri.py @@ -401,8 +401,8 @@ def fit(self, data): lopt * cvxpy.quad_form(c, laplacian_matrix) ) M0 = M[self.gtab.b0s_mask, :] - constraints = [M0[0] * c == 1, - K * c > -.1] + constraints = [(M0[0] * c) == 1, + (K * c) >= -0.1] prob = cvxpy.Problem(objective, constraints) try: prob.solve(solver=self.cvxpy_solver) @@ -1682,6 +1682,33 @@ def mapmri_isotropic_laplacian_reg_matrix(radial_order, mu): NeuroImage (2016). ''' ind_mat = mapmri_isotropic_index_matrix(radial_order) + return mapmri_isotropic_laplacian_reg_matrix_from_index_matrix( + ind_mat, mu + ) + + +def mapmri_isotropic_laplacian_reg_matrix_from_index_matrix(ind_mat, mu): + r''' Computes the Laplacian regularization matrix for MAP-MRI's isotropic + implementation [1]_ eq. (C7). + + Parameters + ---------- + ind_mat : matrix (N_coef, 3), + Basis order matrix + mu : float, + isotropic scale factor of the isotropic MAP-MRI basis + + Returns + ------- + LR : Matrix, shape (N_coef, N_coef) + Laplacian regularization matrix + + References + ---------- + .. [1]_ Fick, Rutger HJ, et al. "MAPL: Tissue microstructure estimation + using Laplacian-regularized MAP-MRI and its application to HCP data." + NeuroImage (2016). + ''' n_elem = ind_mat.shape[0] LR = np.zeros((n_elem, n_elem)) @@ -1984,7 +2011,7 @@ def generalized_crossvalidation_array(data, M, LR, weights_array=None): gcvold = gcvnew i = i + 1 S = np.dot(np.dot(M, np.linalg.pinv(MMt + lrange[i] * LR)), M.T) - trS = np.matrix.trace(S) + trS = np.trace(S) normyytilde = np.linalg.norm(data - np.dot(S, data), 2) gcvnew = normyytilde / (K - trS) lopt = lrange[i - 1] @@ -2033,7 +2060,7 @@ def gcv_cost_function(weight, args): """ data, M, MMt, K, LR = args S = np.dot(np.dot(M, np.linalg.pinv(MMt + weight * LR)), M.T) - trS = np.matrix.trace(S) + trS = np.trace(S) normyytilde = np.linalg.norm(data - np.dot(S, data), 2) gcv_value = normyytilde / (K - trS) return gcv_value diff --git a/dipy/reconst/qtdmri.py b/dipy/reconst/qtdmri.py new file mode 100644 index 0000000000..05782c77db --- /dev/null +++ b/dipy/reconst/qtdmri.py @@ -0,0 +1,2118 @@ +# -*- coding: utf-8 -*- +import numpy as np +from dipy.reconst.cache import Cache +from dipy.core.geometry import cart2sphere +from dipy.reconst.multi_voxel import multi_voxel_fit +from scipy.special import genlaguerre, gamma +from dipy.core.gradients import gradient_table_from_gradient_strength_bvecs +from scipy import special +from warnings import warn +from dipy.reconst import mapmri +try: # preferred scipy >= 0.14, required scipy >= 1.0 + from scipy.special import factorial, factorial2 +except ImportError: + from scipy.misc import factorial, factorial2 +from scipy.optimize import fmin_l_bfgs_b +from dipy.reconst.shm import real_sph_harm +import dipy.reconst.dti as dti +from dipy.utils.optpkg import optional_package +import random + +cvxpy, have_cvxpy, _ = optional_package("cvxpy") +plt, have_plt, _ = optional_package("matplotlib.pyplot") + + +class QtdmriModel(Cache): + r"""The q$\tau$-dMRI model [1] to analytically and continuously represent + the q$\tau$ diffusion signal attenuation over diffusion sensitization + q and diffusion time $\tau$. The model can be seen as an extension of + the MAP-MRI basis [2] towards different diffusion times. + + The main idea is to model the diffusion signal over time and space as + a linear combination of continuous functions, + + ..math:: + :nowrap: + \begin{equation} + \hat{E}(\textbf{q},\tau;\textbf{c}) = + \sum_i^{N_{\textbf{q}}}\sum_k^{N_\tau} \textbf{c}_{ik} + \,\Phi_i(\textbf{q})\,T_k(\tau), + \end{equation} + + where $\Phi$ and $T$ are the spatial and temporal basis funcions, + $N_{\textbf{q}}$ and $N_\tau$ are the maximum spatial and temporal + order, and $i,k$ are basis order iterators. + + The estimation of the coefficients $c_i$ can be regularized using + either analytic Laplacian regularization, sparsity regularization using + the l1-norm, or both to do a type of elastic net regularization. + + From the coefficients, there exists an analytical formula to estimate + the ODF, RTOP, RTAP, RTPP, QIV and MSD, for any diffusion time. + + Parameters + ---------- + gtab : GradientTable, + gradient directions and bvalues container class. The bvalues + should be in the normal s/mm^2. big_delta and small_delta need to + given in seconds. + radial_order : unsigned int, + an even integer representing the spatial/radial order of the basis. + time_order : unsigned int, + an integer larger or equal than zero representing the time order + of the basis. + laplacian_regularization : bool, + Regularize using the Laplacian of the qt-dMRI basis. + laplacian_weighting: string or scalar, + The string 'GCV' makes it use generalized cross-validation to find + the regularization weight [3]. A scalar sets the regularization + weight to that value. + l1_regularization : bool, + Regularize by imposing sparsity in the coefficients using the + l1-norm. + l1_weighting : 'CV' or scalar, + The string 'CV' makes it use five-fold cross-validation to find + the regularization weight. A scalar sets the regularization weight + to that value. + cartesian : bool + Whether to use the Cartesian or spherical implementation of the + qt-dMRI basis, which we first explored in [4]. + anisotropic_scaling : bool + Whether to use anisotropic scaling or isotropic scaling. This + option can be used to test if the Cartesian implementation is + equivalent with the spherical one when using the same scaling. + normalization : bool + Whether to normalize the basis functions such that their inner + product is equal to one. Normalization is only necessary when + imposing sparsity in the spherical basis if cartesian=False. + constrain_q0 : bool + whether to constrain the q0 point to unity along the tau-space. + This is necessary to ensure that $E(0,\tau)=1$. + bval_threshold : float + the threshold b-value to be used, such that only data points below + that threshold are used when estimating the scale factors. + eigenvalue_threshold : float, + Sets the minimum of the tensor eigenvalues in order to avoid + stability problem. + cvxpy_solver : str, optional + cvxpy solver name. Optionally optimize the positivity constraint + with a particular cvxpy solver. See See http://www.cvxpy.org/ for + details. Default: ECOS. + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + + .. [2] Ozarslan E. et al., "Mean apparent propagator (MAP) MRI: A novel + diffusion imaging method for mapping tissue microstructure", + NeuroImage, 2013. + + .. [3] Craven et al. "Smoothing Noisy Data with Spline Functions." + NUMER MATH 31.4 (1978): 377-403. + + .. [4] Fick, Rutger HJ, et al. "A unifying framework for spatial and + temporal diffusion in diffusion mri." International Conference on + Information Processing in Medical Imaging. Springer, Cham, 2015. + """ + + def __init__(self, + gtab, + radial_order=6, + time_order=2, + laplacian_regularization=False, + laplacian_weighting=0.2, + l1_regularization=False, + l1_weighting=0.1, + cartesian=True, + anisotropic_scaling=True, + normalization=False, + constrain_q0=True, + bval_threshold=1e10, + eigenvalue_threshold=1e-04, + cvxpy_solver="ECOS" + ): + + if radial_order % 2 or radial_order < 0: + msg = "radial_order must be zero or an even positive integer." + msg += " radial_order %s was given." % radial_order + raise ValueError(msg) + + if time_order < 0: + msg = "time_order must be larger or equal than zero integer." + msg += " time_order %s was given." % time_order + raise ValueError(msg) + + if not isinstance(laplacian_regularization, bool): + msg = "laplacian_regularization must be True or False." + msg += " Input value was %s." % laplacian_regularization + raise ValueError(msg) + + if laplacian_regularization: + msg = "laplacian_regularization weighting must be 'GCV' " + msg += "or a float larger or equal than zero." + msg += " Input value was %s." % laplacian_weighting + if isinstance(laplacian_weighting, str): + if laplacian_weighting is not 'GCV': + raise ValueError(msg) + elif isinstance(laplacian_weighting, float): + if laplacian_weighting < 0: + raise ValueError(msg) + else: + raise ValueError(msg) + + if not isinstance(l1_regularization, bool): + msg = "l1_regularization must be True or False." + msg += " Input value was %s." % l1_regularization + raise ValueError(msg) + + if l1_regularization: + msg = "l1_weighting weighting must be 'CV' " + msg += "or a float larger or equal than zero." + msg += " Input value was %s." % l1_weighting + if isinstance(l1_weighting, str): + if l1_weighting is not 'CV': + raise ValueError(msg) + elif isinstance(l1_weighting, float): + if l1_weighting < 0: + raise ValueError(msg) + else: + raise ValueError(msg) + + if not isinstance(cartesian, bool): + msg = "cartesian must be True or False." + msg += " Input value was %s." % cartesian + raise ValueError(msg) + + if not isinstance(anisotropic_scaling, bool): + msg = "anisotropic_scaling must be True or False." + msg += " Input value was %s." % anisotropic_scaling + raise ValueError(msg) + + if not isinstance(constrain_q0, bool): + msg = "constrain_q0 must be True or False." + msg += " Input value was %s." % constrain_q0 + raise ValueError(msg) + + if (not isinstance(bval_threshold, float) or + bval_threshold < 0): + msg = "bval_threshold must be a positive float." + msg += " Input value was %s." % bval_threshold + raise ValueError(msg) + + if (not isinstance(eigenvalue_threshold, float) or + eigenvalue_threshold < 0): + msg = "eigenvalue_threshold must be a positive float." + msg += " Input value was %s." % eigenvalue_threshold + raise ValueError(msg) + + if laplacian_regularization or l1_regularization: + if not have_cvxpy: + msg = "cvxpy must be installed for Laplacian or l1 " + msg += "regularization." + raise ValueError(msg) + if cvxpy_solver is not None: + if cvxpy_solver not in cvxpy.installed_solvers(): + msg = "Input `cvxpy_solver` was set to %s." % cvxpy_solver + msg += " One of %s" % ', '.join(cvxpy.installed_solvers()) + msg += " was expected." + raise ValueError(msg) + + if l1_regularization and not cartesian and not normalization: + msg = "The non-Cartesian implementation must be normalized for the" + msg += " l1-norm sparsity regularization to work. Set " + msg += "normalization=True to proceed." + raise ValueError(msg) + + self.gtab = gtab + self.radial_order = radial_order + self.time_order = time_order + self.laplacian_regularization = laplacian_regularization + self.laplacian_weighting = laplacian_weighting + self.l1_regularization = l1_regularization + self.l1_weighting = l1_weighting + self.cartesian = cartesian + self.anisotropic_scaling = anisotropic_scaling + self.normalization = normalization + self.constrain_q0 = constrain_q0 + self.bval_threshold = bval_threshold + self.eigenvalue_threshold = eigenvalue_threshold + self.cvxpy_solver = cvxpy_solver + + if self.cartesian: + self.ind_mat = qtdmri_index_matrix(radial_order, time_order) + else: + self.ind_mat = qtdmri_isotropic_index_matrix(radial_order, + time_order) + + # precompute parts of laplacian regularization matrices + self.part4_reg_mat_tau = part4_reg_matrix_tau(self.ind_mat, 1.) + self.part23_reg_mat_tau = part23_reg_matrix_tau(self.ind_mat, 1.) + self.part1_reg_mat_tau = part1_reg_matrix_tau(self.ind_mat, 1.) + if self.cartesian: + self.S_mat, self.T_mat, self.U_mat = ( + mapmri.mapmri_STU_reg_matrices(radial_order) + ) + else: + self.part1_uq_iso_precomp = ( + mapmri.mapmri_isotropic_laplacian_reg_matrix_from_index_matrix( + self.ind_mat[:, :3], 1. + ) + ) + + self.tenmodel = dti.TensorModel(gtab) + + @multi_voxel_fit + def fit(self, data): + bval_mask = self.gtab.bvals < self.bval_threshold + data_norm = data / data[self.gtab.b0s_mask].mean() + tau = self.gtab.tau + bvecs = self.gtab.bvecs + qvals = self.gtab.qvals + b0s_mask = self.gtab.b0s_mask + + if self.cartesian: + if self.anisotropic_scaling: + us, ut, R = qtdmri_anisotropic_scaling(data_norm[bval_mask], + qvals[bval_mask], + bvecs[bval_mask], + tau[bval_mask]) + tau_scaling = ut / us.mean() + tau_scaled = tau * tau_scaling + ut /= tau_scaling + us = np.clip(us, self.eigenvalue_threshold, np.inf) + q = np.dot(bvecs, R) * qvals[:, None] + M = qtdmri_signal_matrix_( + self.radial_order, self.time_order, us, ut, q, tau_scaled, + self.normalization + ) + else: + us, ut = qtdmri_isotropic_scaling(data_norm, qvals, tau) + tau_scaling = ut / us + tau_scaled = tau * tau_scaling + ut /= tau_scaling + R = np.eye(3) + us = np.tile(us, 3) + q = bvecs * qvals[:, None] + M = qtdmri_signal_matrix_( + self.radial_order, self.time_order, us, ut, q, tau_scaled, + self.normalization + ) + else: + us, ut = qtdmri_isotropic_scaling(data_norm, qvals, tau) + tau_scaling = ut / us + tau_scaled = tau * tau_scaling + ut /= tau_scaling + R = np.eye(3) + us = np.tile(us, 3) + q = bvecs * qvals[:, None] + M = qtdmri_isotropic_signal_matrix_( + self.radial_order, self.time_order, us[0], ut, q, tau_scaled, + normalization=self.normalization + ) + + b0_indices = np.arange(self.gtab.tau.shape[0])[self.gtab.b0s_mask] + tau0_ordered = self.gtab.tau[b0_indices] + unique_taus = np.unique(self.gtab.tau) + first_tau_pos = [] + for unique_tau in unique_taus: + first_tau_pos.append(np.where(tau0_ordered == unique_tau)[0][0]) + M0 = M[b0_indices[first_tau_pos]] + + lopt = 0. + alpha = 0. + if self.laplacian_regularization and not self.l1_regularization: + if self.cartesian: + laplacian_matrix = qtdmri_laplacian_reg_matrix( + self.ind_mat, us, ut, self.S_mat, self.T_mat, self.U_mat, + self.part1_reg_mat_tau, + self.part23_reg_mat_tau, + self.part4_reg_mat_tau, + normalization=self.normalization + ) + else: + laplacian_matrix = qtdmri_isotropic_laplacian_reg_matrix( + self.ind_mat, us, ut, self.part1_uq_iso_precomp, + self.part1_reg_mat_tau, self.part23_reg_mat_tau, + self.part4_reg_mat_tau, + normalization=self.normalization + ) + if self.laplacian_weighting == 'GCV': + try: + lopt = generalized_crossvalidation(data_norm, M, + laplacian_matrix) + except BaseException: + msg = "Laplacian GCV failed. lopt defaulted to 2e-4." + warn(msg) + lopt = 2e-4 + elif np.isscalar(self.laplacian_weighting): + lopt = self.laplacian_weighting + c = cvxpy.Variable(M.shape[1]) + design_matrix = cvxpy.Constant(M) + objective = cvxpy.Minimize( + cvxpy.sum_squares(design_matrix * c - data_norm) + + lopt * cvxpy.quad_form(c, laplacian_matrix) + ) + if self.constrain_q0: + # just constraint first and last, otherwise the solver fails + constraints = [M0[0] * c == 1, + M0[-1] * c == 1] + else: + constraints = [] + prob = cvxpy.Problem(objective, constraints) + try: + prob.solve(solver=self.cvxpy_solver, verbose=False) + cvxpy_solution_optimal = prob.status == 'optimal' + qtdmri_coef = np.asarray(c.value).squeeze() + except BaseException: + qtdmri_coef = np.zeros(M.shape[1]) + cvxpy_solution_optimal = False + elif self.l1_regularization and not self.laplacian_regularization: + if self.l1_weighting == 'CV': + alpha = l1_crossvalidation(b0s_mask, data_norm, M) + elif np.isscalar(self.l1_weighting): + alpha = self.l1_weighting + c = cvxpy.Variable(M.shape[1]) + design_matrix = cvxpy.Constant(M) + objective = cvxpy.Minimize( + cvxpy.sum_squares(design_matrix * c - data_norm) + + alpha * cvxpy.norm1(c) + ) + if self.constrain_q0: + # just constraint first and last, otherwise the solver fails + constraints = [M0[0] * c == 1, + M0[-1] * c == 1] + else: + constraints = [] + prob = cvxpy.Problem(objective, constraints) + try: + prob.solve(solver=self.cvxpy_solver, verbose=False) + cvxpy_solution_optimal = prob.status == 'optimal' + qtdmri_coef = np.asarray(c.value).squeeze() + except BaseException: + qtdmri_coef = np.zeros(M.shape[1]) + cvxpy_solution_optimal = False + elif self.l1_regularization and self.laplacian_regularization: + if self.cartesian: + laplacian_matrix = qtdmri_laplacian_reg_matrix( + self.ind_mat, us, ut, self.S_mat, self.T_mat, self.U_mat, + self.part1_reg_mat_tau, + self.part23_reg_mat_tau, + self.part4_reg_mat_tau, + normalization=self.normalization + ) + else: + laplacian_matrix = qtdmri_isotropic_laplacian_reg_matrix( + self.ind_mat, us, ut, self.part1_uq_iso_precomp, + self.part1_reg_mat_tau, self.part23_reg_mat_tau, + self.part4_reg_mat_tau, + normalization=self.normalization + ) + if self.laplacian_weighting == 'GCV': + lopt = generalized_crossvalidation(data_norm, M, + laplacian_matrix) + elif np.isscalar(self.laplacian_weighting): + lopt = self.laplacian_weighting + if self.l1_weighting == 'CV': + alpha = elastic_crossvalidation(b0s_mask, data_norm, M, + laplacian_matrix, lopt) + elif np.isscalar(self.l1_weighting): + alpha = self.l1_weighting + c = cvxpy.Variable(M.shape[1]) + design_matrix = cvxpy.Constant(M) + objective = cvxpy.Minimize( + cvxpy.sum_squares(design_matrix * c - data_norm) + + alpha * cvxpy.norm1(c) + + lopt * cvxpy.quad_form(c, laplacian_matrix) + ) + if self.constrain_q0: + # just constraint first and last, otherwise the solver fails + constraints = [M0[0] * c == 1, + M0[-1] * c == 1] + else: + constraints = [] + prob = cvxpy.Problem(objective, constraints) + try: + prob.solve(solver=self.cvxpy_solver, verbose=False) + cvxpy_solution_optimal = prob.status == 'optimal' + qtdmri_coef = np.asarray(c.value).squeeze() + except BaseException: + qtdmri_coef = np.zeros(M.shape[1]) + cvxpy_solution_optimal = False + elif not self.l1_regularization and not self.laplacian_regularization: + # just use least squares with the observation matrix + pseudoInv = np.linalg.pinv(M) + qtdmri_coef = np.dot(pseudoInv, data_norm) + # if cvxpy is used to constraint q0 without regularization the + # solver often fails, so only first tau-position is manually + # normalized. + qtdmri_coef /= np.dot(M0[0], qtdmri_coef) + cvxpy_solution_optimal = None + + if cvxpy_solution_optimal is False: + msg = "cvxpy optimization resulted in non-optimal solution. Check " + msg += "cvxpy_solution_optimal attribute in fitted object to see " + msg += "which voxels are affected." + warn(msg) + return QtdmriFit( + self, qtdmri_coef, us, ut, tau_scaling, R, lopt, alpha, + cvxpy_solution_optimal) + + +class QtdmriFit(): + + def __init__(self, model, qtdmri_coef, us, ut, tau_scaling, R, lopt, + alpha, cvxpy_solution_optimal): + """ Calculates diffusion properties for a single voxel + + Parameters + ---------- + model : object, + AnalyticalModel + qtdmri_coef : 1d ndarray, + qtdmri coefficients + us : array, 3 x 1 + spatial scaling factors + ut : float + temporal scaling factor + tau_scaling : float, + the temporal scaling that used to scale tau to the size of us + R : 3x3 numpy array, + tensor eigenvectors + lopt : float, + laplacian regularization weight + alpha : float, + the l1 regularization weight + cvxpy_solution_optimal: bool, + indicates whether the cvxpy coefficient estimation reach an optimal + solution + """ + + self.model = model + self._qtdmri_coef = qtdmri_coef + self.us = us + self.ut = ut + self.tau_scaling = tau_scaling + self.R = R + self.lopt = lopt + self.alpha = alpha + self.cvxpy_solution_optimal = cvxpy_solution_optimal + + def qtdmri_to_mapmri_coef(self, tau): + """This function converts the qtdmri coefficients to mapmri + coefficients for a given tau [1]_. The conversion is performed by a + matrix multiplication that evaluates the time-depenent part of the + basis and multiplies it with the coefficients, after which coefficients + with the same spatial orders are summed up, resulting in mapmri + coefficients. + + Parameters + ---------- + tau : float + diffusion time (big_delta - small_delta / 3.) in seconds + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + if self.model.cartesian: + II = self.model.cache_get('qtdmri_to_mapmri_matrix', + key=(tau)) + if II is None: + II = qtdmri_to_mapmri_matrix(self.model.radial_order, + self.model.time_order, self.ut, + self.tau_scaling * tau) + self.model.cache_set('qtdmri_to_mapmri_matrix', + (tau), II) + else: + II = self.model.cache_get('qtdmri_isotropic_to_mapmri_matrix', + key=(tau)) + if II is None: + II = qtdmri_isotropic_to_mapmri_matrix(self.model.radial_order, + self.model.time_order, + self.ut, + self.tau_scaling * tau) + self.model.cache_set('qtdmri_isotropic_to_mapmri_matrix', + (tau), II) + mapmri_coef = np.dot(II, self._qtdmri_coef) + return mapmri_coef + + def sparsity_abs(self, threshold=0.99): + """As a measure of sparsity, calculates the number of largest + coefficients needed to absolute sum up to 99% of the total absolute sum + of all coefficients""" + if not 0. < threshold < 1.: + msg = "sparsity threshold must be between zero and one" + raise ValueError(msg) + total_weight = np.sum(abs(self._qtdmri_coef)) + absolute_normalized_coef_array = ( + np.sort(abs(self._qtdmri_coef))[::-1] / total_weight) + current_weight = 0. + counter = 0 + while current_weight < threshold: + current_weight += absolute_normalized_coef_array[counter] + counter += 1 + return counter + + def sparsity_density(self, threshold=0.99): + """As a measure of sparsity, calculates the number of largest + coefficients needed to squared sum up to 99% of the total squared sum + of all coefficients""" + if not 0. < threshold < 1.: + msg = "sparsity threshold must be between zero and one" + raise ValueError(msg) + total_weight = np.sum(self._qtdmri_coef ** 2) + squared_normalized_coef_array = ( + np.sort(self._qtdmri_coef ** 2)[::-1] / total_weight) + current_weight = 0. + counter = 0 + while current_weight < threshold: + current_weight += squared_normalized_coef_array[counter] + counter += 1 + return counter + + def odf(self, sphere, tau, s=2): + r""" Calculates the analytical Orientation Distribution Function (ODF) + for a given diffusion time tau from the signal, [1]_ Eq. (32). The + qtdmri coefficients are first converted to mapmri coefficients + following [2]. + + Parameters + ---------- + sphere : dipy sphere object + sphere object with vertice orientations to compute the ODF on. + tau : float + diffusion time (big_delta - small_delta / 3.) in seconds + s : unsigned int + radial moment of the ODF + + References + ---------- + .. [1] Ozarslan E. et. al, "Mean apparent propagator (MAP) MRI: A novel + diffusion imaging method for mapping tissue microstructure", + NeuroImage, 2013. + .. [2] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + mapmri_coef = self.qtdmri_to_mapmri_coef(tau) + if self.model.cartesian: + v_ = sphere.vertices + v = np.dot(v_, self.R) + I_s = mapmri.mapmri_odf_matrix(self.model.radial_order, self.us, + s, v) + odf = np.dot(I_s, mapmri_coef) + else: + II = self.model.cache_get('ODF_matrix', key=(sphere, s)) + if II is None: + II = mapmri.mapmri_isotropic_odf_matrix( + self.model.radial_order, 1, s, sphere.vertices) + self.model.cache_set('ODF_matrix', (sphere, s), II) + + odf = self.us[0] ** s * np.dot(II, mapmri_coef) + return odf + + def odf_sh(self, tau, s=2): + r""" Calculates the real analytical odf for a given discrete sphere. + Computes the design matrix of the ODF for the given sphere vertices + and radial moment [1]_ eq. (32). The radial moment s acts as a + sharpening method. The analytical equation for the spherical ODF basis + is given in [2]_ eq. (C8). The qtdmri coefficients are first converted + to mapmri coefficients following [3]. + + Parameters + ---------- + tau : float + diffusion time (big_delta - small_delta / 3.) in seconds + s : unsigned int + radial moment of the ODF + + References + ---------- + .. [1] Ozarslan E. et. al, "Mean apparent propagator (MAP) MRI: A novel + diffusion imaging method for mapping tissue microstructure", + NeuroImage, 2013. + .. [2]_ Fick, Rutger HJ, et al. "MAPL: Tissue microstructure estimation + using Laplacian-regularized MAP-MRI and its application to HCP + data." NeuroImage (2016). + .. [3] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + mapmri_coef = self.qtdmri_to_mapmri_coef(tau) + if self.model.cartesian: + msg = 'odf in spherical harmonics not yet implemented for ' + msg += 'cartesian implementation' + raise ValueError(msg) + II = self.model.cache_get('ODF_sh_matrix', + key=(self.model.radial_order, s)) + + if II is None: + II = mapmri.mapmri_isotropic_odf_sh_matrix(self.model.radial_order, + 1, s) + self.model.cache_set('ODF_sh_matrix', (self.model.radial_order, s), + II) + + odf = self.us[0] ** s * np.dot(II, mapmri_coef) + return odf + + def rtpp(self, tau): + r""" Calculates the analytical return to the plane probability (RTPP) + for a given diffusion time tau, [1]_ eq. (42). The analytical formula + for the isotropic MAP-MRI basis was derived in [2]_ eq. (C11). The + qtdmri coefficients are first converted to mapmri coefficients + following [3]. + + Parameters + ---------- + tau : float + diffusion time (big_delta - small_delta / 3.) in seconds + + References + ---------- + .. [1] Ozarslan E. et. al, "Mean apparent propagator (MAP) MRI: A novel + diffusion imaging method for mapping tissue microstructure", + NeuroImage, 2013. + .. [2]_ Fick, Rutger HJ, et al. "MAPL: Tissue microstructure estimation + using Laplacian-regularized MAP-MRI and its application to HCP + data." NeuroImage (2016). + .. [3] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + mapmri_coef = self.qtdmri_to_mapmri_coef(tau) + + if self.model.cartesian: + ind_mat = mapmri.mapmri_index_matrix(self.model.radial_order) + Bm = mapmri.b_mat(ind_mat) + sel = Bm > 0. # select only relevant coefficients + const = 1 / (np.sqrt(2 * np.pi) * self.us[0]) + ind_sum = (-1.0) ** (ind_mat[sel, 0] / 2.0) + rtpp_vec = const * Bm[sel] * ind_sum * mapmri_coef[sel] + rtpp = rtpp_vec.sum() + return rtpp + else: + ind_mat = mapmri.mapmri_isotropic_index_matrix( + self.model.radial_order + ) + rtpp_vec = np.zeros(int(ind_mat.shape[0])) + count = 0 + for n in range(0, self.model.radial_order + 1, 2): + for j in range(1, 2 + n // 2): + ll = n + 2 - 2 * j + const = (-1 / 2.0) ** (ll / 2) / np.sqrt(np.pi) + matsum = 0 + for k in range(0, j): + matsum += ( + (-1) ** k * + mapmri.binomialfloat(j + ll - 0.5, j - k - 1) * + gamma(ll / 2 + k + 1 / 2.0) / + (factorial(k) * 0.5 ** (ll / 2 + 1 / 2.0 + k))) + for m in range(-ll, ll + 1): + rtpp_vec[count] = const * matsum + count += 1 + direction = np.array(self.R[:, 0], ndmin=2) + r, theta, phi = cart2sphere(direction[:, 0], direction[:, 1], + direction[:, 2]) + + rtpp = mapmri_coef * (1 / self.us[0]) *\ + rtpp_vec * real_sph_harm(ind_mat[:, 2], ind_mat[:, 1], + theta, phi) + return rtpp.sum() + + def rtap(self, tau): + r""" Calculates the analytical return to the axis probability (RTAP) + for a given diffusion time tau, [1]_ eq. (40, 44a). The analytical + formula for the isotropic MAP-MRI basis was derived in [2]_ eq. (C11). + The qtdmri coefficients are first converted to mapmri coefficients + following [3]. + + Parameters + ---------- + tau : float + diffusion time (big_delta - small_delta / 3.) in seconds + + References + ---------- + .. [1] Ozarslan E. et. al, "Mean apparent propagator (MAP) MRI: A novel + diffusion imaging method for mapping tissue microstructure", + NeuroImage, 2013. + .. [2]_ Fick, Rutger HJ, et al. "MAPL: Tissue microstructure estimation + using Laplacian-regularized MAP-MRI and its application to HCP + data." NeuroImage (2016). + .. [3] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + mapmri_coef = self.qtdmri_to_mapmri_coef(tau) + + if self.model.cartesian: + ind_mat = mapmri.mapmri_index_matrix(self.model.radial_order) + Bm = mapmri.b_mat(ind_mat) + sel = Bm > 0. # select only relevant coefficients + const = 1 / (2 * np.pi * np.prod(self.us[1:])) + ind_sum = (-1.0) ** ((np.sum(ind_mat[sel, 1:], axis=1) / 2.0)) + rtap_vec = const * Bm[sel] * ind_sum * mapmri_coef[sel] + rtap = np.sum(rtap_vec) + else: + ind_mat = mapmri.mapmri_isotropic_index_matrix( + self.model.radial_order + ) + rtap_vec = np.zeros(int(ind_mat.shape[0])) + count = 0 + + for n in range(0, self.model.radial_order + 1, 2): + for j in range(1, 2 + n // 2): + ll = n + 2 - 2 * j + kappa = ((-1) ** (j - 1) * 2. ** (-(ll + 3) / 2.0)) / np.pi + matsum = 0 + for k in range(0, j): + matsum += ((-1) ** k * + mapmri.binomialfloat(j + ll - 0.5, + j - k - 1) * + gamma((ll + 1) / 2.0 + k)) /\ + (factorial(k) * 0.5 ** ((ll + 1) / 2.0 + k)) + for m in range(-ll, ll + 1): + rtap_vec[count] = kappa * matsum + count += 1 + rtap_vec *= 2 + + direction = np.array(self.R[:, 0], ndmin=2) + r, theta, phi = cart2sphere(direction[:, 0], + direction[:, 1], direction[:, 2]) + rtap_vec = mapmri_coef * (1 / self.us[0] ** 2) *\ + rtap_vec * real_sph_harm(ind_mat[:, 2], ind_mat[:, 1], + theta, phi) + rtap = rtap_vec.sum() + return rtap + + def rtop(self, tau): + r""" Calculates the analytical return to the origin probability (RTOP) + for a given diffusion time tau [1]_ eq. (36, 43). The analytical + formula for the isotropic MAP-MRI basis was derived in [2]_ eq. (C11). + The qtdmri coefficients are first converted to mapmri coefficients + following [3]. + + Parameters + ---------- + tau : float + diffusion time (big_delta - small_delta / 3.) in seconds + + References + ---------- + .. [1] Ozarslan E. et. al, "Mean apparent propagator (MAP) MRI: A novel + diffusion imaging method for mapping tissue microstructure", + NeuroImage, 2013. + .. [2]_ Fick, Rutger HJ, et al. "MAPL: Tissue microstructure estimation + using Laplacian-regularized MAP-MRI and its application to HCP + data." NeuroImage (2016). + .. [3] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + mapmri_coef = self.qtdmri_to_mapmri_coef(tau) + + if self.model.cartesian: + ind_mat = mapmri.mapmri_index_matrix(self.model.radial_order) + Bm = mapmri.b_mat(ind_mat) + const = 1 / (np.sqrt(8 * np.pi ** 3) * np.prod(self.us)) + ind_sum = (-1.0) ** (np.sum(ind_mat, axis=1) / 2) + rtop_vec = const * ind_sum * Bm * mapmri_coef + rtop = rtop_vec.sum() + else: + ind_mat = mapmri.mapmri_isotropic_index_matrix( + self.model.radial_order + ) + Bm = mapmri.b_mat_isotropic(ind_mat) + const = 1 / (2 * np.sqrt(2.0) * np.pi ** (3 / 2.0)) + rtop_vec = const * (-1.0) ** (ind_mat[:, 0] - 1) * Bm + rtop = (1 / self.us[0] ** 3) * rtop_vec * mapmri_coef + rtop = rtop.sum() + return rtop + + def msd(self, tau): + r""" Calculates the analytical Mean Squared Displacement (MSD) for a + given diffusion time tau. It is defined as the Laplacian of the origin + of the estimated signal [1]_. The analytical formula for the MAP-MRI + basis was derived in [2]_ eq. (C13, D1). The qtdmri coefficients are + first converted to mapmri coefficients following [3]. + + Parameters + ---------- + tau : float + diffusion time (big_delta - small_delta / 3.) in seconds + + References + ---------- + .. [1] Cheng, J., 2014. Estimation and Processing of Ensemble Average + Propagator and Its Features in Diffusion MRI. Ph.D. Thesis. + .. [2]_ Fick, Rutger HJ, et al. "MAPL: Tissue microstructure estimation + using Laplacian-regularized MAP-MRI and its application to HCP + data." NeuroImage (2016). + .. [3] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + mapmri_coef = self.qtdmri_to_mapmri_coef(tau) + mu = self.us + if self.model.cartesian: + ind_mat = mapmri.mapmri_index_matrix(self.model.radial_order) + Bm = mapmri.b_mat(ind_mat) + sel = Bm > 0. # select only relevant coefficients + ind_sum = np.sum(ind_mat[sel], axis=1) + nx, ny, nz = ind_mat[sel].T + + numerator = (-1) ** (0.5 * (-ind_sum)) * np.pi ** (3 / 2.0) *\ + ((1 + 2 * nx) * mu[0] ** 2 + (1 + 2 * ny) * + mu[1] ** 2 + (1 + 2 * nz) * mu[2] ** 2) + + denominator = np.sqrt(2. ** (-ind_sum) * factorial(nx) * + factorial(ny) * factorial(nz)) *\ + gamma(0.5 - 0.5 * nx) * gamma(0.5 - 0.5 * ny) *\ + gamma(0.5 - 0.5 * nz) + + msd_vec = mapmri_coef[sel] * (numerator / denominator) + msd = msd_vec.sum() + else: + ind_mat = mapmri.mapmri_isotropic_index_matrix( + self.model.radial_order + ) + Bm = mapmri.b_mat_isotropic(ind_mat) + sel = Bm > 0. # select only relevant coefficients + msd_vec = (4 * ind_mat[sel, 0] - 1) * Bm[sel] + msd = self.us[0] ** 2 * msd_vec * mapmri_coef[sel] + msd = msd.sum() + return msd + + def qiv(self, tau): + r""" Calculates the analytical Q-space Inverse Variance (QIV) for given + diffusion time tau. + It is defined as the inverse of the Laplacian of the origin of the + estimated propagator [1]_ eq. (22). The analytical formula for the + MAP-MRI basis was derived in [2]_ eq. (C14, D2). The qtdmri + coefficients are first converted to mapmri coefficients following [3]. + + Parameters + ---------- + tau : float + diffusion time (big_delta - small_delta / 3.) in seconds + + References + ---------- + .. [1] Hosseinbor et al. "Bessel fourier orientation reconstruction + (bfor): An analytical diffusion propagator reconstruction for + hybrid diffusion imaging and computation of q-space indices. + NeuroImage 64, 2013, 650–670. + .. [2]_ Fick, Rutger HJ, et al. "MAPL: Tissue microstructure estimation + using Laplacian-regularized MAP-MRI and its application to HCP + data." NeuroImage (2016). + .. [3] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + mapmri_coef = self.qtdmri_to_mapmri_coef(tau) + ux, uy, uz = self.us + if self.model.cartesian: + ind_mat = mapmri.mapmri_index_matrix(self.model.radial_order) + Bm = mapmri.b_mat(ind_mat) + sel = Bm > 0 # select only relevant coefficients + nx, ny, nz = ind_mat[sel].T + + numerator = 8 * np.pi ** 2 * (ux * uy * uz) ** 3 *\ + np.sqrt(factorial(nx) * factorial(ny) * factorial(nz)) *\ + gamma(0.5 - 0.5 * nx) * gamma(0.5 - 0.5 * ny) * \ + gamma(0.5 - 0.5 * nz) + + denominator = np.sqrt(2. ** (-1 + nx + ny + nz)) *\ + ((1 + 2 * nx) * uy ** 2 * uz ** 2 + ux ** 2 * + ((1 + 2 * nz) * uy ** 2 + (1 + 2 * ny) * uz ** 2)) + + qiv_vec = mapmri_coef[sel] * (numerator / denominator) + qiv = qiv_vec.sum() + else: + ind_mat = mapmri.mapmri_isotropic_index_matrix( + self.model.radial_order + ) + Bm = mapmri.b_mat_isotropic(ind_mat) + sel = Bm > 0. # select only relevant coefficients + j = ind_mat[sel, 0] + qiv_vec = ((8 * (-1.0) ** (1 - j) * + np.sqrt(2) * np.pi ** (7 / 2.)) / ((4.0 * j - 1) * + Bm[sel])) + qiv = ux ** 5 * qiv_vec * mapmri_coef[sel] + qiv = qiv.sum() + return qiv + + def fitted_signal(self, gtab=None): + """ + Recovers the fitted signal for the given gradient table. If no gradient + table is given it recovers the signal for the gtab of the model object. + """ + if gtab is None: + E = self.predict(self.model.gtab) + else: + E = self.predict(gtab) + return E + + def predict(self, qvals_or_gtab, S0=1.): + r"""Recovers the reconstructed signal for any qvalue array or + gradient table. + """ + tau_scaling = self.tau_scaling + if isinstance(qvals_or_gtab, np.ndarray): + q = qvals_or_gtab[:, :3] + tau = qvals_or_gtab[:, 3] * tau_scaling + else: + gtab = qvals_or_gtab + qvals = gtab.qvals + tau = gtab.tau * tau_scaling + q = qvals[:, None] * gtab.bvecs + + if self.model.cartesian: + if self.model.anisotropic_scaling: + q_rot = np.dot(q, self.R) + M = qtdmri_signal_matrix_(self.model.radial_order, + self.model.time_order, + self.us, self.ut, q_rot, tau, + self.model.normalization) + else: + M = qtdmri_signal_matrix_(self.model.radial_order, + self.model.time_order, + self.us, self.ut, q, tau, + self.model.normalization) + else: + M = qtdmri_isotropic_signal_matrix_( + self.model.radial_order, self.model.time_order, + self.us[0], self.ut, q, tau, + normalization=self.model.normalization) + E = S0 * np.dot(M, self._qtdmri_coef) + return E + + def norm_of_laplacian_signal(self): + """ Calculates the norm of the laplacian of the fitted signal [1]_. + This information could be useful to assess if the extrapolation of the + fitted signal contains spurious oscillations. A high laplacian norm may + indicate that these are present, and any q-space indices that + use integrals of the signal may be corrupted (e.g. RTOP, RTAP, RTPP, + QIV). In contrast to [1], the Laplacian now describes oscillations in + the 4-dimensional qt-signal [2]. + + References + ---------- + .. [1]_ Fick, Rutger HJ, et al. "MAPL: Tissue microstructure estimation + using Laplacian-regularized MAP-MRI and its application to HCP + data." NeuroImage (2016). + .. [2] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + if self.model.cartesian: + lap_matrix = qtdmri_laplacian_reg_matrix( + self.model.ind_mat, self.us, self.ut, + self.model.S_mat, self.model.T_mat, self.model.U_mat, + self.model.part1_reg_mat_tau, + self.model.part23_reg_mat_tau, + self.model.part4_reg_mat_tau, + normalization=self.model.normalization + ) + else: + lap_matrix = qtdmri_isotropic_laplacian_reg_matrix( + self.model.ind_mat, self.us, self.ut, + self.model.part1_uq_iso_precomp, + self.model.part1_reg_mat_tau, + self.model.part23_reg_mat_tau, + self.model.part4_reg_mat_tau, + normalization=self.model.normalization + ) + norm_laplacian = np.dot(self._qtdmri_coef, + np.dot(self._qtdmri_coef, lap_matrix)) + return norm_laplacian + + def pdf(self, rt_points): + """ Diffusion propagator on a given set of real points. + if the array r_points is non writeable, then intermediate + results are cached for faster recalculation + """ + tau_scaling = self.tau_scaling + rt_points_ = rt_points * np.r_[1, 1, 1, tau_scaling] + if self.model.cartesian: + K = qtdmri_eap_matrix_(self.model.radial_order, + self.model.time_order, + self.us, self.ut, rt_points_, + self.model.normalization) + else: + K = qtdmri_isotropic_eap_matrix_( + self.model.radial_order, self.model.time_order, + self.us[0], self.ut, rt_points_, + normalization=self.model.normalization + ) + eap = np.dot(K, self._qtdmri_coef) + return eap + + +def qtdmri_to_mapmri_matrix(radial_order, time_order, ut, tau): + """Generates the matrix that maps the qtdmri coefficients to MAP-MRI + coefficients. The conversion is done by only evaluating the time basis for + a diffusion time tau and summing up coefficients with the same spatial + basis orders [1]. + + Parameters + ---------- + radial_order : unsigned int, + an even integer representing the spatial/radial order of the basis. + time_order : unsigned int, + an integer larger or equal than zero representing the time order + of the basis. + ut : float + temporal scaling factor + tau : float + diffusion time (big_delta - small_delta / 3.) in seconds + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + mapmri_ind_mat = mapmri.mapmri_index_matrix(radial_order) + n_elem_mapmri = int(mapmri_ind_mat.shape[0]) + qtdmri_ind_mat = qtdmri_index_matrix(radial_order, time_order) + n_elem_qtdmri = int(qtdmri_ind_mat.shape[0]) + + temporal_storage = np.zeros(time_order + 1) + for o in range(time_order + 1): + temporal_storage[o] = temporal_basis(o, ut, tau) + + counter = 0 + mapmri_mat = np.zeros((n_elem_mapmri, n_elem_qtdmri)) + for nxt, nyt, nzt, o in qtdmri_ind_mat: + index_overlap = np.all([nxt == mapmri_ind_mat[:, 0], + nyt == mapmri_ind_mat[:, 1], + nzt == mapmri_ind_mat[:, 2]], 0) + mapmri_mat[:, counter] = temporal_storage[o] * index_overlap + counter += 1 + return mapmri_mat + + +def qtdmri_isotropic_to_mapmri_matrix(radial_order, time_order, ut, tau): + """Generates the matrix that maps the spherical qtdmri coefficients to + MAP-MRI coefficients. The conversion is done by only evaluating the time + basis for a diffusion time tau and summing up coefficients with the same + spatial basis orders [1]. + + Parameters + ---------- + radial_order : unsigned int, + an even integer representing the spatial/radial order of the basis. + time_order : unsigned int, + an integer larger or equal than zero representing the time order + of the basis. + ut : float + temporal scaling factor + tau : float + diffusion time (big_delta - small_delta / 3.) in seconds + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + mapmri_ind_mat = mapmri.mapmri_isotropic_index_matrix(radial_order) + n_elem_mapmri = int(mapmri_ind_mat.shape[0]) + qtdmri_ind_mat = qtdmri_isotropic_index_matrix(radial_order, time_order) + n_elem_qtdmri = int(qtdmri_ind_mat.shape[0]) + + temporal_storage = np.zeros(time_order + 1) + for o in range(time_order + 1): + temporal_storage[o] = temporal_basis(o, ut, tau) + + counter = 0 + mapmri_isotropic_mat = np.zeros((n_elem_mapmri, n_elem_qtdmri)) + for j, ll, m, o in qtdmri_ind_mat: + index_overlap = np.all([j == mapmri_ind_mat[:, 0], + ll == mapmri_ind_mat[:, 1], + m == mapmri_ind_mat[:, 2]], 0) + mapmri_isotropic_mat[:, counter] = temporal_storage[o] * index_overlap + counter += 1 + return mapmri_isotropic_mat + + +def qtdmri_temporal_normalization(ut): + """Normalization factor for the temporal basis""" + return np.sqrt(ut) + + +def qtdmri_mapmri_normalization(mu): + """Normalization factor for Cartesian MAP-MRI basis. The scaling is the + same for every basis function depending only on the spatial scaling + mu. + """ + sqrtC = np.sqrt(8 * np.prod(mu)) * np.pi ** (3. / 4.) + return sqrtC + + +def qtdmri_mapmri_isotropic_normalization(j, l, u0): + """Normalization factor for Spherical MAP-MRI basis. The normalization + for a basis function with orders [j,l,m] depends only on orders j,l and + the isotropic scale factor. + """ + sqrtC = ((2 * np.pi) ** (3. / 2.) * + np.sqrt(2 ** l * u0 ** 3 * gamma(j) / gamma(j + l + 1. / 2.))) + return sqrtC + + +def qtdmri_signal_matrix_(radial_order, time_order, us, ut, q, tau, + normalization=False): + """Function to generate the qtdmri signal basis.""" + M = qtdmri_signal_matrix(radial_order, time_order, us, ut, q, tau) + if normalization: + sqrtC = qtdmri_mapmri_normalization(us) + sqrtut = qtdmri_temporal_normalization(ut) + sqrtCut = sqrtC * sqrtut + M *= sqrtCut + return M + + +def qtdmri_signal_matrix(radial_order, time_order, us, ut, q, tau): + r"""Constructs the design matrix as a product of 3 separated radial, + angular and temporal design matrices. It precomputes the relevant basis + orders for each one and finally puts them together according to the index + matrix + """ + ind_mat = qtdmri_index_matrix(radial_order, time_order) + + n_dat = int(q.shape[0]) + n_elem = int(ind_mat.shape[0]) + qx, qy, qz = q.T + mux, muy, muz = us + + temporal_storage = np.zeros((n_dat, time_order + 1)) + for o in range(time_order + 1): + temporal_storage[:, o] = temporal_basis(o, ut, tau) + + Qx_storage = np.array(np.zeros((n_dat, radial_order + 1 + 4)), + dtype=complex) + Qy_storage = np.array(np.zeros((n_dat, radial_order + 1 + 4)), + dtype=complex) + Qz_storage = np.array(np.zeros((n_dat, radial_order + 1 + 4)), + dtype=complex) + for n in range(radial_order + 1 + 4): + Qx_storage[:, n] = mapmri.mapmri_phi_1d(n, qx, mux) + Qy_storage[:, n] = mapmri.mapmri_phi_1d(n, qy, muy) + Qz_storage[:, n] = mapmri.mapmri_phi_1d(n, qz, muz) + + counter = 0 + Q = np.zeros((n_dat, n_elem)) + for nx, ny, nz, o in ind_mat: + Q[:, counter] = (np.real( + Qx_storage[:, nx] * Qy_storage[:, ny] * Qz_storage[:, nz]) * + temporal_storage[:, o] + ) + counter += 1 + + return Q + + +def qtdmri_eap_matrix(radial_order, time_order, us, ut, grid): + r"""Constructs the design matrix as a product of 3 separated radial, + angular and temporal design matrices. It precomputes the relevant basis + orders for each one and finally puts them together according to the index + matrix + """ + ind_mat = qtdmri_index_matrix(radial_order, time_order) + rx, ry, rz, tau = grid.T + + n_dat = int(rx.shape[0]) + n_elem = int(ind_mat.shape[0]) + mux, muy, muz = us + + temporal_storage = np.zeros((n_dat, time_order + 1)) + for o in range(time_order + 1): + temporal_storage[:, o] = temporal_basis(o, ut, tau) + + Kx_storage = np.zeros((n_dat, radial_order + 1)) + Ky_storage = np.zeros((n_dat, radial_order + 1)) + Kz_storage = np.zeros((n_dat, radial_order + 1)) + for n in range(radial_order + 1): + Kx_storage[:, n] = mapmri.mapmri_psi_1d(n, rx, mux) + Ky_storage[:, n] = mapmri.mapmri_psi_1d(n, ry, muy) + Kz_storage[:, n] = mapmri.mapmri_psi_1d(n, rz, muz) + + counter = 0 + K = np.zeros((n_dat, n_elem)) + for nx, ny, nz, o in ind_mat: + K[:, counter] = ( + Kx_storage[:, nx] * Ky_storage[:, ny] * Kz_storage[:, nz] * + temporal_storage[:, o] + ) + counter += 1 + + return K + + +def qtdmri_isotropic_signal_matrix_(radial_order, time_order, us, ut, q, tau, + normalization=False): + M = qtdmri_isotropic_signal_matrix( + radial_order, time_order, us, ut, q, tau + ) + if normalization: + ind_mat = qtdmri_isotropic_index_matrix(radial_order, time_order) + j, ll = ind_mat[:, :2].T + sqrtut = qtdmri_temporal_normalization(ut) + sqrtC = qtdmri_mapmri_isotropic_normalization(j, ll, us) + sqrtCut = sqrtC * sqrtut + M = M * sqrtCut[None, :] + return M + + +def qtdmri_isotropic_signal_matrix(radial_order, time_order, us, ut, q, tau): + ind_mat = qtdmri_isotropic_index_matrix(radial_order, time_order) + qvals, theta, phi = cart2sphere(q[:, 0], q[:, 1], q[:, 2]) + + n_dat = int(qvals.shape[0]) + n_elem = int(ind_mat.shape[0]) + + num_j = int(np.max(ind_mat[:, 0])) + num_o = int(time_order + 1) + num_l = int(radial_order // 2 + 1) + num_m = int(radial_order * 2 + 1) + + # Radial Basis + radial_storage = np.zeros([num_j, num_l, n_dat]) + for j in range(1, num_j + 1): + for ll in range(0, radial_order + 1, 2): + radial_storage[j - 1, ll // 2, :] = radial_basis_opt( + j, ll, us, qvals) + + # Angular Basis + angular_storage = np.zeros([num_l, num_m, n_dat]) + for ll in range(0, radial_order + 1, 2): + for m in range(-ll, ll + 1): + angular_storage[ll // 2, m + ll, :] = ( + angular_basis_opt(ll, m, qvals, theta, phi) + ) + + # Temporal Basis + temporal_storage = np.zeros([num_o + 1, n_dat]) + for o in range(0, num_o + 1): + temporal_storage[o, :] = temporal_basis(o, ut, tau) + + # Construct full design matrix + M = np.zeros((n_dat, n_elem)) + counter = 0 + for j, ll, m, o in ind_mat: + M[:, counter] = (radial_storage[j - 1, ll // 2, :] * + angular_storage[ll // 2, m + ll, :] * + temporal_storage[o, :]) + counter += 1 + return M + + +def qtdmri_eap_matrix_(radial_order, time_order, us, ut, grid, + normalization=False): + sqrtC = 1. + sqrtut = 1. + sqrtCut = 1. + if normalization: + sqrtC = qtdmri_mapmri_normalization(us) + sqrtut = qtdmri_temporal_normalization(ut) + sqrtCut = sqrtC * sqrtut + K_tau = ( + qtdmri_eap_matrix(radial_order, time_order, us, ut, grid) * sqrtCut + ) + return K_tau + + +def qtdmri_isotropic_eap_matrix_(radial_order, time_order, us, ut, grid, + normalization=False): + K = qtdmri_isotropic_eap_matrix( + radial_order, time_order, us, ut, grid + ) + if normalization: + ind_mat = qtdmri_isotropic_index_matrix(radial_order, time_order) + j, ll = ind_mat[:, :2].T + sqrtut = qtdmri_temporal_normalization(ut) + sqrtC = qtdmri_mapmri_isotropic_normalization(j, ll, us) + sqrtCut = sqrtC * sqrtut + K = K * sqrtCut[None, :] + return K + + +def qtdmri_isotropic_eap_matrix(radial_order, time_order, us, ut, grid): + r"""Constructs the design matrix as a product of 3 separated radial, + angular and temporal design matrices. It precomputes the relevant basis + orders for each one and finally puts them together according to the index + matrix + """ + + rx, ry, rz, tau = grid.T + R, theta, phi = cart2sphere(rx, ry, rz) + theta[np.isnan(theta)] = 0 + + ind_mat = qtdmri_isotropic_index_matrix(radial_order, time_order) + n_dat = int(R.shape[0]) + n_elem = int(ind_mat.shape[0]) + + num_j = int(np.max(ind_mat[:, 0])) + num_o = int(time_order + 1) + num_l = int(radial_order / 2 + 1) + num_m = int(radial_order * 2 + 1) + + # Radial Basis + radial_storage = np.zeros([num_j, num_l, n_dat]) + for j in range(1, num_j + 1): + for ll in range(0, radial_order + 1, 2): + radial_storage[j - 1, ll // 2, :] = radial_basis_EAP_opt( + j, ll, us, R) + + # Angular Basis + angular_storage = np.zeros([num_j, num_l, num_m, n_dat]) + for j in range(1, num_j + 1): + for ll in range(0, radial_order + 1, 2): + for m in range(-ll, ll + 1): + angular_storage[j - 1, ll // 2, m + ll, :] = ( + angular_basis_EAP_opt(j, ll, m, R, theta, phi) + ) + + # Temporal Basis + temporal_storage = np.zeros([num_o + 1, n_dat]) + for o in range(0, num_o + 1): + temporal_storage[o, :] = temporal_basis(o, ut, tau) + + # Construct full design matrix + M = np.zeros((n_dat, n_elem)) + counter = 0 + for j, ll, m, o in ind_mat: + M[:, counter] = (radial_storage[j - 1, ll // 2, :] * + angular_storage[j - 1, ll // 2, m + ll, :] * + temporal_storage[o, :]) + counter += 1 + return M + + +def radial_basis_opt(j, l, us, q): + """ Spatial basis dependent on spatial scaling factor us + """ + const = ( + us ** l * np.exp(-2 * np.pi ** 2 * us ** 2 * q ** 2) * + genlaguerre(j - 1, l + 0.5)(4 * np.pi ** 2 * us ** 2 * q ** 2) + ) + return const + + +def angular_basis_opt(l, m, q, theta, phi): + """ Angular basis independent of spatial scaling factor us. Though it + includes q, it is independent of the data and can be precomputed. + """ + const = ( + (-1) ** (l / 2) * np.sqrt(4.0 * np.pi) * + (2 * np.pi ** 2 * q ** 2) ** (l / 2) * + real_sph_harm(m, l, theta, phi) + ) + return const + + +def radial_basis_EAP_opt(j, l, us, r): + radial_part = ( + (us ** 3) ** (-1) / (us ** 2) ** (l / 2) * + np.exp(- r ** 2 / (2 * us ** 2)) * + genlaguerre(j - 1, l + 0.5)(r ** 2 / us ** 2) + ) + return radial_part + + +def angular_basis_EAP_opt(j, l, m, r, theta, phi): + angular_part = ( + (-1) ** (j - 1) * (np.sqrt(2) * np.pi) ** (-1) * + (r ** 2 / 2) ** (l / 2) * real_sph_harm(m, l, theta, phi) + ) + return angular_part + + +def temporal_basis(o, ut, tau): + """ Temporal basis dependent on temporal scaling factor ut + """ + const = np.exp(-ut * tau / 2.0) * special.laguerre(o)(ut * tau) + return const + + +def qtdmri_index_matrix(radial_order, time_order): + """Computes the SHORE basis order indices according to [1]. + """ + index_matrix = [] + for n in range(0, radial_order + 1, 2): + for i in range(0, n + 1): + for j in range(0, n - i + 1): + for o in range(0, time_order + 1): + index_matrix.append([n - i - j, j, i, o]) + + return np.array(index_matrix) + + +def qtdmri_isotropic_index_matrix(radial_order, time_order): + """Computes the SHORE basis order indices according to [1]. + """ + index_matrix = [] + for n in range(0, radial_order + 1, 2): + for j in range(1, 2 + n // 2): + ll = n + 2 - 2 * j + for m in range(-ll, ll + 1): + for o in range(0, time_order + 1): + index_matrix.append([j, ll, m, o]) + return np.array(index_matrix) + + +def qtdmri_laplacian_reg_matrix(ind_mat, us, ut, + S_mat=None, T_mat=None, U_mat=None, + part1_ut_precomp=None, + part23_ut_precomp=None, + part4_ut_precomp=None, + normalization=False): + """Computes the cartesian qt-dMRI Laplacian regularization matrix. If + given, uses precomputed matrices for temporal and spatial regularization + matrices to speed up computation. Follows the the formulation of Appendix B + in [1]. + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + if S_mat is None or T_mat is None or U_mat is None: + radial_order = ind_mat[:, :3].max() + S_mat, T_mat, U_mat = mapmri.mapmri_STU_reg_matrices(radial_order) + + part1_us = mapmri.mapmri_laplacian_reg_matrix(ind_mat[:, :3], us, + S_mat, T_mat, U_mat) + part23_us = part23_reg_matrix_q(ind_mat, U_mat, T_mat, us) + part4_us = part4_reg_matrix_q(ind_mat, U_mat, us) + + if part1_ut_precomp is None: + part1_ut = part1_reg_matrix_tau(ind_mat, ut) + else: + part1_ut = part1_ut_precomp / ut + if part23_ut_precomp is None: + part23_ut = part23_reg_matrix_tau(ind_mat, ut) + else: + part23_ut = part23_ut_precomp * ut + if part4_ut_precomp is None: + part4_ut = part4_reg_matrix_tau(ind_mat, ut) + else: + part4_ut = part4_ut_precomp * ut ** 3 + + regularization_matrix = ( + part1_us * part1_ut + part23_us * part23_ut + part4_us * part4_ut + ) + + if normalization: + temporal_normalization = qtdmri_temporal_normalization(ut) ** 2 + spatial_normalization = qtdmri_mapmri_normalization(us) ** 2 + regularization_matrix *= temporal_normalization * spatial_normalization + return regularization_matrix + + +def qtdmri_isotropic_laplacian_reg_matrix(ind_mat, us, ut, + part1_uq_iso_precomp=None, + part1_ut_precomp=None, + part23_ut_precomp=None, + part4_ut_precomp=None, + normalization=False): + """Computes the spherical qt-dMRI Laplacian regularization matrix. If + given, uses precomputed matrices for temporal and spatial regularization + matrices to speed up computation. Follows the the formulation of Appendix C + in [1]. + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + if part1_uq_iso_precomp is None: + part1_us = ( + mapmri.mapmri_isotropic_laplacian_reg_matrix_from_index_matrix( + ind_mat[:, :3], us[0] + ) + ) + else: + part1_us = part1_uq_iso_precomp * us[0] + + if part1_ut_precomp is None: + part1_ut = part1_reg_matrix_tau(ind_mat, ut) + else: + part1_ut = part1_ut_precomp / ut + + if part23_ut_precomp is None: + part23_ut = part23_reg_matrix_tau(ind_mat, ut) + else: + part23_ut = part23_ut_precomp * ut + + if part4_ut_precomp is None: + part4_ut = part4_reg_matrix_tau(ind_mat, ut) + else: + part4_ut = part4_ut_precomp * ut ** 3 + + part23_us = part23_iso_reg_matrix_q(ind_mat, us[0]) + part4_us = part4_iso_reg_matrix_q(ind_mat, us[0]) + + regularization_matrix = ( + part1_us * part1_ut + part23_us * part23_ut + part4_us * part4_ut + ) + + if normalization: + temporal_normalization = qtdmri_temporal_normalization(ut) ** 2 + spatial_normalization = np.zeros_like(regularization_matrix) + j, ll = ind_mat[:, :2].T + pre_spatial_norm = qtdmri_mapmri_isotropic_normalization(j, ll, us[0]) + spatial_normalization = np.outer(pre_spatial_norm, pre_spatial_norm) + regularization_matrix *= temporal_normalization * spatial_normalization + return regularization_matrix + + +def part23_reg_matrix_q(ind_mat, U_mat, T_mat, us): + """Partial cartesian spatial Laplacian regularization matrix following + second line of Eq. (B2) in [1]. + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + ux, uy, uz = us + x, y, z, _ = ind_mat.T + n_elem = int(ind_mat.shape[0]) + LR = np.zeros((n_elem, n_elem)) + for i in range(n_elem): + for k in range(i, n_elem): + val = 0 + if x[i] == x[k] and y[i] == y[k]: + val += ( + (uz / (ux * uy)) * + U_mat[x[i], x[k]] * U_mat[y[i], y[k]] * T_mat[z[i], z[k]] + ) + if x[i] == x[k] and z[i] == z[k]: + val += ( + (uy / (ux * uz)) * + U_mat[x[i], x[k]] * T_mat[y[i], y[k]] * U_mat[z[i], z[k]] + ) + if y[i] == y[k] and z[i] == z[k]: + val += ( + (ux / (uy * uz)) * + T_mat[x[i], x[k]] * U_mat[y[i], y[k]] * U_mat[z[i], z[k]] + ) + LR[i, k] = LR[k, i] = val + return LR + + +def part23_iso_reg_matrix_q(ind_mat, us): + """Partial spherical spatial Laplacian regularization matrix following the + equation below Eq. (C4) in [1]. + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + n_elem = int(ind_mat.shape[0]) + + LR = np.zeros((n_elem, n_elem)) + + for i in range(n_elem): + for k in range(i, n_elem): + if ind_mat[i, 1] == ind_mat[k, 1] and \ + ind_mat[i, 2] == ind_mat[k, 2]: + ji = ind_mat[i, 0] + jk = ind_mat[k, 0] + ll = ind_mat[i, 1] + if ji == (jk + 1): + LR[i, k] = LR[k, i] = ( + 2. ** (-ll) * -gamma(3 / 2.0 + jk + ll) / gamma(jk) + ) + elif ji == jk: + LR[i, k] = LR[k, i] = 2. ** (-(ll + 1)) *\ + (1 - 4 * ji - 2 * ll) *\ + gamma(1 / 2.0 + ji + ll) / gamma(ji) + elif ji == (jk - 1): + LR[i, k] = LR[k, i] = 2. ** (-ll) *\ + -gamma(3 / 2.0 + ji + ll) / gamma(ji) + return LR / us + + +def part4_reg_matrix_q(ind_mat, U_mat, us): + """Partial cartesian spatial Laplacian regularization matrix following + equation Eq. (B2) in [1]. + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + ux, uy, uz = us + x, y, z, _ = ind_mat.T + n_elem = int(ind_mat.shape[0]) + LR = np.zeros((n_elem, n_elem)) + for i in range(n_elem): + for k in range(i, n_elem): + if x[i] == x[k] and y[i] == y[k] and z[i] == z[k]: + LR[i, k] = LR[k, i] = ( + (1. / (ux * uy * uz)) * U_mat[x[i], x[k]] * + U_mat[y[i], y[k]] * U_mat[z[i], z[k]] + ) + return LR + + +def part4_iso_reg_matrix_q(ind_mat, us): + """Partial spherical spatial Laplacian regularization matrix following the + equation below Eq. (C4) in [1]. + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + n_elem = int(ind_mat.shape[0]) + LR = np.zeros((n_elem, n_elem)) + for i in range(n_elem): + for k in range(i, n_elem): + if ind_mat[i, 0] == ind_mat[k, 0] and \ + ind_mat[i, 1] == ind_mat[k, 1] and \ + ind_mat[i, 2] == ind_mat[k, 2]: + ji = ind_mat[i, 0] + ll = ind_mat[i, 1] + LR[i, k] = LR[k, i] = ( + 2. ** (-(ll + 2)) * gamma(1 / 2.0 + ji + ll) / + (np.pi ** 2 * gamma(ji)) + ) + + return LR / us ** 3 + + +def part1_reg_matrix_tau(ind_mat, ut): + """Partial temporal Laplacian regularization matrix following + Appendix B in [1]. + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + n_elem = int(ind_mat.shape[0]) + LD = np.zeros((n_elem, n_elem)) + for i in range(n_elem): + for k in range(i, n_elem): + oi = ind_mat[i, 3] + ok = ind_mat[k, 3] + if oi == ok: + LD[i, k] = LD[k, i] = 1. / ut + return LD + + +def part23_reg_matrix_tau(ind_mat, ut): + """Partial temporal Laplacian regularization matrix following + Appendix B in [1]. + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + n_elem = int(ind_mat.shape[0]) + LD = np.zeros((n_elem, n_elem)) + for i in range(n_elem): + for k in range(i, n_elem): + oi = ind_mat[i, 3] + ok = ind_mat[k, 3] + if oi == ok: + LD[i, k] = LD[k, i] = 1 / 2. + else: + LD[i, k] = LD[k, i] = np.abs(oi - ok) + return ut * LD + + +def part4_reg_matrix_tau(ind_mat, ut): + """Partial temporal Laplacian regularization matrix following + Appendix B in [1]. + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + n_elem = int(ind_mat.shape[0]) + LD = np.zeros((n_elem, n_elem)) + + for i in range(n_elem): + for k in range(i, n_elem): + oi = ind_mat[i, 3] + ok = ind_mat[k, 3] + + sum1 = 0 + for p in range(1, min([ok, oi]) + 1 + 1): + sum1 += (oi - p) * (ok - p) * H(min([oi, ok]) - p) + + sum2 = 0 + for p in range(0, min(ok - 2, oi - 1) + 1): + sum2 += p + + sum3 = 0 + for p in range(0, min(ok - 1, oi - 2) + 1): + sum3 += p + + LD[i, k] = LD[k, i] = ( + 0.25 * np.abs(oi - ok) + (1 / 16.) * mapmri.delta(oi, ok) + + min([oi, ok]) + sum1 + H(oi - 1) * H(ok - 1) * + (oi + ok - 2 + sum2 + sum3 + H(abs(oi - ok) - 1) * + (abs(oi - ok) - 1) * min([ok - 1, oi - 1])) + ) + return LD * ut ** 3 + + +def H(value): + """Step function of H(x)=1 if x>=0 and zero otherwise. Used for the + temporal laplacian matrix.""" + if value >= 0: + return 1 + return 0 + + +def generalized_crossvalidation(data, M, LR, startpoint=5e-4): + r"""Generalized Cross Validation Function [1]. + + References + ---------- + .. [1] Craven et al. "Smoothing Noisy Data with Spline Functions." + NUMER MATH 31.4 (1978): 377-403. + """ + startpoint = 1e-4 + MMt = np.dot(M.T, M) + K = len(data) + input_stuff = (data, M, MMt, K, LR) + + bounds = ((1e-5, 1),) + res = fmin_l_bfgs_b(lambda x, + input_stuff: GCV_cost_function(x, input_stuff), + (startpoint), args=(input_stuff,), approx_grad=True, + bounds=bounds, disp=False, pgtol=1e-10, factr=10.) + return res[0][0] + + +def GCV_cost_function(weight, arguments): + r"""Generalized Cross Validation Function that is iterated [1]. + + References + ---------- + .. [1] Craven et al. "Smoothing Noisy Data with Spline Functions." + NUMER MATH 31.4 (1978): 377-403. + """ + data, M, MMt, K, LR = arguments + S = np.dot(np.dot(M, np.linalg.pinv(MMt + weight * LR)), M.T) + trS = np.matrix.trace(S) + normyytilde = np.linalg.norm(data - np.dot(S, data), 2) + gcv_value = normyytilde / (K - trS) + return gcv_value + + +def qtdmri_isotropic_scaling(data, q, tau): + """ Constructs design matrix for fitting an exponential to the + diffusion time points. + """ + dataclip = np.clip(data, 1e-05, 1.) + logE = -np.log(dataclip) + logE_q = logE / (2 * np.pi ** 2) + logE_tau = logE * 2 + + B_q = np.array([q * q]) + inv_B_q = np.linalg.pinv(B_q) + + B_tau = np.array([tau]) + inv_B_tau = np.linalg.pinv(B_tau) + + us = np.sqrt(np.dot(logE_q, inv_B_q)) + ut = np.dot(logE_tau, inv_B_tau) + return us, ut + + +def qtdmri_anisotropic_scaling(data, q, bvecs, tau): + """ Constructs design matrix for fitting an exponential to the + diffusion time points. + """ + dataclip = np.clip(data, 1e-05, 10e10) + logE = -np.log(dataclip) + logE_q = logE / (2 * np.pi ** 2) + logE_tau = logE * 2 + + B_q = design_matrix_spatial(bvecs, q) + inv_B_q = np.linalg.pinv(B_q) + A = np.dot(inv_B_q, logE_q) + + evals, R = dti.decompose_tensor(dti.from_lower_triangular(A)) + us = np.sqrt(evals) + + B_tau = np.array([tau]) + inv_B_tau = np.linalg.pinv(B_tau) + + ut = np.dot(logE_tau, inv_B_tau) + + return us, ut, R + + +def design_matrix_spatial(bvecs, qvals, dtype=None): + """ Constructs design matrix for DTI weighted least squares or + least squares fitting. (Basser et al., 1994a) + + Parameters + ---------- + bvecs : array (N x 3) + unit b-vectors of the acquisition. + qvals : array (N,) + corresponding q-values in 1/mm + + Returns + ------- + design_matrix : array (g,7) + Design matrix or B matrix assuming Gaussian distributed tensor model + design_matrix[j, :] = (Bxx, Byy, Bzz, Bxy, Bxz, Byz, dummy) + """ + B = np.zeros((bvecs.shape[0], 6)) + B[:, 0] = bvecs[:, 0] * bvecs[:, 0] * 1. * qvals ** 2 # Bxx + B[:, 1] = bvecs[:, 0] * bvecs[:, 1] * 2. * qvals ** 2 # Bxy + B[:, 2] = bvecs[:, 1] * bvecs[:, 1] * 1. * qvals ** 2 # Byy + B[:, 3] = bvecs[:, 0] * bvecs[:, 2] * 2. * qvals ** 2 # Bxz + B[:, 4] = bvecs[:, 1] * bvecs[:, 2] * 2. * qvals ** 2 # Byz + B[:, 5] = bvecs[:, 2] * bvecs[:, 2] * 1. * qvals ** 2 # Bzz + return B + + +def create_rt_space_grid(grid_size_r, max_radius_r, grid_size_tau, + min_radius_tau, max_radius_tau): + """ Generates EAP grid (for potential positivity constraint).""" + tau_list = np.linspace(min_radius_tau, max_radius_tau, grid_size_tau) + constraint_grid_tau = np.c_[0., 0., 0., 0.] + for tau in tau_list: + constraint_grid = mapmri.create_rspace(grid_size_r, max_radius_r) + constraint_grid_tau = np.vstack( + [constraint_grid_tau, + np.c_[constraint_grid, np.zeros(constraint_grid.shape[0]) + tau]] + ) + return constraint_grid_tau[1:] + + +def qtdmri_number_of_coefficients(radial_order, time_order): + """Computes the total number of coefficients of the qtdmri basis given a + radial and temporal order. Equation given below Eq (9) in [1]. + + References + ---------- + .. [1] Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. + """ + F = np.floor(radial_order / 2.) + Msym = (F + 1) * (F + 2) * (4 * F + 3) / 6 + M_total = Msym * (time_order + 1) + return M_total + + +def l1_crossvalidation(b0s_mask, E, M, weight_array=np.linspace(0, .4, 21)): + """cross-validation function to find the optimal weight of alpha for + sparsity regularization""" + dwi_mask = ~b0s_mask + b0_mask = b0s_mask + dwi_indices = np.arange(E.shape[0])[dwi_mask] + b0_indices = np.arange(E.shape[0])[b0_mask] + random.shuffle(dwi_indices) + + sub0 = dwi_indices[0::5] + sub1 = dwi_indices[1::5] + sub2 = dwi_indices[2::5] + sub3 = dwi_indices[3::5] + sub4 = dwi_indices[4::5] + + test0 = np.hstack((b0_indices, sub1, sub2, sub3, sub4)) + test1 = np.hstack((b0_indices, sub0, sub2, sub3, sub4)) + test2 = np.hstack((b0_indices, sub0, sub1, sub3, sub4)) + test3 = np.hstack((b0_indices, sub0, sub1, sub2, sub4)) + test4 = np.hstack((b0_indices, sub0, sub1, sub2, sub3)) + + cv_list = ( + (sub0, test0), + (sub1, test1), + (sub2, test2), + (sub3, test3), + (sub4, test4) + ) + + errorlist = np.zeros((5, 21)) + errorlist[:, 0] = 100. + optimal_alpha_sub = np.zeros(5) + for i, (sub, test) in enumerate(cv_list): + counter = 1 + cv_old = errorlist[i, 0] + cv_new = errorlist[i, 0] + while cv_old >= cv_new and counter < weight_array.shape[0]: + alpha = weight_array[counter] + c = cvxpy.Variable(M.shape[1]) + design_matrix = cvxpy.Constant(M[test]) + design_matrix_to_recover = cvxpy.Constant(M[sub]) + data = cvxpy.Constant(E[test]) + objective = cvxpy.Minimize( + cvxpy.sum_squares(design_matrix * c - data) + + alpha * cvxpy.norm1(c) + ) + constraints = [] + prob = cvxpy.Problem(objective, constraints) + prob.solve(solver="ECOS", verbose=False) + recovered_signal = design_matrix_to_recover * c + errorlist[i, counter] = np.mean( + (E[sub] - np.asarray(recovered_signal.value).squeeze()) ** 2) + cv_old = errorlist[i, counter - 1] + cv_new = errorlist[i, counter] + counter += 1 + optimal_alpha_sub[i] = weight_array[counter - 1] + optimal_alpha = optimal_alpha_sub.mean() + return optimal_alpha + + +def elastic_crossvalidation(b0s_mask, E, M, L, lopt, + weight_array=np.linspace(0, .2, 21)): + """cross-validation function to find the optimal weight of alpha for + sparsity regularization when also Laplacian regularization is used.""" + dwi_mask = ~b0s_mask + b0_mask = b0s_mask + dwi_indices = np.arange(E.shape[0])[dwi_mask] + b0_indices = np.arange(E.shape[0])[b0_mask] + random.shuffle(dwi_indices) + + sub0 = dwi_indices[0::5] + sub1 = dwi_indices[1::5] + sub2 = dwi_indices[2::5] + sub3 = dwi_indices[3::5] + sub4 = dwi_indices[4::5] + + test0 = np.hstack((b0_indices, sub1, sub2, sub3, sub4)) + test1 = np.hstack((b0_indices, sub0, sub2, sub3, sub4)) + test2 = np.hstack((b0_indices, sub0, sub1, sub3, sub4)) + test3 = np.hstack((b0_indices, sub0, sub1, sub2, sub4)) + test4 = np.hstack((b0_indices, sub0, sub1, sub2, sub3)) + + cv_list = ( + (sub0, test0), + (sub1, test1), + (sub2, test2), + (sub3, test3), + (sub4, test4) + ) + + errorlist = np.zeros((5, 21)) + errorlist[:, 0] = 100. + optimal_alpha_sub = np.zeros(5) + for i, (sub, test) in enumerate(cv_list): + counter = 1 + cv_old = errorlist[i, 0] + cv_new = errorlist[i, 0] + c = cvxpy.Variable(M.shape[1]) + design_matrix = cvxpy.Constant(M[test]) + design_matrix_to_recover = cvxpy.Constant(M[sub]) + data = cvxpy.Constant(E[test]) + constraints = [] + while cv_old >= cv_new and counter < weight_array.shape[0]: + alpha = weight_array[counter] + objective = cvxpy.Minimize( + cvxpy.sum_squares(design_matrix * c - data) + + alpha * cvxpy.norm1(c) + + lopt * cvxpy.quad_form(c, L) + ) + prob = cvxpy.Problem(objective, constraints) + prob.solve(solver="ECOS", verbose=False) + recovered_signal = design_matrix_to_recover * c + errorlist[i, counter] = np.mean( + (E[sub] - np.asarray(recovered_signal.value).squeeze()) ** 2) + cv_old = errorlist[i, counter - 1] + cv_new = errorlist[i, counter] + counter += 1 + optimal_alpha_sub[i] = weight_array[counter - 1] + optimal_alpha = optimal_alpha_sub.mean() + return optimal_alpha + + +def visualise_gradient_table_G_Delta_rainbow( + gtab, + big_delta_start=None, big_delta_end=None, G_start=None, G_end=None, + bval_isolines=np.r_[0, 250, 1000, 2500, 5000, 7500, 10000, 14000], + alpha_shading=0.6): + """This function visualizes a q-tau acquisition scheme as a function of + gradient strength and pulse separation (big_delta). It represents every + measurements at its G and big_delta position regardless of b-vector, with a + background of b-value isolines for reference. It assumes there is only one + unique pulse length (small_delta) in the acquisition scheme. + + Parameters + ---------- + gtab : GradientTable object + constructed gradient table with big_delta and small_delta given as + inputs. + big_delta_start : float, + optional minimum big_delta that is plotted in seconds + big_delta_end : float, + optional maximum big_delta that is plotted in seconds + G_start : float, + optional minimum gradient strength that is plotted in T/m + G_end : float, + optional maximum gradient strength taht is plotted in T/m + bval_isolines : array, + optional array of bvalue isolines that are plotted in the background + alpha_shading : float between [0-1] + optional shading of the bvalue colors in the background + """ + Delta = gtab.big_delta # in seconds + delta = gtab.small_delta # in seconds + G = gtab.gradient_strength * 1e3 # in SI units T/m + + if len(np.unique(delta)) > 1: + msg = "This acquisition has multiple small_delta values. " + msg += "This visualization assumes there is only one small_delta." + raise ValueError(msg) + + if big_delta_start is None: + big_delta_start = 0.005 + if big_delta_end is None: + big_delta_end = Delta.max() + 0.004 + if G_start is None: + G_start = 0. + if G_end is None: + G_end = G.max() + .05 + + Delta_ = np.linspace(big_delta_start, big_delta_end, 50) + G_ = np.linspace(G_start, G_end, 50) + Delta_grid, G_grid = np.meshgrid(Delta_, G_) + dummy_bvecs = np.tile([0, 0, 1], (len(G_grid.ravel()), 1)) + gtab_grid = gradient_table_from_gradient_strength_bvecs( + G_grid.ravel() / 1e3, dummy_bvecs, Delta_grid.ravel(), delta[0] + ) + bvals_ = gtab_grid.bvals.reshape(G_grid.shape) + + plt.contourf(Delta_, G_, bvals_, + levels=bval_isolines, + cmap='rainbow', alpha=alpha_shading) + cb = plt.colorbar(spacing="proportional") + cb.ax.tick_params(labelsize=16) + plt.scatter(Delta, G, c='k', s=25) + + plt.xlim(big_delta_start, big_delta_end) + plt.ylim(G_start, G_end) + cb.set_label('b-value ($s$/$mm^2$)', fontsize=18) + plt.xlabel('Pulse Separation $\Delta$ [sec]', fontsize=18) + plt.ylabel('Gradient Strength [T/m]', fontsize=18) + return None diff --git a/dipy/reconst/shm.py b/dipy/reconst/shm.py index c878a632ea..82f0985ebf 100755 --- a/dipy/reconst/shm.py +++ b/dipy/reconst/shm.py @@ -28,6 +28,7 @@ from numpy import concatenate, diag, diff, empty, eye, sqrt, unique, dot from numpy.linalg import pinv, svd from numpy.random import randint +import warnings from dipy.reconst.odf import OdfModel, OdfFit from dipy.core.geometry import cart2sphere @@ -241,8 +242,8 @@ def real_sph_harm(m, n, theta, phi): def real_sym_sh_mrtrix(sh_order, theta, phi): """ - Compute real spherical harmonics as in mrtrix, where the real harmonic - $Y^m_n$ is defined to be:: + Compute real spherical harmonics as in Tournier 2007 [2]_, where the real + harmonic $Y^m_n$ is defined to be:: Real($Y^m_n$) if m > 0 $Y^0_n$ if m = 0 @@ -264,13 +265,24 @@ def real_sym_sh_mrtrix(sh_order, theta, phi): -------- y_mn : real float The real harmonic $Y^m_n$ sampled at `theta` and `phi` as - implemented in mrtrix. Warning: the basis is Tournier et al - 2004 and 2007 is slightly different. + implemented in mrtrix. Warning: the basis is Tournier et al. + 2007 [2]_; 2004 [1]_ is slightly different. m : array The order of the harmonics. n : array The degree of the harmonics. + References + ---------- + .. [1] Tournier J.D., Calamante F., Gadian D.G. and Connelly A. + Direct estimation of the fibre orientation density function from + diffusion-weighted MRI data using spherical deconvolution. + NeuroImage. 2004;23:1176-1185. + .. [2] Tournier J.D., Calamante F. and Connelly A. Robust determination + of the fibre orientation distribution in diffusion MRI: + Non-negativity constrained super-resolved spherical deconvolution. + NeuroImage. 2007;35(4):1459-1472. + """ m, n = sph_harm_ind_list(sh_order) phi = np.reshape(phi, [-1, 1]) @@ -287,8 +299,8 @@ def real_sym_sh_basis(sh_order, theta, phi): Samples the basis functions up to order `sh_order` at points on the sphere given by `theta` and `phi`. The basis functions are defined here the same - way as in fibernavigator [1]_ where the real harmonic $Y^m_n$ is defined to - be: + way as in Descoteaux et al. 2007 [1]_ where the real harmonic $Y^m_n$ is + defined to be: Imag($Y^m_n$) * sqrt(2) if m > 0 $Y^0_n$ if m = 0 @@ -317,7 +329,9 @@ def real_sym_sh_basis(sh_order, theta, phi): References ---------- - .. [1] https://github.com/scilus/fibernavigator + .. [1] Descoteaux, M., Angelino, E., Fitzgibbons, S. and Deriche, R. + Regularized, Fast, and Robust Analytical Q-ball Imaging. + Magn. Reson. Med. 2007;58:497-510. """ m, n = sph_harm_ind_list(sh_order) @@ -330,7 +344,9 @@ def real_sym_sh_basis(sh_order, theta, phi): sph_harm_lookup = {None: real_sym_sh_basis, "mrtrix": real_sym_sh_mrtrix, - "fibernav": real_sym_sh_basis} + "fibernav": real_sym_sh_basis, + "tournier07": real_sym_sh_mrtrix, + "descoteaux07": real_sym_sh_basis} def sph_harm_ind_list(sh_order): @@ -861,11 +877,11 @@ def sf_to_sh(sf, sphere, sh_order=4, basis_type=None, smooth=0.0): sh_order : int, optional Maximum SH order in the SH fit. For `sh_order`, there will be ``(sh_order + 1) * (sh_order_2) / 2`` SH coefficients (default 4). - basis_type : {None, 'mrtrix', 'fibernav'} - ``None`` for the default dipy basis, - ``mrtrix`` for the MRtrix basis, and - ``fibernav`` for the FiberNavigator basis - (default ``None``). + basis_type : {None, 'tournier07', 'descoteaux07'} + ``None`` for the default DIPY basis, + ``tournier07`` for the Tournier 2007 [2]_ basis, and + ``descoteaux07`` for the Descoteaux 2007 [1]_ basis + (``None`` defaults to ``descoteaux07``). smooth : float, optional Lambda-regularization in the SH fit (default 0.0). @@ -874,7 +890,29 @@ def sf_to_sh(sf, sphere, sh_order=4, basis_type=None, smooth=0.0): sh : ndarray SH coefficients representing the input function. + References + ---------- + .. [1] Descoteaux, M., Angelino, E., Fitzgibbons, S. and Deriche, R. + Regularized, Fast, and Robust Analytical Q-ball Imaging. + Magn. Reson. Med. 2007;58:497-510. + .. [2] Tournier J.D., Calamante F. and Connelly A. Robust determination + of the fibre orientation distribution in diffusion MRI: + Non-negativity constrained super-resolved spherical deconvolution. + NeuroImage. 2007;35(4):1459-1472. + """ + + if basis_type == "fibernav": + warnings.warn("sh basis type `fibernav` is deprecated as of version" + + " 0.15 of DIPY and will be removed in a future " + + "version. Please use `descoteaux07` instead", + DeprecationWarning) + if basis_type == "mrtrix": + warnings.warn("sh basis type `mrtrix` is deprecated as of version" + + " 0.15 of DIPY and will be removed in a future " + + "version. Please use `tournier07` instead", + DeprecationWarning) + sph_harm_basis = sph_harm_lookup.get(basis_type) if sph_harm_basis is None: @@ -900,18 +938,40 @@ def sh_to_sf(sh, sphere, sh_order, basis_type=None): sh_order : int, optional Maximum SH order in the SH fit. For `sh_order`, there will be ``(sh_order + 1) * (sh_order_2) / 2`` SH coefficients (default 4). - basis_type : {None, 'mrtrix', 'fibernav'} - ``None`` for the default dipy basis, - ``mrtrix`` for the MRtrix basis, and - ``fibernav`` for the FiberNavigator basis - (default ``None``). + basis_type : {None, 'tournier07', 'descoteaux07'} + ``None`` for the default DIPY basis, + ``tournier07`` for the Tournier 2007 [2]_ basis, and + ``descoteaux07`` for the Descoteaux 2007 [1]_ basis + (``None`` defaults to ``descoteaux07``). Returns ------- sf : ndarray Spherical function values on the `sphere`. + References + ---------- + .. [1] Descoteaux, M., Angelino, E., Fitzgibbons, S. and Deriche, R. + Regularized, Fast, and Robust Analytical Q-ball Imaging. + Magn. Reson. Med. 2007;58:497-510. + .. [2] Tournier J.D., Calamante F. and Connelly A. Robust determination + of the fibre orientation distribution in diffusion MRI: + Non-negativity constrained super-resolved spherical deconvolution. + NeuroImage. 2007;35(4):1459-1472. + """ + + if basis_type == 'fibernav': + warnings.warn("sh basis type `fibernav` is deprecated as of version" + + " 0.15 of DIPY and will be removed in a future " + + "version. Please use `descoteaux07` instead", + DeprecationWarning) + elif basis_type == 'mrtrix': + warnings.warn("sh basis type `mrtrix` is deprecated as of version" + + " 0.15 of DIPY and will be removed in a future " + + "version. Please use `tournier07` instead", + DeprecationWarning) + sph_harm_basis = sph_harm_lookup.get(basis_type) if sph_harm_basis is None: @@ -935,11 +995,11 @@ def sh_to_sf_matrix(sphere, sh_order, basis_type=None, return_inv=True, sh_order : int, optional Maximum SH order in the SH fit. For `sh_order`, there will be ``(sh_order + 1) * (sh_order_2) / 2`` SH coefficients (default 4). - basis_type : {None, 'mrtrix', 'fibernav'} - ``None`` for the default dipy basis, - ``mrtrix`` for the MRtrix basis, and - ``fibernav`` for the FiberNavigator basis - (default ``None``). + basis_type : {None, 'tournier07', 'descoteaux07'} + ``None`` for the default DIPY basis, + ``tournier07`` for the Tournier 2007 [2]_ basis, and + ``descoteaux07`` for the Descoteaux 2007 [1]_ basis + (``None`` defaults to ``descoteaux07``). return_inv : bool If True then the inverse of the matrix is also returned smooth : float, optional @@ -953,7 +1013,29 @@ def sh_to_sf_matrix(sphere, sh_order, basis_type=None, return_inv=True, invB : ndarray Inverse of B. + References + ---------- + .. [1] Descoteaux, M., Angelino, E., Fitzgibbons, S. and Deriche, R. + Regularized, Fast, and Robust Analytical Q-ball Imaging. + Magn. Reson. Med. 2007;58:497-510. + .. [2] Tournier J.D., Calamante F. and Connelly A. Robust determination + of the fibre orientation distribution in diffusion MRI: + Non-negativity constrained super-resolved spherical deconvolution. + NeuroImage. 2007;35(4):1459-1472. + """ + + if basis_type == 'fibernav': + warnings.warn("sh basis type `fibernav` is deprecated as of version" + + " 0.15 of DIPY and will be removed in a future " + + "version. Please use `descoteaux07` instead", + DeprecationWarning) + elif basis_type == 'mrtrix': + warnings.warn("sh basis type `mrtrix` is deprecated as of version" + + " 0.15 of DIPY and will be removed in a future " + + "version. Please use `tournier07` instead", + DeprecationWarning) + sph_harm_basis = sph_harm_lookup.get(basis_type) if sph_harm_basis is None: @@ -997,9 +1079,20 @@ def calculate_max_order(n_coeffs): Finally, the positive value is chosen between the two options. """ - L1 = (-3 + np.sqrt(1 + 8 * n_coeffs)) / 2 - L2 = (-3 - np.sqrt(1 + 8 * n_coeffs)) / 2 - return np.int(max([L1, L2])) + # L2 is negative for all positive values of n_coeffs, so we don't + # bother even computing it: + # L2 = (-3 - np.sqrt(1 + 8 * n_coeffs)) / 2 + # L1 is always the larger value, so we go with that: + L1 = (-3 + np.sqrt(1 + 8 * n_coeffs)) / 2.0 + # Check that it is a whole even number: + if L1.is_integer() and not np.mod(L1, 2): + return int(L1) + else: + # Otherwise, the input didn't make sense: + raise ValueError("The input to ``calculate_max_order`` was ", + "%s, but that is not a valid number" % n_coeffs, + "of coefficients for a spherical harmonics ", + "basis set.") def anisotropic_power(sh_coeffs, norm_factor=0.00001, power=2, diff --git a/dipy/reconst/shore.py b/dipy/reconst/shore.py index 0ffb4e8264..41cca7b738 100644 --- a/dipy/reconst/shore.py +++ b/dipy/reconst/shore.py @@ -152,9 +152,9 @@ def __init__(self, with respect to the SHORE basis and compute the real and analytical ODF. - from dipy.data import get_data,get_sphere + from dipy.data import get_fnames,get_sphere sphere = get_sphere('symmetric724') - fimg, fbvals, fbvecs = get_data('ISBI_testing_2shells_table') + fimg, fbvals, fbvecs = get_fnames('ISBI_testing_2shells_table') bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) from dipy.sims.voxel import SticksAndBall @@ -269,7 +269,7 @@ def fit(self, data): self.cache_set( 'shore_matrix_positive_constraint', (self.pos_grid, self.pos_radius), psi) - constraints = [M0[0] * c == 1., psi * c > 1e-3] + constraints = [(M0[0] * c) == 1., (psi * c) >= 1e-3] prob = cvxpy.Problem(objective, constraints) try: prob.solve(solver=self.cvxpy_solver) diff --git a/dipy/reconst/tests/test_cross_validation.py b/dipy/reconst/tests/test_cross_validation.py index ac8f36fc16..10e1515495 100644 --- a/dipy/reconst/tests/test_cross_validation.py +++ b/dipy/reconst/tests/test_cross_validation.py @@ -17,7 +17,7 @@ # We'll set these globally: -fdata, fbval, fbvec = dpd.get_data('small_64D') +fdata, fbval, fbvec = dpd.get_fnames('small_64D') def test_coeff_of_determination(): diff --git a/dipy/reconst/tests/test_csdeconv.py b/dipy/reconst/tests/test_csdeconv.py index 312e29217e..427368fb00 100644 --- a/dipy/reconst/tests/test_csdeconv.py +++ b/dipy/reconst/tests/test_csdeconv.py @@ -5,7 +5,7 @@ from numpy.testing import (assert_, assert_equal, assert_almost_equal, assert_array_almost_equal, run_module_suite, assert_array_equal, assert_warns) -from dipy.data import get_sphere, get_data, default_sphere, small_sphere +from dipy.data import get_sphere, get_fnames, default_sphere, small_sphere from dipy.sims.voxel import (multi_tensor, single_tensor, multi_tensor_odf, @@ -31,6 +31,7 @@ import dipy.reconst.dti as dti from dipy.reconst.dti import fractional_anisotropy from dipy.core.sphere import Sphere +from dipy.io.gradients import read_bvals_bvecs def test_recursive_response_calibration(): @@ -40,10 +41,9 @@ def test_recursive_response_calibration(): SNR = 100 S0 = 1 - _, fbvals, fbvecs = get_data('small_64D') + _, fbvals, fbvecs = get_fnames('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) sphere = get_sphere('symmetric724') gtab = gradient_table(bvals, bvecs) @@ -109,9 +109,8 @@ def test_recursive_response_calibration(): def test_auto_response(): - fdata, fbvals, fbvecs = get_data('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + fdata, fbvals, fbvecs = get_fnames('small_64D') + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) data = nib.load(fdata).get_data() gtab = gradient_table(bvals, bvecs) @@ -153,9 +152,8 @@ def test_fa_inferior(FA, fa_thr): def test_response_from_mask(): - fdata, fbvals, fbvecs = get_data('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + fdata, fbvals, fbvecs = get_fnames('small_64D') + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) data = nib.load(fdata).get_data() gtab = gradient_table(bvals, bvecs) @@ -193,10 +191,9 @@ def test_csdeconv(): SNR = 100 S0 = 1 - _, fbvals, fbvecs = get_data('small_64D') + _, fbvals, fbvecs = get_fnames('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) mevals = np.array(([0.0015, 0.0003, 0.0003], [0.0015, 0.0003, 0.0003])) @@ -260,9 +257,8 @@ def test_odfdeconv(): SNR = 100 S0 = 1 - _, fbvals, fbvecs = get_data('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + _, fbvals, fbvecs = get_fnames('small_64D') + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) mevals = np.array(([0.0015, 0.0003, 0.0003], [0.0015, 0.0003, 0.0003])) @@ -318,9 +314,8 @@ def test_odfdeconv(): def test_odf_sh_to_sharp(): SNR = None S0 = 1 - _, fbvals, fbvecs = get_data('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + _, fbvals, fbvecs = get_fnames('small_64D') + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) mevals = np.array(([0.0015, 0.0003, 0.0003], [0.0015, 0.0003, 0.0003])) @@ -378,10 +373,9 @@ def test_r2_term_odf_sharp(): S0 = 1 angle = 45 # 45 degrees is a very tight angle to disentangle - _, fbvals, fbvecs = get_data('small_64D') # get_data('small_64D') + _, fbvals, fbvecs = get_fnames('small_64D') # get_fnames('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) sphere = get_sphere('symmetric724') gtab = gradient_table(bvals, bvecs) @@ -424,9 +418,8 @@ def test_csd_predict(): """ SNR = 100 S0 = 1 - _, fbvals, fbvecs = get_data('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + _, fbvals, fbvecs = get_fnames('small_64D') + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) mevals = np.array(([0.0015, 0.0003, 0.0003], [0.0015, 0.0003, 0.0003])) @@ -476,9 +469,8 @@ def test_csd_predict_multi(): """ S0 = 123. - _, fbvals, fbvecs = get_data('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + _, fbvals, fbvecs = get_fnames('small_64D') + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) response = (np.array([0.0015, 0.0003, 0.0003]), S0) csd = ConstrainedSphericalDeconvModel(gtab, response) @@ -495,10 +487,9 @@ def test_csd_predict_multi(): def test_sphere_scaling_csdmodel(): """Check that mirroring regularization sphere does not change the result of the model""" - _, fbvals, fbvecs = get_data('small_64D') + _, fbvals, fbvecs = get_fnames('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) mevals = np.array(([0.0015, 0.0003, 0.0003], @@ -533,9 +524,8 @@ def test_default_lambda_csdmodel(): sphere = default_sphere # Create gradient table - _, fbvals, fbvecs = get_data('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + _, fbvals, fbvecs = get_fnames('small_64D') + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) # Some response function @@ -551,9 +541,8 @@ def test_default_lambda_csdmodel(): def test_csd_superres(): """ Check the quality of csdfit with high SH order. """ - _, fbvals, fbvecs = get_data('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + _, fbvals, fbvecs = get_fnames('small_64D') + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) # img, gtab = read_stanford_hardi() diff --git a/dipy/reconst/tests/test_dki.py b/dipy/reconst/tests/test_dki.py index f04628c939..6c4636a755 100644 --- a/dipy/reconst/tests/test_dki.py +++ b/dipy/reconst/tests/test_dki.py @@ -12,7 +12,7 @@ from dipy.sims.voxel import multi_tensor_dki from dipy.io.gradients import read_bvals_bvecs from dipy.core.gradients import gradient_table -from dipy.data import get_data +from dipy.data import get_fnames from dipy.reconst.dti import (from_lower_triangular, decompose_tensor) from dipy.reconst.dki import (mean_kurtosis, carlson_rf, carlson_rd, axial_kurtosis, radial_kurtosis, _positive_evals, @@ -22,7 +22,7 @@ from dipy.data import get_sphere from dipy.core.geometry import (sphere2cart, perpendicular_directions) -fimg, fbvals, fbvecs = get_data('small_64D') +fimg, fbvals, fbvecs = get_fnames('small_64D') bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) @@ -720,15 +720,15 @@ def test_multi_voxel_kurtosis_maximum(): # TEST - when no sphere is given k_max = dki.kurtosis_maximum(dkiF.model_params) - assert_almost_equal(k_max, RK, decimal=5) + assert_almost_equal(k_max, RK, decimal=4) # TEST - when sphere is given k_max = dki.kurtosis_maximum(dkiF.model_params, sphere) - assert_almost_equal(k_max, RK, decimal=5) + assert_almost_equal(k_max, RK, decimal=4) # TEST - when mask is given mask = np.ones((2, 2, 2), dtype='bool') mask[1, 1, 1] = 0 RK[1, 1, 1] = 0 k_max = dki.kurtosis_maximum(dkiF.model_params, mask=mask) - assert_almost_equal(k_max, RK, decimal=5) + assert_almost_equal(k_max, RK, decimal=4) diff --git a/dipy/reconst/tests/test_dki_micro.py b/dipy/reconst/tests/test_dki_micro.py index f05c7995bc..bdb08ec60f 100644 --- a/dipy/reconst/tests/test_dki_micro.py +++ b/dipy/reconst/tests/test_dki_micro.py @@ -10,12 +10,12 @@ from dipy.sims.voxel import (multi_tensor_dki, _check_directions, multi_tensor) from dipy.io.gradients import read_bvals_bvecs from dipy.core.gradients import gradient_table -from dipy.data import get_data +from dipy.data import get_fnames from dipy.reconst.dti import (eig_from_lo_tri) from dipy.data import get_sphere -fimg, fbvals, fbvecs = get_data('small_64D') +fimg, fbvals, fbvecs = get_fnames('small_64D') bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) diff --git a/dipy/reconst/tests/test_dsi.py b/dipy/reconst/tests/test_dsi.py index fa31561d3f..fc26f6d7c9 100644 --- a/dipy/reconst/tests/test_dsi.py +++ b/dipy/reconst/tests/test_dsi.py @@ -4,7 +4,7 @@ run_module_suite, assert_array_equal, assert_raises) -from dipy.data import get_data, dsi_voxels +from dipy.data import get_fnames, dsi_voxels from dipy.reconst.dsi import DiffusionSpectrumModel from dipy.reconst.odf import gfa from dipy.direction.peaks import peak_directions @@ -16,6 +16,7 @@ from dipy.core.subdivide_octahedron import create_unit_sphere from dipy.core.sphere_stats import angular_similarity +from dipy.testing import setup_test def test_dsi(): # load symmetric 724 sphere @@ -23,7 +24,7 @@ def test_dsi(): # load icosahedron sphere sphere2 = create_unit_sphere(5) - btable = np.loadtxt(get_data('dsi515btable')) + btable = np.loadtxt(get_fnames('dsi515btable')) gtab = gradient_table(btable[:, 0], btable[:, 1:]) data, golden_directions = SticksAndBall(gtab, d=0.0015, S0=100, angles=[(0, 0), (90, 0)], diff --git a/dipy/reconst/tests/test_dsi_deconv.py b/dipy/reconst/tests/test_dsi_deconv.py index ef69ffd88e..1bc09c5f1d 100644 --- a/dipy/reconst/tests/test_dsi_deconv.py +++ b/dipy/reconst/tests/test_dsi_deconv.py @@ -4,7 +4,7 @@ run_module_suite, assert_array_equal, assert_raises) -from dipy.data import get_data, dsi_deconv_voxels +from dipy.data import get_fnames, dsi_deconv_voxels from dipy.reconst.dsi import DiffusionSpectrumDeconvModel from dipy.reconst.odf import gfa from dipy.direction.peaks import peak_directions @@ -17,13 +17,14 @@ from dipy.core.sphere_stats import angular_similarity from dipy.reconst.tests.test_dsi import sticks_and_ball_dummies +from dipy.testing import setup_test def test_dsi(): # load symmetric 724 sphere sphere = get_sphere('symmetric724') # load icosahedron sphere sphere2 = create_unit_sphere(5) - btable = np.loadtxt(get_data('dsi515btable')) + btable = np.loadtxt(get_fnames('dsi515btable')) gtab = gradient_table(btable[:, 0], btable[:, 1:]) data, golden_directions = SticksAndBall(gtab, d=0.0015, S0=100, angles=[(0, 0), (90, 0)], diff --git a/dipy/reconst/tests/test_dsi_metrics.py b/dipy/reconst/tests/test_dsi_metrics.py index e1d67e82f0..b71bf8bffc 100644 --- a/dipy/reconst/tests/test_dsi_metrics.py +++ b/dipy/reconst/tests/test_dsi_metrics.py @@ -1,15 +1,17 @@ import numpy as np from dipy.reconst.dsi import DiffusionSpectrumModel -from dipy.data import get_data +from dipy.data import get_fnames from dipy.core.gradients import gradient_table from numpy.testing import (assert_almost_equal, run_module_suite) from dipy.sims.voxel import (SticksAndBall, MultiTensor) +from dipy.testing import setup_test + def test_dsi_metrics(): - btable = np.loadtxt(get_data('dsi4169btable')) + btable = np.loadtxt(get_fnames('dsi4169btable')) gtab = gradient_table(btable[:, 0], btable[:, 1:]) data, golden_directions = SticksAndBall(gtab, d=0.0015, S0=100, angles=[(0, 0), (60, 0)], diff --git a/dipy/reconst/tests/test_dti.py b/dipy/reconst/tests/test_dti.py index 99c709bf89..270eead4fd 100644 --- a/dipy/reconst/tests/test_dti.py +++ b/dipy/reconst/tests/test_dti.py @@ -23,7 +23,7 @@ _decompose_tensor_nan) from dipy.io.bvectxt import read_bvec_file -from dipy.data import get_data, dsi_voxels, get_sphere +from dipy.data import get_fnames, dsi_voxels, get_sphere from dipy.core.subdivide_octahedron import create_unit_sphere import dipy.core.gradients as grad @@ -54,7 +54,7 @@ def test_tensor_algebra(): def test_odf_with_zeros(): - fdata, fbval, fbvec = get_data('small_25') + fdata, fbval, fbvec = get_fnames('small_25') gtab = grad.gradient_table(fbval, fbvec) data = nib.load(fdata).get_data() dm = dti.TensorModel(gtab) @@ -66,7 +66,7 @@ def test_odf_with_zeros(): def test_tensor_model(): - fdata, fbval, fbvec = get_data('small_25') + fdata, fbval, fbvec = get_fnames('small_25') data1 = nib.load(fdata).get_data() gtab1 = grad.gradient_table(fbval, fbvec) data2, gtab2 = dsi_voxels() @@ -109,7 +109,7 @@ def test_tensor_model(): # Make some synthetic data b0 = 1000. - bvecs, bvals = read_bvec_file(get_data('55dir_grad.bvec')) + bvecs, bvals = read_bvec_file(get_fnames('55dir_grad.bvec')) gtab = grad.gradient_table_from_bvals_bvecs(bvals, bvecs.T) # The first b value is 0., so we take the second one: B = bvals[1] @@ -337,7 +337,7 @@ def test_wls_and_ls_fit(): # Recall: D = [Dxx,Dyy,Dzz,Dxy,Dxz,Dyz,log(S_0)] and D ~ 10^-4 mm^2 /s b0 = 1000. - bvec, bval = read_bvec_file(get_data('55dir_grad.bvec')) + bvec, bval = read_bvec_file(get_fnames('55dir_grad.bvec')) B = bval[1] # Scale the eigenvalues and tensor by the B value so the units match D = np.array([1., 1., 1., 0., 0., 1., -np.log(b0) * B]) / B @@ -396,7 +396,7 @@ def test_masked_array_with_tensor(): mask = np.array([[True, False, False, True], [True, False, True, False]]) - bvec, bval = read_bvec_file(get_data('55dir_grad.bvec')) + bvec, bval = read_bvec_file(get_fnames('55dir_grad.bvec')) gtab = grad.gradient_table_from_bvals_bvecs(bval, bvec.T) tensor_model = TensorModel(gtab) @@ -421,7 +421,7 @@ def test_masked_array_with_tensor(): def test_fit_method_error(): - bvec, bval = read_bvec_file(get_data('55dir_grad.bvec')) + bvec, bval = read_bvec_file(get_fnames('55dir_grad.bvec')) gtab = grad.gradient_table_from_bvals_bvecs(bval, bvec.T) # This should work (smoke-testing!): @@ -466,7 +466,7 @@ def test_from_lower_triangular(): def test_all_constant(): - bvecs, bvals = read_bvec_file(get_data('55dir_grad.bvec')) + bvecs, bvals = read_bvec_file(get_fnames('55dir_grad.bvec')) gtab = grad.gradient_table_from_bvals_bvecs(bvals, bvecs.T) fit_methods = ['LS', 'OLS', 'NNLS', 'RESTORE'] for _ in fit_methods: @@ -477,7 +477,7 @@ def test_all_constant(): def test_all_zeros(): - bvecs, bvals = read_bvec_file(get_data('55dir_grad.bvec')) + bvecs, bvals = read_bvec_file(get_fnames('55dir_grad.bvec')) gtab = grad.gradient_table_from_bvals_bvecs(bvals, bvecs.T) fit_methods = ['LS', 'OLS', 'NNLS', 'RESTORE'] for _ in fit_methods: @@ -525,7 +525,7 @@ def test_mask(): def test_nnls_jacobian_fucn(): b0 = 1000. - bvecs, bval = read_bvec_file(get_data('55dir_grad.bvec')) + bvecs, bval = read_bvec_file(get_fnames('55dir_grad.bvec')) gtab = grad.gradient_table(bval, bvecs) B = bval[1] @@ -562,7 +562,7 @@ def test_nlls_fit_tensor(): """ b0 = 1000. - bvecs, bval = read_bvec_file(get_data('55dir_grad.bvec')) + bvecs, bval = read_bvec_file(get_fnames('55dir_grad.bvec')) gtab = grad.gradient_table(bval, bvecs) B = bval[1] @@ -609,7 +609,7 @@ def test_nlls_fit_tensor(): npt.assert_raises(ValueError, tensor_model.fit, Y) # Use NLLS with some actual 4D data: - data, bvals, bvecs = get_data('small_25') + data, bvals, bvecs = get_fnames('small_25') gtab = grad.gradient_table(bvals, bvecs) tm1 = dti.TensorModel(gtab, fit_method='NLLS') dd = nib.load(data).get_data() @@ -625,7 +625,7 @@ def test_restore(): Test the implementation of the RESTORE algorithm """ b0 = 1000. - bvecs, bval = read_bvec_file(get_data('55dir_grad.bvec')) + bvecs, bval = read_bvec_file(get_fnames('55dir_grad.bvec')) gtab = grad.gradient_table(bval, bvecs) B = bval[1] @@ -712,7 +712,7 @@ def test_predict(): assert_array_almost_equal(dmfit.predict(gtab), S) assert_array_almost_equal(dm.predict(dmfit.model_params, S0=100), S) - fdata, fbvals, fbvecs = get_data() + fdata, fbvals, fbvecs = get_fnames() data = nib.load(fdata).get_data() # Make the data cube a bit larger: data = np.tile(data.T, 2).T @@ -773,8 +773,9 @@ def test_eig_from_lo_tri(): lo_tri = lower_triangular(dmfit.quadratic_form) assert_array_almost_equal(dti.eig_from_lo_tri(lo_tri), dmfit.model_params) + def test_min_signal_alone(): - fdata, fbvals, fbvecs = get_data() + fdata, fbvals, fbvecs = get_fnames() data = nib.load(fdata).get_data() gtab = grad.gradient_table(fbvals, fbvecs) @@ -782,7 +783,9 @@ def test_min_signal_alone(): ten_model = dti.TensorModel(gtab) fit_alone = ten_model.fit(data[idx]) fit_together = ten_model.fit(data) - npt.assert_array_almost_equal(fit_together.model_params[idx], fit_alone.model_params, decimal=12) + npt.assert_almost_equal(fit_together.model_params[idx], + fit_alone.model_params) + def test_decompose_tensor_nan(): D_fine = np.array([1.7e-3, 0.0, 0.3e-3, 0.0, 0.0, 0.2e-3]) diff --git a/dipy/reconst/tests/test_forecast.py b/dipy/reconst/tests/test_forecast.py index 9cb08d907d..290d580950 100644 --- a/dipy/reconst/tests/test_forecast.py +++ b/dipy/reconst/tests/test_forecast.py @@ -50,11 +50,11 @@ def test_forecast_positive_constrain(): sphere = get_sphere('repulsion100') fodf = f_fit.odf(sphere, clip_negative=False) - assert_equal(fodf[fodf < 0].sum(), 0) + assert_almost_equal(fodf[fodf < 0].sum(), 0, 2) coeff = f_fit.sh_coeff c0 = np.sqrt(1.0/(4*np.pi)) - assert_almost_equal(coeff[0], c0, 10) + assert_almost_equal(coeff[0], c0, 5) def test_forecast_csd(): diff --git a/dipy/reconst/tests/test_fwdti.py b/dipy/reconst/tests/test_fwdti.py index 4c5f68d641..28e52f48f9 100644 --- a/dipy/reconst/tests/test_fwdti.py +++ b/dipy/reconst/tests/test_fwdti.py @@ -17,9 +17,9 @@ all_tensor_evecs, multi_tensor_dki) from dipy.io.gradients import read_bvals_bvecs from dipy.core.gradients import gradient_table -from dipy.data import get_data +from dipy.data import get_fnames -fimg, fbvals, fbvecs = get_data('small_64D') +fimg, fbvals, fbvecs = get_fnames('small_64D') bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) diff --git a/dipy/reconst/tests/test_gqi.py b/dipy/reconst/tests/test_gqi.py index 4e6709c1a2..61a74f3c68 100644 --- a/dipy/reconst/tests/test_gqi.py +++ b/dipy/reconst/tests/test_gqi.py @@ -1,5 +1,5 @@ import numpy as np -from dipy.data import get_data, dsi_voxels +from dipy.data import get_fnames, dsi_voxels from dipy.core.sphere import Sphere from dipy.core.gradients import gradient_table from dipy.sims.voxel import SticksAndBall @@ -20,7 +20,7 @@ def test_gqi(): sphere = get_sphere('symmetric724') # load icosahedron sphere sphere2 = create_unit_sphere(5) - btable = np.loadtxt(get_data('dsi515btable')) + btable = np.loadtxt(get_fnames('dsi515btable')) bvals = btable[:, 0] bvecs = btable[:, 1:] gtab = gradient_table(bvals, bvecs) diff --git a/dipy/reconst/tests/test_ivim.py b/dipy/reconst/tests/test_ivim.py index a1e45631b8..bec63b2483 100644 --- a/dipy/reconst/tests/test_ivim.py +++ b/dipy/reconst/tests/test_ivim.py @@ -14,7 +14,7 @@ import numpy as np from numpy.testing import (assert_array_equal, assert_array_almost_equal, assert_raises, assert_array_less, run_module_suite, - assert_warns, dec) + dec) from dipy.reconst.ivim import ivim_prediction, IvimModel from dipy.core.gradients import gradient_table, generate_bvecs @@ -31,7 +31,7 @@ 500., 600., 700., 800., 900., 1000.]) N = len(bvals) bvecs = generate_bvecs(N) -gtab = gradient_table(bvals, bvecs.T) +gtab = gradient_table(bvals, bvecs.T, b0_threshold=0) S0, f, D_star, D = 1000.0, 0.132, 0.00885, 0.000921 # params for a single voxel @@ -65,7 +65,7 @@ 500., 600., 700., 800., 900., 1000.]) bvecs_no_b0 = generate_bvecs(N) -gtab_no_b0 = gradient_table(bvals_no_b0, bvecs.T) +gtab_no_b0 = gradient_table(bvals_no_b0, bvecs.T, b0_threshold=0) bvals_with_multiple_b0 = np.array([0., 0., 0., 0., 40., 60., 80., 100., 120., 140., 160., 180., 200., 300., 400., @@ -73,10 +73,12 @@ bvecs_with_multiple_b0 = generate_bvecs(N) gtab_with_multiple_b0 = gradient_table(bvals_with_multiple_b0, - bvecs_with_multiple_b0.T) + bvecs_with_multiple_b0.T, + b0_threshold=0) noisy_single = np.array([4243.71728516, 4317.81298828, 4244.35693359, - 4439.36816406, 4420.06201172, 4152.30078125, 4114.34912109, 4104.59375, 4151.61914062, + 4439.36816406, 4420.06201172, 4152.30078125, + 4114.34912109, 4104.59375, 4151.61914062, 4003.58374023, 4013.68408203, 3906.39428711, 3909.06079102, 3495.27197266, 3402.57006836, 3163.10180664, 2896.04003906, 2663.7253418, @@ -200,6 +202,23 @@ def test_with_higher_S0(): assert_array_almost_equal(ivim_fit.model_params, params2) +def test_b0_threshold_greater_than0(): + """ + Added test case for default b0_threshold set to 50. + Checks if error is thrown correctly. + """ + bvals_b0t = np.array([50., 10., 20., 30., 40., 60., 80., 100., + 120., 140., 160., 180., 200., 300., 400., + 500., 600., 700., 800., 900., 1000.]) + N = len(bvals_b0t) + bvecs = generate_bvecs(N) + gtab = gradient_table(bvals_b0t, bvecs.T) + with assert_raises(ValueError) as vae: + _ = IvimModel(gtab) + b0_s = "The IVIM model requires a measurement at b==0. As of " + assert b0_s in vae.exception + + def test_bounds_x0(): """ Test to check if setting bounds for signal where initial value is diff --git a/dipy/reconst/tests/test_mapmri.py b/dipy/reconst/tests/test_mapmri.py index 8c0f7e0f90..f2f55b3be4 100644 --- a/dipy/reconst/tests/test_mapmri.py +++ b/dipy/reconst/tests/test_mapmri.py @@ -1,29 +1,31 @@ +import platform +import time +from math import factorial + +from scipy.special import gamma +import scipy.integrate as integrate import numpy as np -from dipy.data import get_gtab_taiwan_dsi from numpy.testing import (assert_almost_equal, assert_array_almost_equal, assert_equal, run_module_suite, assert_raises) + +from dipy.data import get_gtab_taiwan_dsi from dipy.reconst.mapmri import MapmriModel, mapmri_index_matrix from dipy.reconst import dti, mapmri from dipy.sims.voxel import (MultiTensor, multi_tensor_pdf, single_tensor, cylinders_and_ball_soderman) -from scipy.special import gamma -from math import factorial from dipy.data import get_sphere from dipy.sims.voxel import add_noise -import scipy.integrate as integrate from dipy.core.sphere_stats import angular_similarity from dipy.direction.peaks import peak_directions from dipy.reconst.odf import gfa from dipy.reconst.tests.test_dsi import sticks_and_ball_dummies from dipy.core.subdivide_octahedron import create_unit_sphere from dipy.reconst.shm import sh_to_sf -import time - def int_func(n): f = np.sqrt(2) * factorial(n) / float(((gamma(1 + n / 2.0)) * @@ -281,9 +283,12 @@ def test_mapmri_isotropic_static_scale_factor(radial_order=6): # test if indeed the scale factor is fixed now assert_equal(np.all(mapf_scale_stat_reg_stat.mu == mu), True) - # test if computation time is shorter - assert_equal(time_scale_stat_reg_stat < time_scale_adapt_reg_stat, - True) + + # test if computation time is shorter (except on Windows): + if not platform.system() == "Windows": + assert_equal(time_scale_stat_reg_stat < time_scale_adapt_reg_stat, + True) + # check if the fitted signal is the same assert_almost_equal(mapf_scale_stat_reg_stat.fitted_signal(), mapf_scale_adapt_reg_stat.fitted_signal()) diff --git a/dipy/reconst/tests/test_qtdmri.py b/dipy/reconst/tests/test_qtdmri.py new file mode 100644 index 0000000000..6f29f709cb --- /dev/null +++ b/dipy/reconst/tests/test_qtdmri.py @@ -0,0 +1,578 @@ +import numpy as np +from dipy.data import get_gtab_taiwan_dsi +from numpy.testing import (assert_, + assert_almost_equal, + assert_array_almost_equal, + assert_equal, + assert_raises, + run_module_suite) +from dipy.reconst import qtdmri, mapmri +from dipy.sims.voxel import MultiTensor +from dipy.data import get_sphere +from dipy.sims.voxel import add_noise +import scipy.integrate as integrate +from dipy.core.gradients import gradient_table_from_qvals_bvecs + + +def generate_gtab4D(number_of_tau_shells=4, delta=0.01): + """Generates testing gradient table for 4D qt-dMRI scheme""" + gtab = get_gtab_taiwan_dsi() + qvals = np.tile(gtab.bvals / 100., number_of_tau_shells) + bvecs = np.tile(gtab.bvecs, (number_of_tau_shells, 1)) + pulse_separation = [] + for ps in np.linspace(0.02, 0.05, number_of_tau_shells): + pulse_separation = np.append(pulse_separation, + np.tile(ps, gtab.bvals.shape[0])) + pulse_duration = np.tile(delta, qvals.shape[0]) + gtab_4d = gradient_table_from_qvals_bvecs(qvals=qvals, bvecs=bvecs, + big_delta=pulse_separation, + small_delta=pulse_duration, + b0_threshold=0) + return gtab_4d + + +def generate_signal_crossing(gtab, lambda1, lambda2, lambda3, angle2=60): + mevals = np.array(([lambda1, lambda2, lambda3], + [lambda1, lambda2, lambda3])) + angl = [(0, 0), (angle2, 0)] + S, sticks = MultiTensor(gtab, mevals, S0=1.0, angles=angl, + fractions=[50, 50], snr=None) + return S + + +def test_input_parameters(): + gtab_4d = generate_gtab4D() + + # uneven radial order + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, radial_order=3) + + # negative radial order + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, radial_order=-1) + + # negative time order + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, time_order=-1) + + # non-bool laplacian_regularization + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, + laplacian_regularization='test') + + # 'non-"GCV" string for laplacian_weighting + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, + laplacian_regularization=True, + laplacian_weighting='test') + + # negative laplacian_weighting + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, + laplacian_regularization=True, + laplacian_weighting=-1.) + + # non-bool for l1_weighting + assert_raises(ValueError, qtdmri.QtdmriModel, + gtab_4d, l1_regularization='test') + + # non-"CV" string for laplacian_weighting + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, + l1_regularization=True, + l1_weighting='test') + + # negative l1_weighting is caught + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, + l1_regularization=True, + l1_weighting=-1.) + + # non-bool cartesian is caught + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, + cartesian='test') + + # non-bool anisotropic_scaling is caught + assert_raises(ValueError, qtdmri.QtdmriModel, + gtab_4d, anisotropic_scaling='test') + + # non-bool constrain_q0 is caught + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, constrain_q0='test') + + # negative bval_threshold is caught + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, bval_threshold=-1) + + # negative eigenvalue_threshold is caught + assert_raises(ValueError, qtdmri.QtdmriModel, + gtab_4d, eigenvalue_threshold=-1) + + # unavailable cvxpy solver is caught + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, + laplacian_regularization=True, + cvxpy_solver='test') + + # non-normalized non-cartesian l1-regularization is caught + assert_raises(ValueError, qtdmri.QtdmriModel, gtab_4d, + l1_regularization=True, cartesian=False, + normalization=False) + + +def test_orthogonality_temporal_basis_functions(): + # numerical integration parameters + ut = 10 + tmin = 0 + tmax = 100 + + int1 = integrate.quad(lambda t: + qtdmri.temporal_basis(1, ut, t) * + qtdmri.temporal_basis(2, ut, t), tmin, tmax) + int2 = integrate.quad(lambda t: + qtdmri.temporal_basis(2, ut, t) * + qtdmri.temporal_basis(3, ut, t), tmin, tmax) + int3 = integrate.quad(lambda t: + qtdmri.temporal_basis(3, ut, t) * + qtdmri.temporal_basis(4, ut, t), tmin, tmax) + int4 = integrate.quad(lambda t: + qtdmri.temporal_basis(4, ut, t) * + qtdmri.temporal_basis(5, ut, t), tmin, tmax) + + assert_almost_equal(int1, 0.) + assert_almost_equal(int2, 0.) + assert_almost_equal(int3, 0.) + assert_almost_equal(int4, 0.) + + +def test_normalization_time(): + ut = 10 + tmin = 0 + tmax = 100 + + int0 = integrate.quad(lambda t: + qtdmri.qtdmri_temporal_normalization(ut) ** 2 * + qtdmri.temporal_basis(0, ut, t) * + qtdmri.temporal_basis(0, ut, t), tmin, tmax)[0] + int1 = integrate.quad(lambda t: + qtdmri.qtdmri_temporal_normalization(ut) ** 2 * + qtdmri.temporal_basis(1, ut, t) * + qtdmri.temporal_basis(1, ut, t), tmin, tmax)[0] + int2 = integrate.quad(lambda t: + qtdmri.qtdmri_temporal_normalization(ut) ** 2 * + qtdmri.temporal_basis(2, ut, t) * + qtdmri.temporal_basis(2, ut, t), tmin, tmax)[0] + + assert_almost_equal(int0, 1.) + assert_almost_equal(int1, 1.) + assert_almost_equal(int2, 1.) + + +def test_anisotropic_isotropic_equivalence(radial_order=4, time_order=2): + # generate qt-scheme and arbitary synthetic crossing data. + gtab_4d = generate_gtab4D() + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + + # initialize both cartesian and spherical models without any kind of + # regularization + qtdmri_mod_aniso = qtdmri.QtdmriModel(gtab_4d, radial_order=radial_order, + time_order=time_order, + cartesian=True, + anisotropic_scaling=False) + qtdmri_mod_iso = qtdmri.QtdmriModel(gtab_4d, radial_order=radial_order, + time_order=time_order, + cartesian=False, + anisotropic_scaling=False) + + # both implementations fit the same signal + qtdmri_fit_cart = qtdmri_mod_aniso.fit(S) + qtdmri_fit_sphere = qtdmri_mod_iso.fit(S) + + # same signal fit + assert_array_almost_equal(qtdmri_fit_cart.fitted_signal(), + qtdmri_fit_sphere.fitted_signal()) + + # same PDF reconstruction + rt_grid = qtdmri.create_rt_space_grid(5, 20e-3, 5, 0.02, .05) + pdf_aniso = qtdmri_fit_cart.pdf(rt_grid) + pdf_iso = qtdmri_fit_sphere.pdf(rt_grid) + assert_array_almost_equal(pdf_aniso / pdf_aniso.max(), + pdf_iso / pdf_aniso.max()) + + # same norm of the laplacian + norm_laplacian_aniso = qtdmri_fit_cart.norm_of_laplacian_signal() + norm_laplacian_iso = qtdmri_fit_sphere.norm_of_laplacian_signal() + assert_almost_equal(norm_laplacian_aniso / norm_laplacian_aniso, + norm_laplacian_iso / norm_laplacian_aniso) + + # all q-space index is the same for arbitrary tau + tau = 0.02 + assert_almost_equal(qtdmri_fit_cart.rtop(tau), qtdmri_fit_sphere.rtop(tau)) + assert_almost_equal(qtdmri_fit_cart.rtap(tau), qtdmri_fit_sphere.rtap(tau)) + assert_almost_equal(qtdmri_fit_cart.rtpp(tau), qtdmri_fit_sphere.rtpp(tau)) + assert_almost_equal(qtdmri_fit_cart.msd(tau), qtdmri_fit_sphere.msd(tau)) + assert_almost_equal(qtdmri_fit_cart.qiv(tau), qtdmri_fit_sphere.qiv(tau)) + + # ODF estimation is the same + sphere = get_sphere() + assert_array_almost_equal(qtdmri_fit_cart.odf(sphere, tau, s=0), + qtdmri_fit_sphere.odf(sphere, tau, s=0)) + + +def test_cartesian_normalization(radial_order=4, time_order=2): + gtab_4d = generate_gtab4D() + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + + qtdmri_mod_aniso = qtdmri.QtdmriModel(gtab_4d, radial_order=radial_order, + time_order=time_order, + cartesian=True, + normalization=False) + qtdmri_mod_aniso_norm = qtdmri.QtdmriModel(gtab_4d, + radial_order=radial_order, + time_order=time_order, + cartesian=True, + normalization=True) + qtdmri_fit_aniso = qtdmri_mod_aniso.fit(S) + qtdmri_fit_aniso_norm = qtdmri_mod_aniso_norm.fit(S) + assert_array_almost_equal(qtdmri_fit_aniso.fitted_signal(), + qtdmri_fit_aniso_norm.fitted_signal()) + rt_grid = qtdmri.create_rt_space_grid(5, 20e-3, 5, 0.02, .05) + pdf_aniso = qtdmri_fit_aniso.pdf(rt_grid) + pdf_aniso_norm = qtdmri_fit_aniso_norm.pdf(rt_grid) + assert_array_almost_equal(pdf_aniso / pdf_aniso.max(), + pdf_aniso_norm / pdf_aniso.max()) + norm_laplacian = qtdmri_fit_aniso.norm_of_laplacian_signal() + norm_laplacian_norm = qtdmri_fit_aniso_norm.norm_of_laplacian_signal() + assert_array_almost_equal(norm_laplacian / norm_laplacian, + norm_laplacian_norm / norm_laplacian) + + +def test_spherical_normalization(radial_order=4, time_order=2): + gtab_4d = generate_gtab4D() + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + + qtdmri_mod_aniso = qtdmri.QtdmriModel(gtab_4d, radial_order=radial_order, + time_order=time_order, + cartesian=False, + normalization=False) + qtdmri_mod_aniso_norm = qtdmri.QtdmriModel(gtab_4d, + radial_order=radial_order, + time_order=time_order, + cartesian=False, + normalization=True) + qtdmri_fit = qtdmri_mod_aniso.fit(S) + qtdmri_fit_norm = qtdmri_mod_aniso_norm.fit(S) + assert_array_almost_equal(qtdmri_fit.fitted_signal(), + qtdmri_fit_norm.fitted_signal()) + + rt_grid = qtdmri.create_rt_space_grid(5, 20e-3, 5, 0.02, .05) + pdf = qtdmri_fit.pdf(rt_grid) + pdf_norm = qtdmri_fit_norm.pdf(rt_grid) + assert_array_almost_equal(pdf / pdf.max(), + pdf_norm / pdf.max()) + + norm_laplacian = qtdmri_fit.norm_of_laplacian_signal() + norm_laplacian_norm = qtdmri_fit_norm.norm_of_laplacian_signal() + assert_array_almost_equal(norm_laplacian / norm_laplacian, + norm_laplacian_norm / norm_laplacian) + + +def test_anisotropic_reduced_MSE(radial_order=0, time_order=0): + gtab_4d = generate_gtab4D() + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + qtdmri_mod_aniso = qtdmri.QtdmriModel(gtab_4d, radial_order=radial_order, + time_order=time_order, + cartesian=True, + anisotropic_scaling=True) + qtdmri_mod_iso = qtdmri.QtdmriModel(gtab_4d, radial_order=radial_order, + time_order=time_order, + cartesian=True, + anisotropic_scaling=False) + qtdmri_fit_aniso = qtdmri_mod_aniso.fit(S) + qtdmri_fit_iso = qtdmri_mod_iso.fit(S) + mse_aniso = np.mean((S - qtdmri_fit_aniso.fitted_signal()) ** 2) + mse_iso = np.mean((S - qtdmri_fit_iso.fitted_signal()) ** 2) + assert_(mse_aniso < mse_iso) + + +def test_number_of_coefficients(radial_order=4, time_order=2): + gtab_4d = generate_gtab4D() + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + qtdmri_mod = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order) + qtdmri_fit = qtdmri_mod.fit(S) + number_of_coef_model = qtdmri_fit._qtdmri_coef.shape[0] + number_of_coef_analytic = qtdmri.qtdmri_number_of_coefficients( + radial_order, time_order + ) + assert_equal(number_of_coef_model, number_of_coef_analytic) + + +def test_calling_cartesian_laplacian_with_precomputed_matrices( + radial_order=4, time_order=2, ut=2e-3, us=np.r_[1e-3, 2e-3, 3e-3]): + ind_mat = qtdmri.qtdmri_index_matrix(radial_order, time_order) + part4_reg_mat_tau = qtdmri.part4_reg_matrix_tau(ind_mat, 1.) + part23_reg_mat_tau = qtdmri.part23_reg_matrix_tau(ind_mat, 1.) + part1_reg_mat_tau = qtdmri.part1_reg_matrix_tau(ind_mat, 1.) + S_mat, T_mat, U_mat = mapmri.mapmri_STU_reg_matrices(radial_order) + + laplacian_matrix_precomputed = qtdmri.qtdmri_laplacian_reg_matrix( + ind_mat, us, ut, S_mat, T_mat, U_mat, + part1_reg_mat_tau, part23_reg_mat_tau, part4_reg_mat_tau + ) + laplacian_matrix_regular = qtdmri.qtdmri_laplacian_reg_matrix( + ind_mat, us, ut) + assert_array_almost_equal(laplacian_matrix_precomputed, + laplacian_matrix_regular) + + +def test_calling_spherical_laplacian_with_precomputed_matrices( + radial_order=4, time_order=2, ut=2e-3, us=np.r_[2e-3, 2e-3, 2e-3]): + ind_mat = qtdmri.qtdmri_isotropic_index_matrix(radial_order, time_order) + part4_reg_mat_tau = qtdmri.part4_reg_matrix_tau(ind_mat, 1.) + part23_reg_mat_tau = qtdmri.part23_reg_matrix_tau(ind_mat, 1.) + part1_reg_mat_tau = qtdmri.part1_reg_matrix_tau(ind_mat, 1.) + part1_uq_iso_precomp = ( + mapmri.mapmri_isotropic_laplacian_reg_matrix_from_index_matrix( + ind_mat[:, :3], 1. + ) + ) + laplacian_matrix_precomp = qtdmri.qtdmri_isotropic_laplacian_reg_matrix( + ind_mat, us, ut, + part1_uq_iso_precomp=part1_uq_iso_precomp, + part1_ut_precomp=part1_reg_mat_tau, + part23_ut_precomp=part23_reg_mat_tau, + part4_ut_precomp=part4_reg_mat_tau) + laplacian_matrix_regular = qtdmri.qtdmri_isotropic_laplacian_reg_matrix( + ind_mat, us, ut) + assert_array_almost_equal(laplacian_matrix_precomp, + laplacian_matrix_regular) + + +@np.testing.dec.skipif(not qtdmri.have_cvxpy) +def test_q0_constraint_and_unity_of_ODFs(radial_order=6, time_order=2): + gtab_4d = generate_gtab4D() + tau = gtab_4d.tau + + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + # first test without regularization + qtdmri_mod_ls = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order + ) + qtdmri_fit_ls = qtdmri_mod_ls.fit(S) + fitted_signal = qtdmri_fit_ls.fitted_signal() + # only first tau_point is normalized with least squares. + E_q0_first_tau = fitted_signal[ + np.all([tau == tau.min(), gtab_4d.b0s_mask], axis=0) + ] + assert_almost_equal(float(E_q0_first_tau), 1.) + + # now with cvxpy regularization cartesian + qtdmri_mod_lap = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + laplacian_regularization=True, laplacian_weighting=1e-4 + ) + qtdmri_fit_lap = qtdmri_mod_lap.fit(S) + fitted_signal = qtdmri_fit_lap.fitted_signal() + E_q0_first_tau = fitted_signal[ + np.all([tau == tau.min(), gtab_4d.b0s_mask], axis=0) + ] + E_q0_last_tau = fitted_signal[ + np.all([tau == tau.max(), gtab_4d.b0s_mask], axis=0) + ] + assert_almost_equal(E_q0_first_tau[0], 1.) + assert_almost_equal(E_q0_last_tau[0], 1.) + + # check if odf in spherical harmonics for cartesian raises an error + try: + qtdmri_fit_lap.odf_sh(tau=tau.max()) + assert_equal(True, False) + except ValueError: + print('missing spherical harmonics cartesian ODF caught.') + + # now with cvxpy regularization spherical + qtdmri_mod_lap = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + laplacian_regularization=True, laplacian_weighting=1e-4, + cartesian=False + ) + qtdmri_fit_lap = qtdmri_mod_lap.fit(S) + fitted_signal = qtdmri_fit_lap.fitted_signal() + E_q0_first_tau = fitted_signal[ + np.all([tau == tau.min(), gtab_4d.b0s_mask], axis=0) + ] + E_q0_last_tau = fitted_signal[ + np.all([tau == tau.max(), gtab_4d.b0s_mask], axis=0) + ] + assert_almost_equal(float(E_q0_first_tau), 1.) + assert_almost_equal(float(E_q0_last_tau), 1.) + + # test if maginal ODF integral in sh is equal to one + # Integral of Y00 spherical harmonic is 1 / (2 * np.sqrt(np.pi)) + # division with this results in normalization + odf_sh = qtdmri_fit_lap.odf_sh(s=0, tau=tau.max()) + odf_integral = odf_sh[0] * (2 * np.sqrt(np.pi)) + assert_almost_equal(odf_integral, 1.) + + +@np.testing.dec.skipif(not qtdmri.have_cvxpy) +def test_laplacian_reduces_laplacian_norm(radial_order=4, time_order=2): + gtab_4d = generate_gtab4D() + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + + qtdmri_mod_no_laplacian = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + laplacian_regularization=True, laplacian_weighting=0. + ) + qtdmri_mod_laplacian = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + laplacian_regularization=True, laplacian_weighting=1e-4 + ) + + qtdmri_fit_no_laplacian = qtdmri_mod_no_laplacian.fit(S) + qtdmri_fit_laplacian = qtdmri_mod_laplacian.fit(S) + + laplacian_norm_no_reg = qtdmri_fit_no_laplacian.norm_of_laplacian_signal() + laplacian_norm_reg = qtdmri_fit_laplacian.norm_of_laplacian_signal() + + assert_(laplacian_norm_no_reg > laplacian_norm_reg) + + +@np.testing.dec.skipif(not qtdmri.have_cvxpy) +def test_spherical_laplacian_reduces_laplacian_norm(radial_order=4, + time_order=2): + gtab_4d = generate_gtab4D() + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + + qtdmri_mod_no_laplacian = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + cartesian=False, laplacian_regularization=True, laplacian_weighting=0. + ) + qtdmri_mod_laplacian = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + cartesian=False, laplacian_regularization=True, + laplacian_weighting=1e-4 + ) + + qtdmri_fit_no_laplacian = qtdmri_mod_no_laplacian.fit(S) + qtdmri_fit_laplacian = qtdmri_mod_laplacian.fit(S) + + laplacian_norm_no_reg = qtdmri_fit_no_laplacian.norm_of_laplacian_signal() + laplacian_norm_reg = qtdmri_fit_laplacian.norm_of_laplacian_signal() + + assert_(laplacian_norm_no_reg > laplacian_norm_reg) + + +@np.testing.dec.skipif(not qtdmri.have_cvxpy) +def test_laplacian_GCV_higher_weight_with_noise(radial_order=4, time_order=2): + gtab_4d = generate_gtab4D() + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + S_noise = add_noise(S, S0=1., snr=10) + + qtdmri_mod_laplacian_GCV = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + laplacian_regularization=True, laplacian_weighting="GCV" + ) + + qtdmri_fit_no_noise = qtdmri_mod_laplacian_GCV.fit(S) + qtdmri_fit_noise = qtdmri_mod_laplacian_GCV.fit(S_noise) + + assert_(qtdmri_fit_noise.lopt > qtdmri_fit_no_noise.lopt) + + +@np.testing.dec.skipif(not qtdmri.have_cvxpy) +def test_l1_increases_sparsity(radial_order=4, time_order=2): + gtab_4d = generate_gtab4D() + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + + qtdmri_mod_no_l1 = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + l1_regularization=True, l1_weighting=0. + ) + qtdmri_mod_l1 = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + l1_regularization=True, l1_weighting=.1 + ) + + qtdmri_fit_no_l1 = qtdmri_mod_no_l1.fit(S) + qtdmri_fit_l1 = qtdmri_mod_l1.fit(S) + + sparsity_abs_no_reg = qtdmri_fit_no_l1.sparsity_abs() + sparsity_abs_reg = qtdmri_fit_l1.sparsity_abs() + assert_(sparsity_abs_no_reg > sparsity_abs_reg) + + sparsity_density_no_reg = qtdmri_fit_no_l1.sparsity_density() + sparsity_density_reg = qtdmri_fit_l1.sparsity_density() + assert_(sparsity_density_no_reg > sparsity_density_reg) + + +@np.testing.dec.skipif(not qtdmri.have_cvxpy) +def test_spherical_l1_increases_sparsity(radial_order=4, time_order=2): + gtab_4d = generate_gtab4D() + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + + qtdmri_mod_no_l1 = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + l1_regularization=True, cartesian=False, normalization=True, + l1_weighting=0. + ) + qtdmri_mod_l1 = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + l1_regularization=True, cartesian=False, normalization=True, + l1_weighting=.1 + ) + + qtdmri_fit_no_l1 = qtdmri_mod_no_l1.fit(S) + qtdmri_fit_l1 = qtdmri_mod_l1.fit(S) + + sparsity_abs_no_reg = qtdmri_fit_no_l1.sparsity_abs() + sparsity_abs_reg = qtdmri_fit_l1.sparsity_abs() + assert_equal(sparsity_abs_no_reg > sparsity_abs_reg, True) + + sparsity_density_no_reg = qtdmri_fit_no_l1.sparsity_density() + sparsity_density_reg = qtdmri_fit_l1.sparsity_density() + assert_(sparsity_density_no_reg > sparsity_density_reg) + + +@np.testing.dec.skipif(not qtdmri.have_cvxpy) +def test_l1_CV(radial_order=4, time_order=2): + gtab_4d = generate_gtab4D() + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + S_noise = add_noise(S, S0=1., snr=10) + qtdmri_mod_l1_cv = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + l1_regularization=True, l1_weighting="CV" + ) + qtdmri_fit_noise = qtdmri_mod_l1_cv.fit(S_noise) + assert_(qtdmri_fit_noise.alpha >= 0) + + +@np.testing.dec.skipif(not qtdmri.have_cvxpy) +def test_elastic_GCV_CV(radial_order=4, time_order=2): + gtab_4d = generate_gtab4D() + l1, l2, l3 = [0.0015, 0.0003, 0.0003] + S = generate_signal_crossing(gtab_4d, l1, l2, l3) + S_noise = add_noise(S, S0=1., snr=10) + qtdmri_mod_elastic = qtdmri.QtdmriModel( + gtab_4d, radial_order=radial_order, time_order=time_order, + l1_regularization=True, l1_weighting="CV", + laplacian_regularization=True, laplacian_weighting="GCV" + ) + qtdmri_fit_noise = qtdmri_mod_elastic.fit(S_noise) + assert_(qtdmri_fit_noise.lopt >= 0) + assert_(qtdmri_fit_noise.alpha >= 0) + + +@np.testing.dec.skipif(not qtdmri.have_plt) +def test_visualise_gradient_table_G_Delta_rainbow(): + gtab_4d = generate_gtab4D() + qtdmri.visualise_gradient_table_G_Delta_rainbow(gtab_4d) + + gtab_4d.small_delta[4] += 0.001 # so now the gtab has multiple small_delta + assert_raises(ValueError, + qtdmri.visualise_gradient_table_G_Delta_rainbow, gtab_4d) + + +if __name__ == '__main__': + run_module_suite() diff --git a/dipy/reconst/tests/test_sfm.py b/dipy/reconst/tests/test_sfm.py index b835d3a266..bbd85ac4a5 100644 --- a/dipy/reconst/tests/test_sfm.py +++ b/dipy/reconst/tests/test_sfm.py @@ -7,6 +7,7 @@ import dipy.sims.voxel as sims import dipy.core.optimize as opt import dipy.reconst.cross_validation as xval +from dipy.io.gradients import read_bvals_bvecs def test_design_matrix(): @@ -21,7 +22,7 @@ def test_design_matrix(): @npt.dec.skipif(not sfm.has_sklearn) def test_sfm(): - fdata, fbvals, fbvecs = dpd.get_data() + fdata, fbvals, fbvecs = dpd.get_fnames() data = nib.load(fdata).get_data() gtab = grad.gradient_table(fbvals, fbvecs) for iso in [sfm.ExponentialIsotropicModel, None]: @@ -51,9 +52,8 @@ def test_sfm(): def test_predict(): SNR = 1000 S0 = 100 - _, fbvals, fbvecs = dpd.get_data('small_64D') - bvals = np.load(fbvals) - bvecs = np.load(fbvecs) + _, fbvals, fbvecs = dpd.get_fnames('small_64D') + bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = grad.gradient_table(bvals, bvecs) mevals = np.array(([0.0015, 0.0003, 0.0003], [0.0015, 0.0003, 0.0003])) @@ -73,7 +73,7 @@ def test_predict(): def test_sfm_background(): - fdata, fbvals, fbvecs = dpd.get_data() + fdata, fbvals, fbvecs = dpd.get_fnames() data = nib.load(fdata).get_data() gtab = grad.gradient_table(fbvals, fbvecs) to_fit = data[0, 0, 0] @@ -84,7 +84,7 @@ def test_sfm_background(): def test_sfm_stick(): - fdata, fbvals, fbvecs = dpd.get_data() + fdata, fbvals, fbvecs = dpd.get_fnames() data = nib.load(fdata).get_data() gtab = grad.gradient_table(fbvals, fbvecs) sfmodel = sfm.SparseFascicleModel(gtab, solver='NNLS', @@ -118,7 +118,7 @@ class EvenSillierSolver(object): def fit(self, X, y): self.coef_ = np.ones(X.shape[-1]) - fdata, fbvals, fbvecs = dpd.get_data() + fdata, fbvals, fbvecs = dpd.get_fnames() gtab = grad.gradient_table(fbvals, fbvecs) sfmodel = sfm.SparseFascicleModel(gtab, solver=SillySolver()) @@ -131,7 +131,7 @@ def fit(self, X, y): @npt.dec.skipif(not sfm.has_sklearn) def test_exponential_iso(): - fdata, fbvals, fbvecs = dpd.get_data() + fdata, fbvals, fbvecs = dpd.get_fnames() data_dti = nib.load(fdata).get_data() gtab_dti = grad.gradient_table(fbvals, fbvecs) data_multi, gtab_multi = dpd.dsi_deconv_voxels() diff --git a/dipy/reconst/tests/test_shm.py b/dipy/reconst/tests/test_shm.py index f23ebdd626..034de7bf0a 100644 --- a/dipy/reconst/tests/test_shm.py +++ b/dipy/reconst/tests/test_shm.py @@ -106,16 +106,16 @@ def test_real_sym_sh_mrtrix(): def test_real_sym_sh_basis(): # This test should do for now - # The mrtrix basis should be the same as re-ordering and re-scaling the - # fibernav basis + # The tournier07 basis should be the same as re-ordering and re-scaling the + # descoteaux07 basis new_order = [0, 5, 4, 3, 2, 1, 14, 13, 12, 11, 10, 9, 8, 7, 6] sphere = hemi_icosahedron.subdivide(2) basis, m, n = real_sym_sh_mrtrix(4, sphere.theta, sphere.phi) expected = basis[:, new_order] expected *= np.where(m == 0, 1., np.sqrt(2)) - fibernav_basis, m, n = real_sym_sh_basis(4, sphere.theta, sphere.phi) - assert_array_almost_equal(fibernav_basis, expected) + descoteaux07_basis, m, n = real_sym_sh_basis(4, sphere.theta, sphere.phi) + assert_array_almost_equal(descoteaux07_basis, expected) def test_smooth_pinv(): @@ -360,14 +360,34 @@ def test_sf_to_sh(): odf2 = sh_to_sf(odf_sh, sphere, 8) assert_array_almost_equal(odf, odf2, 2) - odf_sh = sf_to_sh(odf, sphere, 8, "mrtrix") - odf2 = sh_to_sf(odf_sh, sphere, 8, "mrtrix") + odf_sh = sf_to_sh(odf, sphere, 8, "tournier07") + odf2 = sh_to_sf(odf_sh, sphere, 8, "tournier07") assert_array_almost_equal(odf, odf2, 2) - odf_sh = sf_to_sh(odf, sphere, 8, "fibernav") - odf2 = sh_to_sf(odf_sh, sphere, 8, "fibernav") + # Test the basis naming deprecation + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always", DeprecationWarning) + odf_sh_mrtrix = sf_to_sh(odf, sphere, 8, "mrtrix") + odf2_mrtrix = sh_to_sf(odf_sh_mrtrix, sphere, 8, "mrtrix") + assert_array_almost_equal(odf, odf2_mrtrix, 2) + assert len(w) != 0 + assert issubclass(w[-1].category, DeprecationWarning) + warnings.simplefilter("default", DeprecationWarning) + + odf_sh = sf_to_sh(odf, sphere, 8, "descoteaux07") + odf2 = sh_to_sf(odf_sh, sphere, 8, "descoteaux07") assert_array_almost_equal(odf, odf2, 2) + # Test the basis naming deprecation + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always", DeprecationWarning) + odf_sh_fibernav = sf_to_sh(odf, sphere, 8, "fibernav") + odf2_fibernav = sh_to_sf(odf_sh_fibernav, sphere, 8, "fibernav") + assert_array_almost_equal(odf, odf2_fibernav, 2) + assert len(w) != 0 + assert issubclass(w[-1].category, DeprecationWarning) + warnings.simplefilter("default", DeprecationWarning) + # 2D case odf2d = np.vstack((odf2, odf)) odf2d_sh = sf_to_sh(odf2d, sphere, 8) @@ -375,6 +395,8 @@ def test_sf_to_sh(): assert_array_almost_equal(odf2d, odf2d_sf, 2) +test_sf_to_sh() + def test_faster_sph_harm(): sh_order = 8 @@ -457,7 +479,8 @@ def test_calculate_max_order(): for o, n in zip(orders, n_coeffs): assert_equal(calculate_max_order(n), o) + assert_raises(ValueError, calculate_max_order, 29) -if __name__ == "__main__": - import nose - nose.runmodule() +#if __name__ == "__main__": +# import nose +# nose.runmodule() diff --git a/dipy/reconst/tests/test_shore.py b/dipy/reconst/tests/test_shore.py index d25771e6fe..daaaeeff9a 100644 --- a/dipy/reconst/tests/test_shore.py +++ b/dipy/reconst/tests/test_shore.py @@ -53,7 +53,7 @@ def test_shore_positive_constrain(): pos_radius=20e-03) asmfit = asm.fit(data.S) eap = asmfit.pdf_grid(11, 20e-03) - assert_equal(eap[eap < 0].sum(), 0) + assert_almost_equal(eap[eap < 0].sum(), 0, 3) def test_shore_fitting_no_constrain_e0(): diff --git a/dipy/segment/benchmarks/bench_quickbundles.py b/dipy/segment/benchmarks/bench_quickbundles.py index 3d4cc873d6..ac7b3b214c 100644 --- a/dipy/segment/benchmarks/bench_quickbundles.py +++ b/dipy/segment/benchmarks/bench_quickbundles.py @@ -16,7 +16,7 @@ import numpy as np import nibabel as nib -from dipy.data import get_data +from dipy.data import get_fnames import dipy.tracking.streamline as streamline_utils from dipy.segment.metric import Metric @@ -43,7 +43,7 @@ def bench_quickbundles(): repeat = 10 nb_points = 12 - streams, hdr = nib.trackvis.read(get_data('fornix')) + streams, hdr = nib.trackvis.read(get_fnames('fornix')) fornix = [s[0].astype(dtype) for s in streams] fornix = streamline_utils.set_number_of_points(fornix, nb_points) diff --git a/dipy/segment/bundles.py b/dipy/segment/bundles.py index de92eeffed..5efed50667 100644 --- a/dipy/segment/bundles.py +++ b/dipy/segment/bundles.py @@ -11,14 +11,71 @@ from time import time from itertools import chain -from dipy.tracking.streamline import Streamlines +from dipy.tracking.streamline import Streamlines, length from nibabel.affines import apply_affine +def check_range(streamline, gt, lt): + length_s = length(streamline) + if (length_s > gt) & (length_s < lt): + return True + else: + return False + + +def bundle_adjacency(dtracks0, dtracks1, threshold): + """ Find bundle adjacency between two given tracks/bundles + + Parameters + ---------- + dtracks0 : Streamlines + dtracks1 : Streamlines + threshold: float + References + ---------- + .. [Garyfallidis12] Garyfallidis E. et al., QuickBundles a method for + tractography simplification, Frontiers in Neuroscience, + vol 6, no 175, 2012. + """ + d01 = bundles_distances_mdf(dtracks0, dtracks1) + + pair12 = [] + + for i in range(len(dtracks0)): + if np.min(d01[i, :]) < threshold: + j = np.argmin(d01[i, :]) + pair12.append((i, j)) + + pair12 = np.array(pair12) + pair21 = [] + + # solo2 = [] + for i in range(len(dtracks1)): + if np.min(d01[:, i]) < threshold: + j = np.argmin(d01[:, i]) + pair21.append((i, j)) + + pair21 = np.array(pair21) + A = len(pair12) / np.float(len(dtracks0)) + B = len(pair21) / np.float(len(dtracks1)) + res = 0.5 * (A + B) + return res + + +def ba_analysis(recognized_bundle, expert_bundle, threshold=2.): + + recognized_bundle = set_number_of_points(recognized_bundle, 20) + + expert_bundle = set_number_of_points(expert_bundle, 20) + + return bundle_adjacency(recognized_bundle, expert_bundle, threshold) + + class RecoBundles(object): - def __init__(self, streamlines, cluster_map=None, clust_thr=15, nb_pts=20, - seed=42, verbose=True): + def __init__(self, streamlines, greater_than=50, less_than=1000000, + cluster_map=None, clust_thr=15, nb_pts=20, + rng=None, verbose=True): """ Recognition of bundles Extract bundles from a participants' tractograms using model bundles @@ -29,12 +86,17 @@ def __init__(self, streamlines, cluster_map=None, clust_thr=15, nb_pts=20, ---------- streamlines : Streamlines The tractogram in which you want to recognize bundles. + greater_than : int, optional + Keep streamlines that have length greater than + this value (default 50) + less_than : int, optional + Keep streamlines have length less than this value (default 1000000) cluster_map : QB map Provide existing clustering to start RB faster (default None). clust_thr : float Distance threshold in mm for clustering `streamlines` - seed : int - Setup for random number generator (default 42). + rng : RandomState + If None define RandomState in initialization function. nb_pts : int Number of points per streamline (default 20) @@ -51,16 +113,28 @@ def __init__(self, streamlines, cluster_map=None, clust_thr=15, nb_pts=20, bundles using local and global streamline-based registration and clustering, Neuroimage, 2017. """ - self.streamlines = streamlines - + map_ind = np.zeros(len(streamlines)) + for i in range(len(streamlines)): + map_ind[i] = check_range(streamlines[i], greater_than, less_than) + map_ind = map_ind.astype(bool) + + self.orig_indices = np.array(list(range(0, len(streamlines)))) + self.filtered_indices = np.array(self.orig_indices[map_ind]) + self.streamlines = Streamlines(streamlines[map_ind]) + print("target brain streamlines length = ", len(streamlines)) + print("After refining target brain streamlines length = ", + len(self.streamlines)) self.nb_streamlines = len(self.streamlines) self.verbose = verbose self.start_thr = [40, 25, 20] + if rng is None: + self.rng = np.random.RandomState() + else: + self.rng = rng if cluster_map is None: - self._cluster_streamlines(clust_thr=clust_thr, nb_pts=nb_pts, - seed=seed) + self._cluster_streamlines(clust_thr=clust_thr, nb_pts=nb_pts) else: if self.verbose: t = time() @@ -77,9 +151,7 @@ def __init__(self, streamlines, cluster_map=None, clust_thr=15, nb_pts=20, print(' Total loading duration %0.3f sec. \n' % (time() - t,)) - def _cluster_streamlines(self, clust_thr, nb_pts, seed): - - rng = np.random.RandomState(seed=seed) + def _cluster_streamlines(self, clust_thr, nb_pts): if self.verbose: t = time() @@ -93,7 +165,8 @@ def _cluster_streamlines(self, clust_thr, nb_pts, seed): thresholds = self.start_thr + [clust_thr] merged_cluster_map = qbx_and_merge(self.streamlines, thresholds, - nb_pts, None, rng, self.verbose) + nb_pts, None, self.rng, + self.verbose) self.cluster_map = merged_cluster_map self.centroids = merged_cluster_map.centroids @@ -106,7 +179,7 @@ def _cluster_streamlines(self, clust_thr, nb_pts, seed): print(' Total duration %0.3f sec. \n' % (time() - t,)) def recognize(self, model_bundle, model_clust_thr, - reduction_thr=20, + reduction_thr=10, reduction_distance='mdf', slr=True, slr_metric=None, @@ -114,7 +187,7 @@ def recognize(self, model_bundle, model_clust_thr, slr_bounds=None, slr_select=(400, 600), slr_method='L-BFGS-B', - pruning_thr=10, + pruning_thr=5, pruning_distance='mdf'): """ Recognize the model_bundle in self.streamlines @@ -148,8 +221,6 @@ def recognize(self, model_bundle, model_clust_thr, Recognized bundle in the space of the model tractogram recognized_labels : array Indices of recognized bundle in the original tractogram - recognized_bundle : Streamlines - Recognized bundle in the space of the original tractogram References ---------- @@ -157,6 +228,7 @@ def recognize(self, model_bundle, model_clust_thr, bundles using local and global streamline-based registration and clustering, Neuroimage, 2017. """ + if self.verbose: t = time() print('## Recognize given bundle ## \n') @@ -164,14 +236,18 @@ def recognize(self, model_bundle, model_clust_thr, model_centroids = self._cluster_model_bundle( model_bundle, model_clust_thr=model_clust_thr) + neighb_streamlines, neighb_indices = self._reduce_search_space( model_centroids, reduction_thr=reduction_thr, reduction_distance=reduction_distance) + if len(neighb_streamlines) == 0: - return Streamlines([]), [], Streamlines([]) + return Streamlines([]), [] + if slr: - transf_streamlines = self._register_neighb_to_model( + + transf_streamlines, slr1_bmd = self._register_neighb_to_model( model_bundle, neighb_streamlines, metric=slr_metric, @@ -180,6 +256,7 @@ def recognize(self, model_bundle, model_clust_thr, select_model=slr_select[0], select_target=slr_select[1], method=slr_method) + else: transf_streamlines = neighb_streamlines @@ -193,9 +270,163 @@ def recognize(self, model_bundle, model_clust_thr, if self.verbose: print('Total duration of recognition time is %0.3f sec.\n' % (time()-t,)) - # return recognized bundle in original streamlines, labels of - # recognized bundle and transformed recognized bundle - return pruned_streamlines, labels, self.streamlines[labels] + # return recognized bundle, labels of + # recognized bundle + + return pruned_streamlines, self.filtered_indices[labels] + + def refine(self, model_bundle, pruned_streamlines, model_clust_thr, + reduction_thr=14, + reduction_distance='mdf', + slr=True, + slr_metric=None, + slr_x0=None, + slr_bounds=None, + slr_select=(400, 600), + slr_method='L-BFGS-B', + pruning_thr=6, + pruning_distance='mdf'): + """ Refine and recognize the model_bundle in self.streamlines + This method expects once pruned streamlines as input. It refines the + first ouput of recobundle by applying second local slr (optional), + and second pruning. This method is useful when we are dealing with + noisy data or when we want to extract small tracks from tractograms. + + Parameters + ---------- + model_bundle : Streamlines + pruned_streamlines : Streamlines + model_clust_thr : float + reduction_thr : float + reduction_distance : string + mdf or mam (default mam) + slr : bool + Use Streamline-based Linear Registration (SLR) locally + (default True) + slr_metric : BundleMinDistanceMetric + slr_x0 : array + (default None) + slr_bounds : array + (default None) + slr_select : tuple + Select the number of streamlines from model to neirborhood of + model to perform the local SLR. + slr_method : string + Optimization method (default 'L-BFGS-B') + pruning_thr : float + pruning_distance : string + MDF ('mdf') and MAM ('mam') + + Returns + ------- + recognized_transf : Streamlines + Recognized bundle in the space of the model tractogram + recognized_labels : array + Indices of recognized bundle in the original tractogram + + References + ---------- + .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter + bundles using local and global streamline-based registration and + clustering, Neuroimage, 2017. + """ + + if self.verbose: + t = time() + print('## Refine recognize given bundle ## \n') + + model_centroids = self._cluster_model_bundle( + model_bundle, + model_clust_thr=model_clust_thr) + + pruned_model_centroids = self._cluster_model_bundle( + pruned_streamlines, + model_clust_thr=model_clust_thr) + + neighb_streamlines, neighb_indices = self._reduce_search_space( + pruned_model_centroids, + reduction_thr=reduction_thr, + reduction_distance=reduction_distance) + + if len(neighb_streamlines) == 0: # if no streamlines recognized + return Streamlines([]), [] + + if self.verbose: + print("2nd local Slr") + + if slr: + transf_streamlines, slr2_bmd = self._register_neighb_to_model( + model_bundle, + neighb_streamlines, + metric=slr_metric, + x0=slr_x0, + bounds=slr_bounds, + select_model=slr_select[0], + select_target=slr_select[1], + method=slr_method) + + if self.verbose: + print("pruning after 2nd local Slr") + + pruned_streamlines, labels = self._prune_what_not_in_model( + model_centroids, + transf_streamlines, + neighb_indices, + pruning_thr=pruning_thr, + pruning_distance=pruning_distance) + + if self.verbose: + print('Total duration of recognition time is %0.3f sec.\n' + % (time()-t,)) + + return pruned_streamlines, self.filtered_indices[labels] + + def evaluate_results(self, model_bundle, pruned_streamlines, slr_select): + """ Comapare the similiarity between two given bundles, model bundle, + and extracted bundle. + + Parameters + ---------- + model_bundle : Streamlines + pruned_streamlines : Streamlines + slr_select : tuple + Select the number of streamlines from model to neirborhood of + model to perform the local SLR. + + Returns + ------- + ba_value : float + bundle analytics value between model bundle and pruned bundle + bmd_value : float + bundle minimum distance value between model bundle and + pruned bundle + """ + + spruned_streamlines = Streamlines(pruned_streamlines) + recog_centroids = self._cluster_model_bundle( + spruned_streamlines, + model_clust_thr=1.25) + mod_centroids = self._cluster_model_bundle( + model_bundle, + model_clust_thr=1.25) + recog_centroids = Streamlines(recog_centroids) + model_centroids = Streamlines(mod_centroids) + ba_value = ba_analysis(recog_centroids, model_centroids, threshold=10) + + BMD = BundleMinDistanceMetric() + static = select_random_set_of_streamlines(model_bundle, + slr_select[0]) + moving = select_random_set_of_streamlines(pruned_streamlines, + slr_select[1]) + nb_pts = 20 + static = set_number_of_points(static, nb_pts) + moving = set_number_of_points(moving, nb_pts) + + BMD.setup(static, moving) + x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0]) # affine + bmd_value = BMD.distance(x0.tolist()) + + return ba_value, bmd_value def _cluster_model_bundle(self, model_bundle, model_clust_thr, nb_pts=20, select_randomly=500000): @@ -211,7 +442,7 @@ def _cluster_model_bundle(self, model_bundle, model_clust_thr, nb_pts=20, model_cluster_map = qbx_and_merge(model_bundle, thresholds, nb_pts=nb_pts, select_randomly=select_randomly, - rng=None, + rng=self.rng, verbose=self.verbose) model_centroids = model_cluster_map.centroids nb_model_centroids = len(model_centroids) @@ -293,9 +524,9 @@ def _register_neighb_to_model(self, model_bundle, neighb_streamlines, # TODO this can be speeded up by using directly the centroids static = select_random_set_of_streamlines(model_bundle, - select_model) + select_model, rng=self.rng) moving = select_random_set_of_streamlines(neighb_streamlines, - select_target) + select_target, rng=self.rng) static = set_number_of_points(static, nb_pts) moving = set_number_of_points(moving, nb_pts) @@ -326,7 +557,7 @@ def _register_neighb_to_model(self, model_bundle, neighb_streamlines, print(' Duration %0.3f sec. \n' % (time() - t,)) - return transf_streamlines + return transf_streamlines, slr_bmd def _prune_what_not_in_model(self, model_centroids, transf_streamlines, @@ -348,7 +579,7 @@ def _prune_what_not_in_model(self, model_centroids, rtransf_cluster_map = qbx_and_merge(transf_streamlines, thresholds, nb_pts=20, select_randomly=500000, - rng=None, + rng=self.rng, verbose=self.verbose) if self.verbose: @@ -381,8 +612,7 @@ def _prune_what_not_in_model(self, model_centroids, pruned_indices = [rtransf_cluster_map[i].indices for i in np.where(mins != np.inf)[0]] pruned_indices = list(chain(*pruned_indices)) - pruned_streamlines = [transf_streamlines[i] - for i in pruned_indices] + pruned_streamlines = transf_streamlines[np.array(pruned_indices)] initial_indices = list(chain(*neighb_indices)) final_indices = [initial_indices[i] for i in pruned_indices] diff --git a/dipy/segment/clustering.py b/dipy/segment/clustering.py index a595401dad..c27e4ee5dc 100644 --- a/dipy/segment/clustering.py +++ b/dipy/segment/clustering.py @@ -439,9 +439,9 @@ class QuickBundles(Clustering): Examples -------- >>> from dipy.segment.clustering import QuickBundles - >>> from dipy.data import get_data + >>> from dipy.data import get_fnames >>> from nibabel import trackvis as tv - >>> streams, hdr = tv.read(get_data('fornix')) + >>> streams, hdr = tv.read(get_fnames('fornix')) >>> streamlines = [i[0] for i in streams] >>> # Segment fornix with a treshold of 10mm and streamlines resampled >>> # to 12 points. diff --git a/dipy/segment/tests/test_mask.py b/dipy/segment/tests/test_mask.py index c605f04146..1b2e667651 100644 --- a/dipy/segment/tests/test_mask.py +++ b/dipy/segment/tests/test_mask.py @@ -11,7 +11,7 @@ from numpy.testing import (assert_equal, assert_almost_equal, run_module_suite) -from dipy.data import get_data +from dipy.data import get_fnames def test_mask(): @@ -83,7 +83,7 @@ def test_bounding_box(): def test_median_otsu(): - fname = get_data('S0_10') + fname = get_fnames('S0_10') img = nib.load(fname) data = img.get_data() data = np.squeeze(data.astype('f8')) diff --git a/dipy/segment/tests/test_mrf.py b/dipy/segment/tests/test_mrf.py index a91c1d5e11..c4928a2691 100644 --- a/dipy/segment/tests/test_mrf.py +++ b/dipy/segment/tests/test_mrf.py @@ -1,6 +1,6 @@ import numpy as np import numpy.testing as npt -from dipy.data import get_data +from dipy.data import get_fnames from dipy.sims.voxel import add_noise from dipy.segment.mrf import (ConstantObservationModel, IteratedConditionalModes) @@ -8,7 +8,7 @@ # Load a coronal slice from a T1-weighted MRI -fname = get_data('t1_coronal_slice') +fname = get_fnames('t1_coronal_slice') single_slice = np.load(fname) # Stack a few copies to form a 3D volume diff --git a/dipy/segment/tests/test_qb.py b/dipy/segment/tests/test_qb.py index e906eac597..edadbccf7e 100644 --- a/dipy/segment/tests/test_qb.py +++ b/dipy/segment/tests/test_qb.py @@ -1,11 +1,11 @@ import nibabel as nib from nose.tools import assert_equal -from dipy.data import get_data +from dipy.data import get_fnames from dipy.segment.quickbundles import QuickBundles def test_qbundles(): - streams, hdr = nib.trackvis.read(get_data('fornix')) + streams, hdr = nib.trackvis.read(get_fnames('fornix')) T = [s[0] for s in streams] qb = QuickBundles(T, 10., 12) qb.virtuals() diff --git a/dipy/segment/tests/test_rb.py b/dipy/segment/tests/test_rb.py deleted file mode 100644 index 8010f714af..0000000000 --- a/dipy/segment/tests/test_rb.py +++ /dev/null @@ -1,133 +0,0 @@ -import numpy as np -import nibabel as nib -from numpy.testing import assert_equal, run_module_suite -from dipy.data import get_data -from dipy.segment.bundles import RecoBundles -from dipy.tracking.distances import bundles_distances_mam -from dipy.tracking.streamline import Streamlines -from dipy.segment.clustering import qbx_and_merge - - -streams, hdr = nib.trackvis.read(get_data('fornix')) -fornix = [s[0] for s in streams] - -f = Streamlines(fornix) -f1 = f.copy() - -f2 = f1[:20].copy() -f2._data += np.array([50, 0, 0]) - -f3 = f1[200:].copy() -f3._data += np.array([100, 0, 0]) - -f.extend(f2) -f.extend(f3) - - -def test_rb_check_defaults(): - - rb = RecoBundles(f, clust_thr=10) - rec_trans, rec_labels, recognized = rb.recognize(model_bundle=f2, - model_clust_thr=5., - reduction_thr=10) - D = bundles_distances_mam(f2, recognized) - - # check if the bundle is recognized correctly - for row in D: - assert_equal(row.min(), 0) - - -def test_rb_disable_slr(): - - rb = RecoBundles(f, clust_thr=10) - - rec_trans, rec_labels, recognized = rb.recognize(model_bundle=f2, - model_clust_thr=5., - reduction_thr=10, - slr=False) - - D = bundles_distances_mam(f2, recognized) - - # check if the bundle is recognized correctly - for row in D: - assert_equal(row.min(), 0) - - -def test_rb_no_verbose_and_mam(): - - rb = RecoBundles(f, clust_thr=10, verbose=False) - - rec_trans, rec_labels, recognized = rb.recognize(model_bundle=f2, - model_clust_thr=5., - reduction_thr=10, - slr=True, - pruning_distance='mam') - - D = bundles_distances_mam(f2, recognized) - - # check if the bundle is recognized correctly - for row in D: - assert_equal(row.min(), 0) - - -def test_rb_clustermap(): - - cluster_map = qbx_and_merge(f, thresholds=[40, 25, 20, 10]) - - rb = RecoBundles(f, cluster_map=cluster_map, clust_thr=10) - rec_trans, rec_labels, recognized = rb.recognize(model_bundle=f2, - model_clust_thr=5., - reduction_thr=10) - D = bundles_distances_mam(f2, recognized) - - # check if the bundle is recognized correctly - for row in D: - assert_equal(row.min(), 0) - - -def test_rb_no_neighb(): - # what if no neighbors are found? No recognition - - b = Streamlines(fornix) - b1 = b.copy() - - b2 = b1[:20].copy() - b2._data += np.array([100, 0, 0]) - - b3 = b1[:20].copy() - b3._data += np.array([300, 0, 0]) - - b.extend(b3) - - rb = RecoBundles(b, clust_thr=10) - rec_trans, rec_labels, recognized = rb.recognize(model_bundle=b2, - model_clust_thr=5., - reduction_thr=10) - - assert_equal(len(recognized), 0) - assert_equal(len(rec_labels), 0) - assert_equal(len(rec_trans), 0) - - -def test_rb_reduction_mam(): - - rb = RecoBundles(f, clust_thr=10, verbose=True) - - rec_trans, rec_labels, recognized = rb.recognize(model_bundle=f2, - model_clust_thr=5., - reduction_thr=10, - reduction_distance='mam', - slr=True, - slr_metric='asymmetric', - pruning_distance='mam') - - D = bundles_distances_mam(f2, recognized) - - # check if the bundle is recognized correctly - for row in D: - assert_equal(row.min(), 0) - - -if __name__ == '__main__': - - run_module_suite() diff --git a/dipy/segment/tests/test_refine_rb.py b/dipy/segment/tests/test_refine_rb.py new file mode 100644 index 0000000000..116be429cb --- /dev/null +++ b/dipy/segment/tests/test_refine_rb.py @@ -0,0 +1,207 @@ +import numpy as np +import nibabel as nib +from numpy.testing import assert_equal, run_module_suite +from dipy.data import get_fnames +from dipy.segment.bundles import RecoBundles +from dipy.tracking.distances import bundles_distances_mam +from dipy.tracking.streamline import Streamlines +from dipy.segment.clustering import qbx_and_merge + + +streams, hdr = nib.trackvis.read(get_fnames('fornix')) +fornix = [s[0] for s in streams] + +f = Streamlines(fornix) +f1 = f.copy() + +f2 = f1[:20].copy() +f2._data += np.array([50, 0, 0]) + +f3 = f1[200:].copy() +f3._data += np.array([100, 0, 0]) + +f.extend(f2) +f.extend(f3) + + +def test_rb_check_defaults(): + + rb = RecoBundles(f, greater_than=0, clust_thr=10) + + rec_trans, rec_labels = rb.recognize(model_bundle=f2, + model_clust_thr=5., + reduction_thr=10) + + D = bundles_distances_mam(f2, f[rec_labels]) + + # check if the bundle is recognized correctly + if len(f2) == len(rec_labels): + for row in D: + assert_equal(row.min(), 0) + + refine_trans, refine_labels = rb.refine(model_bundle=f2, + pruned_streamlines=rec_trans, + model_clust_thr=5., + reduction_thr=10) + + D = bundles_distances_mam(f2, f[refine_labels]) + + # check if the bundle is recognized correctly + for row in D: + assert_equal(row.min(), 0) + + +def test_rb_disable_slr(): + + rb = RecoBundles(f, greater_than=0, clust_thr=10) + + rec_trans, rec_labels = rb.recognize(model_bundle=f2, + model_clust_thr=5., + reduction_thr=10, + slr=False) + + D = bundles_distances_mam(f2, f[rec_labels]) + + # check if the bundle is recognized correctly + if len(f2) == len(rec_labels): + for row in D: + assert_equal(row.min(), 0) + + refine_trans, refine_labels = rb.refine(model_bundle=f2, + pruned_streamlines=rec_trans, + model_clust_thr=5., + reduction_thr=10) + + D = bundles_distances_mam(f2, f[refine_labels]) + + # check if the bundle is recognized correctly + for row in D: + assert_equal(row.min(), 0) + + +def test_rb_no_verbose_and_mam(): + + rb = RecoBundles(f, greater_than=0, clust_thr=10, verbose=False) + + rec_trans, rec_labels = rb.recognize(model_bundle=f2, + model_clust_thr=5., + reduction_thr=10, + slr=True, + pruning_distance='mam') + + D = bundles_distances_mam(f2, f[rec_labels]) + + # check if the bundle is recognized correctly + if len(f2) == len(rec_labels): + for row in D: + assert_equal(row.min(), 0) + + refine_trans, refine_labels = rb.refine(model_bundle=f2, + pruned_streamlines=rec_trans, + model_clust_thr=5., + reduction_thr=10) + + D = bundles_distances_mam(f2, f[refine_labels]) + + # check if the bundle is recognized correctly + for row in D: + assert_equal(row.min(), 0) + + +def test_rb_clustermap(): + + cluster_map = qbx_and_merge(f, thresholds=[40, 25, 20, 10]) + + rb = RecoBundles(f, greater_than=0, less_than=1000000, + cluster_map=cluster_map, clust_thr=10) + rec_trans, rec_labels = rb.recognize(model_bundle=f2, + model_clust_thr=5., + reduction_thr=10) + + D = bundles_distances_mam(f2, f[rec_labels]) + + # check if the bundle is recognized correctly + if len(f2) == len(rec_labels): + for row in D: + assert_equal(row.min(), 0) + + refine_trans, refine_labels = rb.refine(model_bundle=f2, + pruned_streamlines=rec_trans, + model_clust_thr=5., + reduction_thr=10) + + D = bundles_distances_mam(f2, f[refine_labels]) + + # check if the bundle is recognized correctly + for row in D: + assert_equal(row.min(), 0) + + +def test_rb_no_neighb(): + # what if no neighbors are found? No recognition + + b = Streamlines(fornix) + b1 = b.copy() + + b2 = b1[:20].copy() + b2._data += np.array([100, 0, 0]) + + b3 = b1[:20].copy() + b3._data += np.array([300, 0, 0]) + + b.extend(b3) + + rb = RecoBundles(b, greater_than=0, clust_thr=10) + + rec_trans, rec_labels = rb.recognize(model_bundle=b2, + model_clust_thr=5., + reduction_thr=10) + + if len(rec_trans) > 0: + refine_trans, refine_labels = rb.refine(model_bundle=b2, + pruned_streamlines=rec_trans, + model_clust_thr=5., + reduction_thr=10) + + assert_equal(len(refine_labels), 0) + assert_equal(len(refine_trans), 0) + + else: + assert_equal(len(rec_labels), 0) + assert_equal(len(rec_trans), 0) + + +def test_rb_reduction_mam(): + + rb = RecoBundles(f, greater_than=0, clust_thr=10, verbose=True) + + rec_trans, rec_labels = rb.recognize(model_bundle=f2, + model_clust_thr=5., + reduction_thr=10, + reduction_distance='mam', + slr=True, + slr_metric='asymmetric', + pruning_distance='mam') + + D = bundles_distances_mam(f2, f[rec_labels]) + + # check if the bundle is recognized correctly + if len(f2) == len(rec_labels): + for row in D: + assert_equal(row.min(), 0) + + refine_trans, refine_labels = rb.refine(model_bundle=f2, + pruned_streamlines=rec_trans, + model_clust_thr=5., + reduction_thr=10) + + D = bundles_distances_mam(f2, f[refine_labels]) + + # check if the bundle is recognized correctly + for row in D: + assert_equal(row.min(), 0) + + +if __name__ == '__main__': + + run_module_suite() diff --git a/dipy/segment/tissue.py b/dipy/segment/tissue.py index 2efb8fb81d..b283152c0c 100644 --- a/dipy/segment/tissue.py +++ b/dipy/segment/tissue.py @@ -117,6 +117,8 @@ def classify(self, image, nclasses, beta, tolerance=None, max_iter=None): else: max_iter = 100 + if tolerance is None: + tolerance = 1e-05 for i in range(max_iter): if self.verbose: @@ -143,8 +145,6 @@ def classify(self, image, nclasses, beta, tolerance=None, max_iter=None): self.energies.append(energy) self.energies_sum.append(energy[energy > -np.inf].sum()) - if tolerance is None: - tolerance = 1e-05 if i % 10 == 0 and i != 0: diff --git a/dipy/sims/phantom.py b/dipy/sims/phantom.py index 91743c8f67..a46e47a284 100644 --- a/dipy/sims/phantom.py +++ b/dipy/sims/phantom.py @@ -4,7 +4,7 @@ from dipy.sims.voxel import SingleTensor, diffusion_evals import dipy.sims.voxel as vox from dipy.core.geometry import vec2vec_rotmat -from dipy.data import get_data +from dipy.data import get_fnames from dipy.core.gradients import gradient_table @@ -144,7 +144,7 @@ def orbital_phantom(gtab=None, """ if gtab is None: - fimg, fbvals, fbvecs = get_data('small_64D') + fimg, fbvals, fbvecs = get_fnames('small_64D') gtab = gradient_table(fbvals, fbvecs) if func is None: diff --git a/dipy/sims/tests/test_phantom.py b/dipy/sims/tests/test_phantom.py index 5399d737b1..6a088ec101 100644 --- a/dipy/sims/tests/test_phantom.py +++ b/dipy/sims/tests/test_phantom.py @@ -4,15 +4,15 @@ from numpy.testing import (assert_, assert_array_almost_equal, run_module_suite) -from dipy.data import get_data +from dipy.data import get_fnames from dipy.reconst.dti import TensorModel from dipy.sims.phantom import orbital_phantom from dipy.core.gradients import gradient_table +from dipy.io.gradients import read_bvals_bvecs -fimg, fbvals, fbvecs = get_data('small_64D') -bvals = np.load(fbvals) -bvecs = np.load(fbvecs) +fimg, fbvals, fbvecs = get_fnames('small_64D') +bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) bvecs[np.isnan(bvecs)] = 0 gtab = gradient_table(bvals, bvecs) diff --git a/dipy/sims/tests/test_voxel.py b/dipy/sims/tests/test_voxel.py index 6fa463dea4..5eddd2c0e4 100644 --- a/dipy/sims/tests/test_voxel.py +++ b/dipy/sims/tests/test_voxel.py @@ -9,12 +9,12 @@ sticks_and_ball, multi_tensor_dki, kurtosis_element, dki_signal) # from dipy.core.geometry import vec2vec_rotmat -from dipy.data import get_data, get_sphere +from dipy.data import get_fnames, get_sphere from dipy.core.gradients import gradient_table from dipy.io.gradients import read_bvals_bvecs -fimg, fbvals, fbvecs = get_data('small_64D') +fimg, fbvals, fbvecs = get_fnames('small_64D') bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) @@ -114,7 +114,7 @@ def test_multi_tensor(): # assert_(odf.shape == (len(vertices),)) # assert_(np.all(odf <= 1) & np.all(odf >= 0)) - fimg, fbvals, fbvecs = get_data('small_101D') + fimg, fbvals, fbvecs = get_fnames('small_101D') bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) gtab = gradient_table(bvals, bvecs) diff --git a/dipy/sims/voxel.py b/dipy/sims/voxel.py index e4a5405640..61c0219820 100644 --- a/dipy/sims/voxel.py +++ b/dipy/sims/voxel.py @@ -399,10 +399,10 @@ def multi_tensor(gtab, mevals, S0=1., angles=[(0, 0), (90, 0)], -------- >>> import numpy as np >>> from dipy.sims.voxel import multi_tensor - >>> from dipy.data import get_data + >>> from dipy.data import get_fnames >>> from dipy.core.gradients import gradient_table >>> from dipy.io.gradients import read_bvals_bvecs - >>> fimg, fbvals, fbvecs = get_data('small_101D') + >>> fimg, fbvals, fbvecs = get_fnames('small_101D') >>> bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) >>> gtab = gradient_table(bvals, bvecs) >>> mevals=np.array(([0.0015, 0.0003, 0.0003],[0.0015, 0.0003, 0.0003])) @@ -471,10 +471,10 @@ def multi_tensor_dki(gtab, mevals, S0=1., angles=[(90., 0.), (90., 0.)], -------- >>> import numpy as np >>> from dipy.sims.voxel import multi_tensor_dki - >>> from dipy.data import get_data + >>> from dipy.data import get_fnames >>> from dipy.core.gradients import gradient_table >>> from dipy.io.gradients import read_bvals_bvecs - >>> fimg, fbvals, fbvecs = get_data('small_64D') + >>> fimg, fbvals, fbvecs = get_fnames('small_64D') >>> bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) >>> bvals_2s = np.concatenate((bvals, bvals * 2), axis=0) >>> bvecs_2s = np.concatenate((bvecs, bvecs), axis=0) diff --git a/dipy/testing/__init__.py b/dipy/testing/__init__.py index 5269a19604..b5a8184481 100644 --- a/dipy/testing/__init__.py +++ b/dipy/testing/__init__.py @@ -4,6 +4,7 @@ from dipy.testing.decorators import doctest_skip_parser from numpy.testing import assert_array_equal import numpy as np +import scipy from distutils.version import LooseVersion # set path to example data @@ -24,6 +25,7 @@ def assert_arrays_equal(arrays1, arrays2): for arr1, arr2 in zip(arrays1, arrays2): assert_array_equal(arr1, arr2) + def setup_test(): """ Set numpy print options to "legacy" for new versions of numpy @@ -37,3 +39,12 @@ def setup_test(): """ if LooseVersion(np.__version__) >= LooseVersion('1.14'): np.set_printoptions(legacy='1.13') + + # Temporary fix until scipy release in October 2018 + # must be removed after that + # print the first occurrence of matching warnings for each location + # (module + line number) where the warning is issued + if LooseVersion(np.__version__) >= LooseVersion('1.15') and \ + LooseVersion(scipy.version.short_version) <= '1.1.0': + import warnings + warnings.simplefilter("default") diff --git a/dipy/tests/test_scripts.py b/dipy/tests/test_scripts.py index f3ed7769ee..44cc104861 100644 --- a/dipy/tests/test_scripts.py +++ b/dipy/tests/test_scripts.py @@ -19,7 +19,7 @@ import nibabel as nib from nibabel.tmpdirs import InTemporaryDirectory -from dipy.data import get_data +from dipy.data import get_fnames # Quickbundles command-line requires matplotlib: try: @@ -68,7 +68,7 @@ def assert_image_shape_affine(filename, shape, affine): def test_dipy_fit_tensor_again(): with InTemporaryDirectory(): - dwi, bval, bvec = get_data("small_25") + dwi, bval, bvec = get_fnames("small_25") # Copy data to tmp directory shutil.copyfile(dwi, "small_25.nii.gz") shutil.copyfile(bval, "small_25.bval") @@ -90,7 +90,7 @@ def test_dipy_fit_tensor_again(): assert_image_shape_affine("small_25_rd.nii.gz", shape, affine) with InTemporaryDirectory(): - dwi, bval, bvec = get_data("small_25") + dwi, bval, bvec = get_fnames("small_25") # Copy data to tmp directory shutil.copyfile(dwi, "small_25.nii.gz") shutil.copyfile(bval, "small_25.bval") @@ -121,7 +121,7 @@ def test_dipy_fit_tensor_again(): @nt.dec.skipif(no_mpl) def test_qb_commandline(): with InTemporaryDirectory(): - tracks_file = get_data('fornix') + tracks_file = get_fnames('fornix') cmd = ["dipy_quickbundles", tracks_file, '--pkl_file', 'mypickle.pkl', '--out_file', 'tracks300.trk'] out = run_command(cmd) @@ -135,7 +135,7 @@ def test_qb_commandline_output_path_handling(): os.mkdir('output') os.chdir('work') - tracks_file = get_data('fornix') + tracks_file = get_fnames('fornix') # Need to specify an output directory with a "../" style path # to trigger old bug. diff --git a/dipy/tracking/benchmarks/bench_streamline.py b/dipy/tracking/benchmarks/bench_streamline.py index 15db139832..3e735fef5f 100644 --- a/dipy/tracking/benchmarks/bench_streamline.py +++ b/dipy/tracking/benchmarks/bench_streamline.py @@ -17,7 +17,7 @@ from numpy.testing import measure from numpy.testing import assert_array_equal, assert_array_almost_equal -from dipy.data import get_data +from dipy.data import get_fnames from nibabel import trackvis as tv from dipy.tracking.streamline import (set_number_of_points, @@ -109,7 +109,7 @@ def bench_length(): def bench_compress_streamlines(): repeat = 10 - fname = get_data('fornix') + fname = get_fnames('fornix') streams, hdr = tv.read(fname) streamlines = [i[0] for i in streams] @@ -119,7 +119,7 @@ def bench_compress_streamlines(): print("Cython time: {0:.3}sec".format(cython_time)) del streamlines - fname = get_data('fornix') + fname = get_fnames('fornix') streams, hdr = tv.read(fname) streamlines = [i[0] for i in streams] python_time = measure("map(compress_streamlines_python, streamlines)", diff --git a/dipy/tracking/distances.pyx b/dipy/tracking/distances.pyx index 615d825fc6..65af6d8b48 100644 --- a/dipy/tracking/distances.pyx +++ b/dipy/tracking/distances.pyx @@ -1513,15 +1513,15 @@ def local_skeleton_clustering(tracks, d_thr=10): Visualization: It is possible to visualize the clustering C from the example - above using the fvtk module:: + above using the dipy.viz module:: - from dipy.viz import fvtk - r=fvtk.ren() + from dipy.viz import window, actor + r=window.Renderer() for c in C: color=np.random.rand(3) for i in C[c]['indices']: - fvtk.add(r,fvtk.line(tracks[i],color)) - fvtk.show(r) + r.add(actor.line(tracks[i],color)) + window.show(r) See Also -------- @@ -1816,18 +1816,18 @@ def larch_3split(tracks, indices=None, thr=10.): Here is an example of how to visualize the clustering above:: - from dipy.viz import fvtk - r=fvtk.ren() - fvtk.add(r,fvtk.line(tracks,fvtk.red)) - fvtk.show(r) + from dipy.viz import window, actor + r=window.Renderer() + r.add(actor.line(tracks,fvtk.red)) + window.show(r) for c in C: color=np.random.rand(3) for i in C[c]['indices']: - fos.add(r,fvtk.line(tracks[i],color)) - fvtk.show(r) + r.add(actor.line(tracks[i],color)) + window.show(r) for c in C: - fvtk.add(r,fos.line(C[c]['rep3']/C[c]['N'],fos.white)) - fvtk.show(r) + r.add(actor.line(C[c]['rep3']/C[c]['N'],fos.white)) + window.show(r) ''' cdef: diff --git a/dipy/tracking/eudx.py b/dipy/tracking/eudx.py index 3ced1e9db2..a9edfcfe6d 100644 --- a/dipy/tracking/eudx.py +++ b/dipy/tracking/eudx.py @@ -111,9 +111,9 @@ def __init__(self, a, ind, -------- >>> import nibabel as nib >>> from dipy.reconst.dti import TensorModel, quantize_evecs - >>> from dipy.data import get_data, get_sphere + >>> from dipy.data import get_fnames, get_sphere >>> from dipy.core.gradients import gradient_table - >>> fimg,fbvals,fbvecs = get_data('small_101D') + >>> fimg,fbvals,fbvecs = get_fnames('small_101D') >>> img = nib.load(fimg) >>> affine = img.affine >>> data = img.get_data() diff --git a/dipy/tracking/life.py b/dipy/tracking/life.py index ebcb5eac7e..ea4acfec9e 100644 --- a/dipy/tracking/life.py +++ b/dipy/tracking/life.py @@ -75,18 +75,19 @@ def gradient(f): slice1[axis] = slice(1, -1) slice2[axis] = slice(2, None) slice3[axis] = slice(None, -2) + # 1D equivalent -- out[1:-1] = (f[2:] - f[:-2])/2.0 - out[slice1] = (f[slice2] - f[slice3])/2.0 + out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)])/2.0 slice1[axis] = 0 slice2[axis] = 1 slice3[axis] = 0 # 1D equivalent -- out[0] = (f[1] - f[0]) - out[slice1] = (f[slice2] - f[slice3]) + out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) slice1[axis] = -1 slice2[axis] = -1 slice3[axis] = -2 # 1D equivalent -- out[-1] = (f[-1] - f[-2]) - out[slice1] = (f[slice2] - f[slice3]) + out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) # divide by step size outvals.append(out / dx[axis]) @@ -135,7 +136,7 @@ def grad_tensor(grad, evals): """ # This is the rotation matrix from [1, 0, 0] to this gradient of the sl: - R = la.svd(np.matrix(grad), overwrite_a=True)[2] + R = la.svd([grad], overwrite_a=True)[2] # This is the 3 by 3 tensor after rotation: T = np.dot(np.dot(R, np.diag(evals)), R.T) return T diff --git a/dipy/tracking/local/localtracking.py b/dipy/tracking/local/localtracking.py index ee3aa877c2..92dfab0b84 100644 --- a/dipy/tracking/local/localtracking.py +++ b/dipy/tracking/local/localtracking.py @@ -1,3 +1,5 @@ +import random + import numpy as np from dipy.tracking.local.localtrack import local_tracker, pft_tracker @@ -36,7 +38,7 @@ def _get_voxel_size(affine): def __init__(self, direction_getter, tissue_classifier, seeds, affine, step_size, max_cross=None, maxlen=500, fixedstep=True, - return_all=True): + return_all=True, random_seed=None): """Creates streamlines by using local fiber-tracking. Parameters @@ -69,6 +71,9 @@ def __init__(self, direction_getter, tissue_classifier, seeds, affine, return_all : bool If true, return all generated streamlines, otherwise only streamlines reaching end points or exiting the image. + random_seed : int + The seed for the random seed generator (numpy.random.seed and + random.seed). """ self.direction_getter = direction_getter @@ -88,6 +93,7 @@ def __init__(self, direction_getter, tissue_classifier, seeds, affine, self.max_cross = max_cross self.max_length = maxlen self.return_all = return_all + self.random_seed = random_seed def _tracker(self, seed, first_step, streamline): return local_tracker(self.direction_getter, @@ -116,6 +122,12 @@ def _generate_streamlines(self): B = F.copy() for s in self.seeds: s = np.dot(lin, s) + offset + # Set the random seed in numpy and random + if self.random_seed is not None: + s_random_seed = hash(np.abs((np.sum(s)) + self.random_seed)) \ + % (2**32 - 1) + random.seed(s_random_seed) + np.random.seed(s_random_seed) directions = self.direction_getter.initial_direction(s) if directions.size == 0 and self.return_all: # only the seed position @@ -146,7 +158,8 @@ class ParticleFilteringTracking(LocalTracking): def __init__(self, direction_getter, tissue_classifier, seeds, affine, step_size, max_cross=None, maxlen=500, pft_back_tracking_dist=2, pft_front_tracking_dist=1, - pft_max_trial=20, particle_count=15, return_all=True): + pft_max_trial=20, particle_count=15, return_all=True, + random_seed=None): r"""A streamline generator using the particle filtering tractography method [1]_. @@ -192,6 +205,9 @@ def __init__(self, direction_getter, tissue_classifier, seeds, affine, return_all : bool If true, return all generated streamlines, otherwise only streamlines reaching end points or exiting the image. + random_seed : int + The seed for the random seed generator (numpy.random.seed and + random.seed). References ---------- @@ -239,7 +255,8 @@ def __init__(self, direction_getter, tissue_classifier, seeds, affine, max_cross, maxlen, True, - return_all) + return_all, + random_seed) def _tracker(self, seed, first_step, streamline): return pft_tracker(self.direction_getter, diff --git a/dipy/tracking/local/tests/test_tracking.py b/dipy/tracking/local/tests/test_tracking.py index 4f205b7549..e2c9b711ad 100644 --- a/dipy/tracking/local/tests/test_tracking.py +++ b/dipy/tracking/local/tests/test_tracking.py @@ -5,7 +5,7 @@ from dipy.core.gradients import gradient_table from dipy.core.sphere import HemiSphere, unit_octahedron -from dipy.data import get_data, get_sphere +from dipy.data import get_fnames, get_sphere from dipy.direction import (BootDirectionGetter, ClosestPeakDirectionGetter, DeterministicMaximumDirectionGetter, @@ -213,9 +213,10 @@ def allclose(x, y): for sl in streamlines: npt.assert_(np.allclose(sl, expected[1])) - # The first path is not possible if pmf_threshold > 0.4 - dg = ProbabilisticDirectionGetter.from_pmf(pmf, 90, sphere, - pmf_threshold=0.5) + # The first path is not possible if pmf_threshold > 0.67 + # 0.4/0.6 < 2/3, multiplying the pmf should not change the ratio + dg = ProbabilisticDirectionGetter.from_pmf(10*pmf, 90, sphere, + pmf_threshold=0.67) streamlines = LocalTracking(dg, tc, seeds, np.eye(4), 1.) for sl in streamlines: @@ -239,8 +240,18 @@ def allclose(x, y): npt.assert_(np.all((s + 0.5).astype(int) < mask.shape)) # Test that the number of streamline return with return_all=True equal the # number of seeds places + npt.assert_(np.array([len(streamlines) == len(seeds)])) + # Test reproducibility + tracking_1 = Streamlines(LocalTracking(dg, tc, seeds, np.eye(4), + 0.5, + random_seed=0)).data + tracking_2 = Streamlines(LocalTracking(dg, tc, seeds, np.eye(4), + 0.5, + random_seed=0)).data + npt.assert_equal(tracking_1, tracking_2) + def test_particle_filtering_tractography(): """This tests that the ParticleFilteringTracking produces @@ -374,6 +385,15 @@ def test_particle_filtering_tractography(): lambda: ParticleFilteringTracking(dg, tc, seeds, np.eye(4), step_size, particle_count=-1)) + # Test reproducibility + tracking_1 = Streamlines(ParticleFilteringTracking(dg, tc, seeds, np.eye(4), + step_size, + random_seed=0)).data + tracking_2 = Streamlines(ParticleFilteringTracking(dg, tc, seeds, np.eye(4), + step_size, + random_seed=0)).data + npt.assert_equal(tracking_1, tracking_2) + def test_maximum_deterministic_tracker(): """This tests that the Maximum Deterministic Direction Getter plays nice @@ -438,11 +458,11 @@ def allclose(x, y): npt.assert_(np.allclose(sl, expected[1])) # Both path are not possible if 90 degree turns are exclude and - # if pmf_threhold is larger than 0.4. Streamlines should stop at - # the crossing - - dg = DeterministicMaximumDirectionGetter.from_pmf(pmf, 80, sphere, - pmf_threshold=0.5) + # if pmf_threshold is larger than 0.67. Streamlines should stop at + # the crossing. + # 0.4/0.6 < 2/3, multiplying the pmf should not change the ratio + dg = DeterministicMaximumDirectionGetter.from_pmf(10*pmf, 80, sphere, + pmf_threshold=0.67) streamlines = LocalTracking(dg, tc, seeds, np.eye(4), 1.) for sl in streamlines: @@ -700,7 +720,7 @@ def test_affine_transformations(): # TST - in vivo affine exemple # Sometimes data have affines with tiny shear components. # For example, the small_101D data-set has some of that: - fdata, _, _ = get_data('small_101D') + fdata, _, _ = get_fnames('small_101D') a6 = nib.load(fdata).affine for affine in [a0, a1, a2, a3, a4, a5, a6]: diff --git a/dipy/tracking/streamline.py b/dipy/tracking/streamline.py index 8c67e2c0f4..ebfcea1362 100644 --- a/dipy/tracking/streamline.py +++ b/dipy/tracking/streamline.py @@ -9,6 +9,7 @@ from nibabel.affines import apply_affine from dipy.tracking.streamlinespeed import set_number_of_points from dipy.tracking.streamlinespeed import length +from dipy.tracking.distances import bundles_distances_mdf from dipy.tracking.streamlinespeed import compress_streamlines import dipy.tracking.utils as ut from dipy.tracking.utils import streamline_near_roi @@ -29,7 +30,6 @@ MEGABYTE = 1024 * 1024 - class _BuildCache(object): def __init__(self, arr_seq, common_shape, dtype): self.offsets = list(arr_seq._offsets) @@ -37,8 +37,12 @@ def __init__(self, arr_seq, common_shape, dtype): self.next_offset = arr_seq._get_next_offset() self.bytes_per_buf = arr_seq._buffer_size * MEGABYTE # Use the passed dtype only if null data array - self.dtype = dtype if arr_seq._data.size == 0 else arr_seq._data.dtype - if arr_seq.common_shape != () and common_shape != arr_seq.common_shape: + if arr_seq._data.size == 0: + self.dtype = dtype + else: + arr_seq._data.dtype + if (arr_seq.common_shape != () and + common_shape != arr_seq.common_shape): raise ValueError( "All dimensions, except the first one, must match exactly") self.common_shape = common_shape @@ -50,34 +54,32 @@ def update_seq(self, arr_seq): arr_seq._offsets = np.array(self.offsets) arr_seq._lengths = np.array(self.lengths) - class Streamlines(ArraySequence): - def __init__(self, *args, **kwargs): super(Streamlines, self).__init__(*args, **kwargs) def append(self, element, cache_build=False): - """ Appends `element` to this array sequence. + """ + Appends `element` to this array sequence. + Append can be a lot faster if it knows that it is appending several - elements instead of a single element. In that case it can cache the - parameters it uses between append operations, in a "build cache". To - tell append to do this, use ``cache_build=True``. If you use - ``cache_build=True``, you need to finalize the append operations with - :meth:`finalize_append`. + elements instead of a single element. In that case it can cache + the parameters it uses between append operations, in a "build + cache". To tell append to do this, use ``cache_build=True``. If + you use ``cache_build=True``, you need to finalize the append + operations with :meth:`finalize_append`. + Parameters ---------- - element : ndarray - Element to append. The shape must match already inserted elements - shape except for the first dimension. - cache_build : {False, True} - Whether to save the build cache from this append routine. If True, - append can assume it is the only player updating `self`, and the - caller must finalize `self` after all append operations, with - ``self.finalize_append()``. - Returns + element : ndarray Element to append. The shape must match already + inserted elements shape except for the first dimension. + cache_build : {False, True} Whether to save the build cache + from this append routine. If True, append can assume it is the + only player updating `self`, and the caller must finalize + `self` after all append operations, with + ``self.finalize_append()``. Returns ------- - None - Notes + None Notes ----- If you need to add multiple elements you should consider `ArraySequence.extend`. @@ -124,19 +126,20 @@ def extend(self, elements): """ Appends all `elements` to this array sequence. Parameters ---------- - elements : iterable of ndarrays or :class:`ArraySequence` object - If iterable of ndarrays, each ndarray will be concatenated along - the first dimension then appended to the data of this + elements : iterable of ndarrays or :class:`ArraySequence` instance + + If iterable of ndarrays, each ndarray will be concatenated + along the first dimension then appended to the data of this ArraySequence. - If :class:`ArraySequence` object, its data are simply appended to - the data of this ArraySequence. + If :class:`ArraySequence` object, its data are simply appended + to the data of this ArraySequence. + Returns ------- - None - Notes + None Notes ----- - The shape of the elements to be added must match the one of the data of - this :class:`ArraySequence` except for the first dimension. + The shape of the elements to be added must match the one of the + data of this :class:`ArraySequence` except for the first dimension. """ # If possible try pre-allocating memory. try: @@ -178,10 +181,10 @@ def unlist_streamlines(streamlines): curr_pos = 0 for (i, s) in enumerate(streamlines): - prev_pos = curr_pos - curr_pos += s.shape[0] - points[prev_pos:curr_pos] = s - offsets[i] = curr_pos + prev_pos = curr_pos + curr_pos += s.shape[0] + points[prev_pos:curr_pos] = s + offsets[i] = curr_pos return points, offsets @@ -276,36 +279,51 @@ def deform_streamlines(streamlines, return new_streamlines -def transform_streamlines(streamlines, mat): +def transform_streamlines(streamlines, mat, in_place=False): """ Apply affine transformation to streamlines Parameters ---------- - streamlines : list - List of 2D ndarrays of shape[-1]==3 + streamlines : Streamlines + Streamlines object mat : array, (4, 4) transformation matrix + in_place : bool + If True then change data in place. + Be careful changes input streamlines. Returns ------- - new_streamlines : list - List of the transformed 2D ndarrays of shape[-1]==3 + new_streamlines : Streamlines + Sequence transformed 2D ndarrays of shape[-1]==3 """ + # using new Streamlines API + if isinstance(streamlines, Streamlines): + if in_place: + streamlines._data = apply_affine(mat, streamlines._data) + return streamlines + new_streamlines = streamlines.copy() + new_streamlines._data = apply_affine(mat, new_streamlines._data) + return new_streamlines + # supporting old data structure of streamlines return [apply_affine(mat, s) for s in streamlines] -def select_random_set_of_streamlines(streamlines, select): +def select_random_set_of_streamlines(streamlines, select, rng=None): """ Select a random set of streamlines Parameters ---------- - streamlines : list - List of 2D ndarrays of shape[-1]==3 + streamlines : Steamlines + Object of 2D ndarrays of shape[-1]==3 select : int Number of streamlines to select. If there are less streamlines than ``select`` then ``select=len(streamlines)``. + rng : RandomState + Default None. + Returns ------- selected_streamlines : list @@ -315,7 +333,11 @@ def select_random_set_of_streamlines(streamlines, select): The same streamline will not be selected twice. """ len_s = len(streamlines) - index = np.random.choice(len_s, min(select, len_s), replace=False) + if rng is None: + rng = np.random.RandomState() + index = rng.choice(len_s, min(select, len_s), replace=False) + if isinstance(streamlines, Streamlines): + return streamlines[index] return [streamlines[i] for i in index] @@ -448,6 +470,83 @@ def select_by_rois(streamlines, rois, include, mode=None, affine=None, yield sl +def cluster_confidence(streamlines, max_mdf=5, subsample=12, power=1, + override=False): + """ Computes the cluster confidence index (cci), which is an + estimation of the support a set of streamlines gives to + a particular pathway. + + Ex: A single streamline with no others in the dataset + following a similar pathway has a low cci. A streamline + in a bundle of 100 streamlines that follow similar + pathways has a high cci. + + See: Jordan et al. 2017 + (Based on streamline MDF distance from Garyfallidis et al. 2012) + + Parameters + ---------- + streamlines : list of 2D (N, 3) arrays + A sequence of streamlines of length N (# streamlines) + max_mdf : int + The maximum MDF distance (mm) that will be considered a + "supporting" streamline and included in cci calculation + subsample: int + The number of points that are considered for each streamline + in the calculation. To save on calculation time, each + streamline is subsampled to subsampleN points. + power: int + The power to which the MDF distance for each streamline + will be raised to determine how much it contributes to + the cci. High values of power make the contribution value + degrade much faster. Example: a streamline with 5mm MDF + similarity contributes 1/5 to the cci if power is 1, but + only contributes 1/5^2 = 1/25 if power is 2. + override: bool, False by default + override means that the cci calculation will still occur even + though there are short streamlines in the dataset that may alter + expected behaviour. + + Returns + ------- + Returns an array of CCI scores + + References + ---------- + [Jordan17] Jordan K. Et al., Cluster Confidence Index: A Streamline-Wise + Pathway Reproducibility Metric for Diffusion-Weighted MRI Tractography, + Journal of Neuroimaging, vol 28, no 1, 2017. + + [Garyfallidis12] Garyfallidis E. et al., QuickBundles a method for + tractography simplification, Frontiers in Neuroscience, + vol 6, no 175, 2012. + + """ + + # error if any streamlines are shorter than 20mm + lengths = list(length(streamlines)) + if min(lengths) < 20 and not override: + raise ValueError('Short streamlines found. We recommend removing them.' + ' To continue without removing short streamlines set' + ' override=True') + + # calculate the pairwise MDF distance between all streamlines in dataset + subsamp_sls = set_number_of_points(streamlines, subsample) + + cci_score_mtrx = np.zeros([len(subsamp_sls)]) + + for i, sl in enumerate(subsamp_sls): + mdf_mx = bundles_distances_mdf([subsamp_sls[i]], subsamp_sls) + if (mdf_mx == 0).sum() > 1: + raise ValueError('Identical streamlines. CCI calculation invalid') + mdf_mx_oi = (mdf_mx > 0) & (mdf_mx < max_mdf) & ~ np.isnan(mdf_mx) + mdf_mx_oi_only = mdf_mx[mdf_mx_oi] + cci_score = np.sum(np.divide(1, np.power(mdf_mx_oi_only, power))) + cci_score_mtrx[i] = cci_score + + return cci_score_mtrx + + def _orient_generator(out, roi1, roi2): """ Helper function to `orient_by_rois` @@ -482,7 +581,9 @@ def _orient_list(out, roi1, roi2): min1 = np.argmin(dist1, 0) min2 = np.argmin(dist2, 0) if min1[0] > min2[0]: - out[idx] = sl[::-1] + out[idx][:, 0] = sl[::-1][:, 0] + out[idx][:, 1] = sl[::-1][:, 1] + out[idx][:, 2] = sl[::-1][:, 2] return out @@ -553,7 +654,7 @@ def orient_by_rois(streamlines, roi1, roi2, in_place=False, # If it's a generator on input, we may as well generate it # here and now: if isinstance(streamlines, types.GeneratorType): - out = list(streamlines) + out = Streamlines(streamlines) elif in_place: out = streamlines @@ -597,7 +698,8 @@ def _extract_vals(data, streamlines, affine=None, threedvec=False): """ data = data.astype(np.float) if (isinstance(streamlines, list) or - isinstance(streamlines, types.GeneratorType)): + isinstance(streamlines, types.GeneratorType) or + isinstance(streamlines, Streamlines)): if affine is not None: streamlines = ut.move_streamlines(streamlines, np.linalg.inv(affine)) @@ -605,11 +707,11 @@ def _extract_vals(data, streamlines, affine=None, threedvec=False): vals = [] for sl in streamlines: if threedvec: - vals.append(list(vfu.interpolate_vector_3d(data, - sl.astype(np.float))[0])) + vals.append(list(vfu.interpolate_vector_3d( + data, sl.astype(np.float))[0])) else: - vals.append(list(vfu.interpolate_scalar_3d(data, - sl.astype(np.float))[0])) + vals.append(list(vfu.interpolate_scalar_3d( + data, sl.astype(np.float))[0])) elif isinstance(streamlines, np.ndarray): sl_shape = streamlines.shape @@ -677,11 +779,11 @@ def values_from_volume(data, streamlines, affine=None): return _extract_vals(data, streamlines, affine=affine, threedvec=True) if isinstance(streamlines, types.GeneratorType): - streamlines = list(streamlines) + streamlines = Streamlines(streamlines) vals = [] for ii in range(data.shape[-1]): vals.append(_extract_vals(data[..., ii], streamlines, - affine=affine)) + affine=affine)) if isinstance(vals[-1], np.ndarray): return np.swapaxes(np.array(vals), 2, 1).T diff --git a/dipy/tracking/tests/test_distances.py b/dipy/tracking/tests/test_distances.py index f89893a1aa..8291dccddf 100644 --- a/dipy/tracking/tests/test_distances.py +++ b/dipy/tracking/tests/test_distances.py @@ -61,7 +61,7 @@ def test_LSCv2(): print(t2-t1) print(len(C5)) - from dipy.data import get_data + from dipy.data import get_fnames from nibabel import trackvis as tv try: from dipy.viz import window, actor @@ -69,7 +69,7 @@ def test_LSCv2(): raise nose.plugins.skip.SkipTest( 'Fails to import dipy.viz due to %s' % str(e)) - streams, hdr = tv.read(get_data('fornix')) + streams, hdr = tv.read(get_fnames('fornix')) T3 = [tm.downsample(s[0], 6) for s in streams] print('lenT3', len(T3)) diff --git a/dipy/tracking/tests/test_life.py b/dipy/tracking/tests/test_life.py index 1258cfa900..79528b13a7 100644 --- a/dipy/tracking/tests/test_life.py +++ b/dipy/tracking/tests/test_life.py @@ -18,6 +18,7 @@ import dipy.core.ndindex as nd import dipy.core.gradients as grad import dipy.reconst.dti as dti +from dipy.io.gradients import read_bvals_bvecs THIS_DIR = op.dirname(__file__) @@ -65,7 +66,7 @@ def test_streamline_tensors(): def test_streamline_signal(): - data_file, bval_file, bvec_file = dpd.get_data('small_64D') + data_file, bval_file, bvec_file = dpd.get_fnames('small_64D') gtab = dpg.gradient_table(bval_file, bvec_file) evals = [0.0015, 0.0005, 0.0005] streamline1 = [[[1, 2, 3], [4, 5, 3], [5, 6, 3], [6, 7, 3]], @@ -102,9 +103,9 @@ def test_voxel2streamline(): def test_FiberModel_init(): # Get some small amount of data: - data_file, bval_file, bvec_file = dpd.get_data('small_64D') + data_file, bval_file, bvec_file = dpd.get_fnames('small_64D') data_ni = nib.load(data_file) - bvals, bvecs = (np.load(f) for f in (bval_file, bvec_file)) + bvals, bvecs = read_bvals_bvecs(bval_file, bvec_file) gtab = dpg.gradient_table(bvals, bvecs) FM = life.FiberModel(gtab) @@ -124,10 +125,11 @@ def test_FiberModel_init(): def test_FiberFit(): - data_file, bval_file, bvec_file = dpd.get_data('small_64D') + data_file, bval_file, bvec_file = dpd.get_fnames('small_64D') data_ni = nib.load(data_file) data = data_ni.get_data() - bvals, bvecs = (np.load(f) for f in (bval_file, bvec_file)) + data_aff = data_ni.affine + bvals, bvecs = read_bvals_bvecs(bval_file, bvec_file) gtab = dpg.gradient_table(bvals, bvecs) FM = life.FiberModel(gtab) evals = [0.0015, 0.0005, 0.0005] @@ -161,7 +163,7 @@ def test_FiberFit(): fit.data) def test_fit_data(): - fdata, fbval, fbvec = dpd.get_data('small_25') + fdata, fbval, fbvec = dpd.get_fnames('small_25') gtab = grad.gradient_table(fbval, fbvec) ni_data = nib.load(fdata) data = ni_data.get_data() diff --git a/dipy/tracking/tests/test_localtrack.py b/dipy/tracking/tests/test_localtrack.py deleted file mode 100644 index 3f0c75918e..0000000000 --- a/dipy/tracking/tests/test_localtrack.py +++ /dev/null @@ -1,23 +0,0 @@ -import numpy as np -import numpy.testing as npt - -from dipy.tracking.local.tissue_classifier import ThresholdTissueClassifier -from dipy.data import default_sphere -from dipy.direction import peaks_from_model - -def test_ThresholdTissueClassifier(): - a = np.random.random((3, 5, 7)) - mid = np.sort(a.ravel())[(3 * 5 * 7) // 2] - - ttc = ThresholdTissueClassifier(a, mid) - for i in range(3): - for j in range(5): - for k in range(7): - tissue = ttc.check_point(np.array([i, j, k], dtype=float)) - if a[i, j, k] > mid: - npt.assert_equal(tissue, 1) - else: - npt.assert_equal(tissue, 2) - - - diff --git a/dipy/tracking/tests/test_metrics.py b/dipy/tracking/tests/test_metrics.py index 24149f0264..b5af60855e 100644 --- a/dipy/tracking/tests/test_metrics.py +++ b/dipy/tracking/tests/test_metrics.py @@ -145,10 +145,10 @@ def test_downsample(): assert_equal(np.sum(res), 0) """ - from dipy.data import get_data + from dipy.data import get_fnames from nibabel import trackvis as tv - streams, hdr = tv.read(get_data('fornix')) + streams, hdr = tv.read(get_fnames('fornix')) Td = [tm.downsample(s[0], pts) for s in streams] T = [s[0] for s in streams] diff --git a/dipy/tracking/tests/test_propagation.py b/dipy/tracking/tests/test_propagation.py index d01beb9f8c..efa44e530b 100644 --- a/dipy/tracking/tests/test_propagation.py +++ b/dipy/tracking/tests/test_propagation.py @@ -2,7 +2,7 @@ import numpy as np import numpy.testing -from dipy.data import get_data, get_sphere +from dipy.data import get_fnames, get_sphere from dipy.core.gradients import gradient_table from dipy.reconst.gqi import GeneralizedQSamplingModel from dipy.reconst.dti import TensorModel, quantize_evecs @@ -68,7 +68,7 @@ def test_eudx_further(): """ Cause we love testin.. ;-) """ - fimg, fbvals, fbvecs = get_data('small_101D') + fimg, fbvals, fbvecs = get_fnames('small_101D') img = ni.load(fimg) data = img.get_data() @@ -123,7 +123,7 @@ def random_affine(seeds): def test_eudx_bad_seed(): """Test passing a bad seed to eudx""" - fimg, fbvals, fbvecs = get_data('small_101D') + fimg, fbvals, fbvecs = get_fnames('small_101D') img = ni.load(fimg) data = img.get_data() diff --git a/dipy/tracking/tests/test_streamline.py b/dipy/tracking/tests/test_streamline.py index bfaf2e5f05..8fc037fecc 100644 --- a/dipy/tracking/tests/test_streamline.py +++ b/dipy/tracking/tests/test_streamline.py @@ -25,7 +25,8 @@ select_by_rois, orient_by_rois, values_from_volume, - deform_streamlines) + deform_streamlines, + cluster_confidence) streamline = np.array([[82.20181274, 91.36505890, 43.15737152], @@ -317,8 +318,8 @@ def test_set_number_of_points(): len(streamlines_readonly)) # Test if nb_points is less than 2 - assert_raises(ValueError, set_number_of_points, [np.ones((10, 3)), - np.ones((10, 3))], nb_points=1) + assert_raises(ValueError, set_number_of_points, [ + np.ones((10, 3)), np.ones((10, 3))], nb_points=1) def test_set_number_of_points_memory_leaks(): @@ -720,9 +721,9 @@ def test_compress_streamlines(): # Make sure Cython and Python versions are the same. cstreamline_python = compress_streamlines_python( - special_streamline, - tol_error=tol_error+1e-4, - max_segment_length=np.inf) + special_streamline, + tol_error=tol_error+1e-4, + max_segment_length=np.inf) assert_equal(len(cspecial_streamline), len(cstreamline_python)) assert_array_almost_equal(cspecial_streamline, cstreamline_python) @@ -800,13 +801,13 @@ def test_select_by_rois(): tol=1) assert_arrays_equal(list(selection), [streamlines[0], - streamlines[1]]) + streamlines[1]]) selection = select_by_rois(streamlines, [mask1, mask2], [True, True], tol=1) assert_arrays_equal(list(selection), [streamlines[0], - streamlines[1]]) + streamlines[1]]) selection = select_by_rois(streamlines, [mask1, mask2], [True, False]) @@ -835,7 +836,7 @@ def test_select_by_rois(): selection = select_by_rois(streamlines, [mask1], [True], tol=1.0) assert_arrays_equal(list(selection), [streamlines[0], - streamlines[1]]) + streamlines[1]]) # Use different modes: selection = select_by_rois(streamlines, [mask1, mask2, mask3], @@ -869,16 +870,16 @@ def test_select_by_rois(): selection = select_by_rois(generate_sl(streamlines), [mask1], [True], tol=1.0) assert_arrays_equal(list(selection), [streamlines[0], - streamlines[1]]) + streamlines[1]]) def test_orient_by_rois(): - streamlines = [np.array([[0, 0., 0], - [1, 0., 0.], - [2, 0., 0.]]), - np.array([[2, 0., 0.], - [1, 0., 0], - [0, 0, 0.]])] + streamlines = Streamlines([np.array([[0, 0., 0], + [1, 0., 0.], + [2, 0., 0.]]), + np.array([[2, 0., 0.], + [1, 0., 0], + [0, 0, 0.]])]) # Make two ROIs: mask1_vol = np.zeros((4, 4, 4), dtype=bool) @@ -892,28 +893,29 @@ def test_orient_by_rois(): affine = np.eye(4) affine[:, 3] = [-1, 100, -20, 1] # Transform the streamlines: - x_streamlines = [sl + affine[:3, 3] for sl in streamlines] + x_streamlines = Streamlines([sl + affine[:3, 3] for sl in streamlines]) # After reorientation, this should be the answer: - flipped_sl = [streamlines[0], streamlines[1][::-1]] + flipped_sl = Streamlines([streamlines[0], streamlines[1][::-1]]) new_streamlines = orient_by_rois(streamlines, mask1_vol, mask2_vol, in_place=False, affine=None, as_generator=False) - npt.assert_equal(new_streamlines, flipped_sl) + npt.assert_array_equal(new_streamlines, flipped_sl) + npt.assert_(new_streamlines is not streamlines) # Test with affine: - x_flipped_sl = [s + affine[:3, 3] for s in flipped_sl] + x_flipped_sl = Streamlines([s + affine[:3, 3] for s in flipped_sl]) new_streamlines = orient_by_rois(x_streamlines, mask1_vol, mask2_vol, in_place=False, affine=affine, as_generator=False) - npt.assert_equal(new_streamlines, x_flipped_sl) + npt.assert_array_equal(new_streamlines, x_flipped_sl) npt.assert_(new_streamlines is not x_streamlines) # Test providing coord ROIs instead of vol ROIs: @@ -923,7 +925,7 @@ def test_orient_by_rois(): in_place=False, affine=affine, as_generator=False) - npt.assert_equal(new_streamlines, x_flipped_sl) + npt.assert_array_equal(new_streamlines, x_flipped_sl) # Test with as_generator set to True new_streamlines = orient_by_rois(streamlines, @@ -934,8 +936,8 @@ def test_orient_by_rois(): as_generator=True) npt.assert_(isinstance(new_streamlines, types.GeneratorType)) - ll = list(new_streamlines) - npt.assert_equal(ll, flipped_sl) + ll = Streamlines(new_streamlines) + npt.assert_array_equal(ll, flipped_sl) # Test with as_generator set to True and with the affine new_streamlines = orient_by_rois(x_streamlines, @@ -946,8 +948,8 @@ def test_orient_by_rois(): as_generator=True) npt.assert_(isinstance(new_streamlines, types.GeneratorType)) - ll = list(new_streamlines) - npt.assert_equal(ll, x_flipped_sl) + ll = Streamlines(new_streamlines) + npt.assert_array_equal(ll, x_flipped_sl) # Test with generator input: new_streamlines = orient_by_rois(generate_sl(streamlines), @@ -958,8 +960,8 @@ def test_orient_by_rois(): as_generator=True) npt.assert_(isinstance(new_streamlines, types.GeneratorType)) - ll = list(new_streamlines) - npt.assert_equal(ll, flipped_sl) + ll = Streamlines(new_streamlines) + npt.assert_array_equal(ll, flipped_sl) # Generator output cannot take a True `in_place` kwarg: npt.assert_raises(ValueError, orient_by_rois, *[generate_sl(streamlines), @@ -978,7 +980,7 @@ def test_orient_by_rois(): as_generator=False) npt.assert_(not isinstance(new_streamlines, types.GeneratorType)) - npt.assert_equal(new_streamlines, flipped_sl) + npt.assert_array_equal(new_streamlines, flipped_sl) # Modify in-place: new_streamlines = orient_by_rois(streamlines, @@ -988,7 +990,7 @@ def test_orient_by_rois(): affine=None, as_generator=False) - npt.assert_equal(new_streamlines, flipped_sl) + npt.assert_array_equal(new_streamlines, flipped_sl) # The two objects are one and the same: npt.assert_(new_streamlines is streamlines) @@ -1027,6 +1029,9 @@ def test_values_from_volume(): vv = values_from_volume(data, np.array(sl1)) npt.assert_almost_equal(vv, ans1, decimal=decimal) + vv = values_from_volume(data, Streamlines(sl1)) + npt.assert_almost_equal(vv, ans1, decimal=decimal) + affine = np.eye(4) affine[:, 3] = [-100, 10, 1, 1] x_sl1 = ut.move_streamlines(sl1, affine) @@ -1095,5 +1100,88 @@ def test_streamlines_generator(): npt.assert_equal(len(streamlines_generator), 0) +def test_cluster_confidence(): + mysl = np.array([np.arange(10)] * 3, 'float').T + + # a short streamline (<20 mm) should raise an error unless override=True + test_streamlines = Streamlines() + test_streamlines.append(mysl) + assert_raises(ValueError, cluster_confidence, test_streamlines) + cci = cluster_confidence(test_streamlines, override=True) + + # two identical streamlines should raise an error + test_streamlines = Streamlines() + test_streamlines.append(mysl, cache_build=True) + test_streamlines.append(mysl) + test_streamlines.finalize_append() + assert_raises(ValueError, cluster_confidence, test_streamlines) + + # 3 offset collinear streamlines + test_streamlines = Streamlines() + test_streamlines.append(mysl, cache_build=True) + test_streamlines.append(mysl+1) + test_streamlines.append(mysl+2) + test_streamlines.finalize_append() + cci = cluster_confidence(test_streamlines, override=True) + assert_equal(cci[0], cci[2]) + assert_true(cci[1] > cci[0]) + + # 3 parallel streamlines + mysl = np.zeros([10, 3]) + mysl[:, 0] = np.arange(10) + mysl2 = mysl.copy() + mysl2[:, 1] = 1 + mysl3 = mysl.copy() + mysl3[:, 1] = 2 + mysl4 = mysl.copy() + mysl4[:, 1] = 4 + mysl5 = mysl.copy() + mysl5[:, 1] = 5000 + + test_streamlines_p1 = Streamlines() + test_streamlines_p1.append(mysl, cache_build=True) + test_streamlines_p1.append(mysl2) + test_streamlines_p1.append(mysl3) + test_streamlines_p1.finalize_append() + test_streamlines_p2 = Streamlines() + test_streamlines_p2.append(mysl, cache_build=True) + test_streamlines_p2.append(mysl3) + test_streamlines_p2.append(mysl4) + test_streamlines_p2.finalize_append() + test_streamlines_p3 = Streamlines() + test_streamlines_p3.append(mysl, cache_build=True) + test_streamlines_p3.append(mysl2) + test_streamlines_p3.append(mysl3) + test_streamlines_p3.append(mysl5) + test_streamlines_p3.finalize_append() + + cci_p1 = cluster_confidence(test_streamlines_p1, override=True) + cci_p2 = cluster_confidence(test_streamlines_p2, override=True) + + # test relative distance + assert_array_equal(cci_p1, cci_p2*2) + + # test simple cci calculation + expected_p1 = np.array([1./1+1./2, 1./1+1./1, 1./1+1./2]) + expected_p2 = np.array([1./2+1./4, 1./2+1./2, 1./2+1./4]) + assert_array_equal(expected_p1, cci_p1) + assert_array_equal(expected_p2, cci_p2) + + # test power variable calculation (dropoff with distance) + cci_p1_pow2 = cluster_confidence(test_streamlines_p1, power=2, + override=True) + expected_p1_pow2 = np.array([np.power(1./1, 2)+np.power(1./2, 2), + np.power(1./1, 2)+np.power(1./1, 2), + np.power(1./1, 2)+np.power(1./2, 2)]) + + assert_array_equal(cci_p1_pow2, expected_p1_pow2) + + # test max distance (ignore distant sls) + cci_dist = cluster_confidence(test_streamlines_p3, + max_mdf=5, override=True) + expected_cci_dist = np.concatenate([cci_p1, np.zeros(1)]) + assert_array_equal(cci_dist, expected_cci_dist) + + if __name__ == '__main__': npt.run_module_suite() diff --git a/dipy/tracking/tests/test_utils.py b/dipy/tracking/tests/test_utils.py index 694d58638f..1ec21fbff8 100644 --- a/dipy/tracking/tests/test_utils.py +++ b/dipy/tracking/tests/test_utils.py @@ -63,12 +63,12 @@ def test_density_map(): # Test passing affine affine = np.diag([2, 2, 2, 1.]) - affine[:3, 3] = 1. + affine[: 3, 3] = 1. dm = density_map(streamlines, shape, affine=affine) assert_array_equal(dm, expected) # Shift the image by 2 voxels, ie 4mm - affine[:3, 3] -= 4. + affine[: 3, 3] -= 4. expected_old = expected new_shape = [i + 2 for i in shape] expected = np.zeros(new_shape) @@ -269,8 +269,10 @@ def _target(target_f, streamlines, voxel_both_true, voxel_one_true, assert_raises(ValueError, list, new) # Test smaller voxels - affine = np.random.random((4, 4)) - .5 - affine[3] = [0, 0, 0, 1] + affine = np.array([[.3, 0, 0, 0], + [0, .2, 0, 0], + [0, 0, .4, 0], + [0, 0, 0, 1]]) streamlines = list(move_streamlines(streamlines, affine)) new = list(target_f(streamlines, mask, affine=affine)) assert_equal(len(new), 1) @@ -537,6 +539,24 @@ def test_random_seeds_from_mask(): assert_equal(100, len(seeds)) assert_true(np.all((seeds > 1.5) & (seeds < 2.5))) + mask = np.zeros((15, 15, 15)) + mask[2:14, 2:14, 2:14] = 1 + seeds_npv_2 = random_seeds_from_mask(mask, seeds_count=2, + seed_count_per_voxel=True, + random_seed=0)[:150] + seeds_npv_3 = random_seeds_from_mask(mask, seeds_count=3, + seed_count_per_voxel=True, + random_seed=0)[:150] + assert_true(np.all(seeds_npv_2 == seeds_npv_3)) + + seeds_nt_150 = random_seeds_from_mask(mask, seeds_count=150, + seed_count_per_voxel=False, + random_seed=0)[:150] + seeds_nt_500 = random_seeds_from_mask(mask, seeds_count=500, + seed_count_per_voxel=False, + random_seed=0)[:150] + assert_true(np.all(seeds_nt_150 == seeds_nt_500)) + def test_connectivity_matrix_shape(): # Labels: z-planes have labels 0,1,2 @@ -626,8 +646,6 @@ def test_get_flexi_tvis_affine(): assert_array_almost_equal(origin[:3], np.multiply(tvis_hdr['dim'], vsz) - vsz / 2) - - # grid_affine = tvis_hdr['voxel_order'] = 'ASL' vsz = tvis_hdr['voxel_size'] = np.array([3, 4, 2.]) affine = get_flexi_tvis_affine(tvis_hdr, grid_affine) diff --git a/dipy/tracking/utils.py b/dipy/tracking/utils.py index f32b32cf68..5841f44ecc 100644 --- a/dipy/tracking/utils.py +++ b/dipy/tracking/utils.py @@ -134,7 +134,7 @@ def connectivity_matrix(streamlines, label_volume, voxel_size=None, This argument is deprecated. affine : array_like (4, 4) The mapping from voxel coordinates to streamline coordinates. - symmetric : bool, False by default + symmetric : bool, True by default Symmetric means we don't distinguish between start and end points. If symmetric is True, ``matrix[i, j] == matrix[j, i]``. return_mapping : bool, False by default @@ -413,7 +413,7 @@ def seeds_from_mask(mask, density=[1, 1, 1], voxel_size=None, affine=None): def random_seeds_from_mask(mask, seeds_count=1, seed_count_per_voxel=True, - affine=None): + affine=None, random_seed=None): """Creates randomly placed seeds for fiber tracking from a binary mask. Seeds points are placed randomly distributed in voxels of ``mask`` @@ -421,8 +421,7 @@ def random_seeds_from_mask(mask, seeds_count=1, seed_count_per_voxel=True, If ``seed_count_per_voxel`` is ``True``, this function is similar to ``seeds_from_mask()``, with the difference that instead of evenly distributing the seeds, it randomly places the seeds within the - voxels specified by the ``mask``. The initial random conditions can be set - using ``numpy.random.seed(...)``, prior to calling this function. + voxels specified by the ``mask``. Parameters ---------- @@ -439,6 +438,8 @@ def random_seeds_from_mask(mask, seeds_count=1, seed_count_per_voxel=True, The mapping between voxel indices and the point space for seeds. A seed point at the center the voxel ``[i, j, k]`` will be represented as ``[x, y, z]`` where ``[x, y, z, 1] == np.dot(affine, [i, j, k , 1])``. + random_seed : int + The seed for the random seed generator (numpy.random.seed). See Also -------- @@ -453,28 +454,37 @@ def random_seeds_from_mask(mask, seeds_count=1, seed_count_per_voxel=True, -------- >>> mask = np.zeros((3,3,3), 'bool') >>> mask[0,0,0] = 1 - >>> np.random.seed(1) - >>> random_seeds_from_mask(mask, seeds_count=1, seed_count_per_voxel=True) - array([[-0.082978 , 0.22032449, -0.49988563]]) - >>> random_seeds_from_mask(mask, seeds_count=6, seed_count_per_voxel=True) - array([[-0.19766743, -0.35324411, -0.40766141], - [-0.31373979, -0.15443927, -0.10323253], - [ 0.03881673, -0.08080549, 0.1852195 ], - [-0.29554775, 0.37811744, -0.47261241], - [ 0.17046751, -0.0826952 , 0.05868983], - [-0.35961306, -0.30189851, 0.30074457]]) + >>> random_seeds_from_mask(mask, seeds_count=1, seed_count_per_voxel=True, + ... random_seed=1) + array([[-0.0640051 , -0.47407377, 0.04966248]]) + >>> random_seeds_from_mask(mask, seeds_count=6, seed_count_per_voxel=True, + ... random_seed=1) + array([[-0.0640051 , -0.47407377, 0.04966248], + [ 0.0507979 , 0.20814782, -0.20909526], + [ 0.46702984, 0.04723225, 0.47268436], + [-0.27800683, 0.37073231, -0.29328084], + [ 0.39286015, -0.16802019, 0.32122912], + [-0.42369171, 0.27991879, -0.06159077]]) >>> mask[0,1,2] = 1 - >>> random_seeds_from_mask(mask, seeds_count=2, seed_count_per_voxel=True) - array([[ 0.46826158, -0.18657582, 0.19232262], - [ 0.37638915, 0.39460666, -0.41495579], - [-0.46094522, 0.66983042, 2.3781425 ], - [-0.40165317, 0.92110763, 2.45788953]]) + >>> random_seeds_from_mask(mask, seeds_count=2, seed_count_per_voxel=True, + ... random_seed=1) + array([[-0.0640051 , -0.47407377, 0.04966248], + [-0.27800683, 1.37073231, 1.70671916], + [ 0.0507979 , 0.20814782, -0.20909526], + [-0.48962585, 1.00187459, 1.99577329]]) """ mask = np.array(mask, dtype=bool, copy=False, ndmin=3) if mask.ndim != 3: raise ValueError('mask cannot be more than 3d') - where = np.argwhere(mask) + # Randomize the voxels + np.random.seed(random_seed) + shape = mask.shape + mask = mask.flatten() + indices = np.arange(len(mask)) + np.random.shuffle(indices) + + where = [np.unravel_index(i, shape) for i in indices if mask[i] == 1] num_voxels = len(where) if not seed_count_per_voxel: @@ -483,16 +493,23 @@ def random_seeds_from_mask(mask, seeds_count=1, seed_count_per_voxel=True, else: seeds_per_voxel = seeds_count - # Generate as many random triplets as the number of seeds needed - grid = np.random.random([seeds_per_voxel * num_voxels, 3]) - # Repeat elements of 'where' so that it can be added to grid - where = np.repeat(where, seeds_per_voxel, axis=0) - seeds = where + grid - .5 + seeds = [] + for i in range(1, seeds_per_voxel + 1): + for s in where: + # Set the random seed with the current seed, the current value of + # seeds per voxel and the global random seed. + if random_seed is not None: + s_random_seed = hash((np.sum(s) + 1) * i + random_seed) \ + % (2**32 - 1) + np.random.seed(s_random_seed) + # Generate random triplet + grid = np.random.random(3) + seed = s + grid - .5 + seeds.append(seed) seeds = asarray(seeds) if not seed_count_per_voxel: - # Randomize the seeds and select the requested amount - np.random.shuffle(seeds) + # Select the requested amount seeds = seeds[:seeds_count] # Apply the spatial transform @@ -1006,7 +1023,7 @@ def flexi_tvis_affine(sl_vox_order, grid_affine, dim, voxel_size): sl_ornt = orientation_from_string(str(sl_vox_order)) grid_ornt = nib.io_orientation(grid_affine) reorder_grid = reorder_voxels_affine( - grid_ornt, sl_ornt, np.array(dim)-1, np.array([1,1,1])) + grid_ornt, sl_ornt, np.array(dim)-1, np.array([1, 1, 1])) tvis_aff = affine_for_trackvis(voxel_size) @@ -1023,9 +1040,11 @@ def get_flexi_tvis_affine(tvis_hdr, nii_aff): ---------- tvis_hdr : header from a trackvis file nii_aff : array (4, 4), - An affine matrix describing the current space of the grid in relation to RAS+ scanner space + An affine matrix describing the current space of the grid in relation + to RAS+ scanner space nii_data : nd array - 3D array, each with shape (x, y, z) corresponding to the shape of the brain volume. + 3D array, each with shape (x, y, z) corresponding to the shape of the + brain volume. Returns ------- @@ -1056,6 +1075,7 @@ def _min_at(a, index, value): a[tuple(index)] = np.minimum(a[tuple(index)], value) + try: minimum_at = np.minimum.at except AttributeError: diff --git a/dipy/viz/__init__.py b/dipy/viz/__init__.py index de522fe725..dbf0891698 100644 --- a/dipy/viz/__init__.py +++ b/dipy/viz/__init__.py @@ -1,16 +1,23 @@ # Init file for visualization package from __future__ import division, print_function, absolute_import -# We make the visualization requirements optional imports: -try: - import matplotlib - has_mpl = True -except ImportError: - e_s = "You do not have Matplotlib installed. Some visualization functions" - e_s += " might not work for you." - print(e_s) - has_mpl = False +from dipy.utils.optpkg import optional_package +# Allow import, but disable doctests if we don't have fury +fury, have_fury, _ = optional_package('fury') + + +if have_fury: + from fury import actor, window, widget, colormap, interactor, ui, utils + from fury.window import vtk + from fury.data import (fetch_viz_icons, read_viz_icons, + DATA_DIR as FURY_DATA_DIR) + +# We make the visualization requirements optional imports: +_, has_mpl, _ = optional_package('matplotlib', + "You do not have Matplotlib installed. Some" + " visualization functions might not work for" + " you") if has_mpl: from . import projections diff --git a/dipy/viz/actor.py b/dipy/viz/actor.py deleted file mode 100644 index a7e8cb7abb..0000000000 --- a/dipy/viz/actor.py +++ /dev/null @@ -1,1387 +0,0 @@ -from __future__ import division, print_function, absolute_import - -import numpy as np -from nibabel.affines import apply_affine - -from dipy.viz.colormap import colormap_lookup_table, create_colormap -from dipy.viz.utils import lines_to_vtk_polydata -from dipy.viz.utils import set_input - -# Conditional import machinery for vtk -from dipy.utils.optpkg import optional_package - -# Allow import, but disable doctests if we don't have vtk -vtk, have_vtk, setup_module = optional_package('vtk') -colors, have_vtk_colors, _ = optional_package('vtk.util.colors') -numpy_support, have_ns, _ = optional_package('vtk.util.numpy_support') - -if have_vtk: - - version = vtk.vtkVersion.GetVTKSourceVersion().split(' ')[-1] - major_version = vtk.vtkVersion.GetVTKMajorVersion() - - -def slicer(data, affine=None, value_range=None, opacity=1., - lookup_colormap=None, interpolation='linear', picking_tol=0.025): - """ Cuts 3D scalar or rgb volumes into 2D images - - Parameters - ---------- - data : array, shape (X, Y, Z) or (X, Y, Z, 3) - A grayscale or rgb 4D volume as a numpy array. - affine : array, shape (4, 4) - Grid to space (usually RAS 1mm) transformation matrix. Default is None. - If None then the identity matrix is used. - value_range : None or tuple (2,) - If None then the values will be interpolated from (data.min(), - data.max()) to (0, 255). Otherwise from (value_range[0], - value_range[1]) to (0, 255). - opacity : float, optional - Opacity of 0 means completely transparent and 1 completely visible. - lookup_colormap : vtkLookupTable - If None (default) then a grayscale map is created. - interpolation : string - If 'linear' (default) then linear interpolation is used on the final - texture mapping. If 'nearest' then nearest neighbor interpolation is - used on the final texture mapping. - picking_tol : float - The tolerance for the vtkCellPicker, specified as a fraction of - rendering window size. - - Returns - ------- - image_actor : ImageActor - An object that is capable of displaying different parts of the volume - as slices. The key method of this object is ``display_extent`` where - one can input grid coordinates and display the slice in space (or grid) - coordinates as calculated by the affine parameter. - - """ - if data.ndim != 3: - if data.ndim == 4: - if data.shape[3] != 3: - raise ValueError('Only RGB 3D arrays are currently supported.') - else: - nb_components = 3 - else: - raise ValueError('Only 3D arrays are currently supported.') - else: - nb_components = 1 - - if value_range is None: - vol = np.interp(data, xp=[data.min(), data.max()], fp=[0, 255]) - else: - vol = np.interp(data, xp=[value_range[0], value_range[1]], fp=[0, 255]) - vol = vol.astype('uint8') - - im = vtk.vtkImageData() - if major_version <= 5: - im.SetScalarTypeToUnsignedChar() - I, J, K = vol.shape[:3] - im.SetDimensions(I, J, K) - voxsz = (1., 1., 1.) - # im.SetOrigin(0,0,0) - im.SetSpacing(voxsz[2], voxsz[0], voxsz[1]) - if major_version <= 5: - im.AllocateScalars() - im.SetNumberOfScalarComponents(nb_components) - else: - im.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, nb_components) - - # copy data - # what I do below is the same as what is commented here but much faster - # for index in ndindex(vol.shape): - # i, j, k = index - # im.SetScalarComponentFromFloat(i, j, k, 0, vol[i, j, k]) - vol = np.swapaxes(vol, 0, 2) - vol = np.ascontiguousarray(vol) - - if nb_components == 1: - vol = vol.ravel() - else: - vol = np.reshape(vol, [np.prod(vol.shape[:3]), vol.shape[3]]) - - uchar_array = numpy_support.numpy_to_vtk(vol, deep=0) - im.GetPointData().SetScalars(uchar_array) - - if affine is None: - affine = np.eye(4) - - # Set the transform (identity if none given) - transform = vtk.vtkTransform() - transform_matrix = vtk.vtkMatrix4x4() - transform_matrix.DeepCopy(( - affine[0][0], affine[0][1], affine[0][2], affine[0][3], - affine[1][0], affine[1][1], affine[1][2], affine[1][3], - affine[2][0], affine[2][1], affine[2][2], affine[2][3], - affine[3][0], affine[3][1], affine[3][2], affine[3][3])) - transform.SetMatrix(transform_matrix) - transform.Inverse() - - # Set the reslicing - image_resliced = vtk.vtkImageReslice() - set_input(image_resliced, im) - image_resliced.SetResliceTransform(transform) - image_resliced.AutoCropOutputOn() - - # Adding this will allow to support anisotropic voxels - # and also gives the opportunity to slice per voxel coordinates - RZS = affine[:3, :3] - zooms = np.sqrt(np.sum(RZS * RZS, axis=0)) - image_resliced.SetOutputSpacing(*zooms) - - image_resliced.SetInterpolationModeToLinear() - image_resliced.Update() - - ex1, ex2, ey1, ey2, ez1, ez2 = image_resliced.GetOutput().GetExtent() - - class ImageActor(vtk.vtkImageActor): - def __init__(self): - self.picker = vtk.vtkCellPicker() - - def input_connection(self, output): - if vtk.VTK_MAJOR_VERSION <= 5: - self.SetInput(output.GetOutput()) - else: - self.GetMapper().SetInputConnection(output.GetOutputPort()) - self.output = output - self.shape = (ex2 + 1, ey2 + 1, ez2 + 1) - - def display_extent(self, x1, x2, y1, y2, z1, z2): - self.SetDisplayExtent(x1, x2, y1, y2, z1, z2) - if vtk.VTK_MAJOR_VERSION > 5: - self.Update() - - def display(self, x=None, y=None, z=None): - if x is None and y is None and z is None: - self.display_extent(ex1, ex2, ey1, ey2, ez2//2, ez2//2) - if x is not None: - self.display_extent(x, x, ey1, ey2, ez1, ez2) - if y is not None: - self.display_extent(ex1, ex2, y, y, ez1, ez2) - if z is not None: - self.display_extent(ex1, ex2, ey1, ey2, z, z) - - def opacity(self, value): - if vtk.VTK_MAJOR_VERSION <= 5: - self.SetOpacity(value) - else: - self.GetProperty().SetOpacity(value) - - def tolerance(self, value): - self.picker.SetTolerance(value) - - def copy(self): - im_actor = ImageActor() - im_actor.input_connection(self.output) - im_actor.SetDisplayExtent(*self.GetDisplayExtent()) - im_actor.opacity(self.GetOpacity()) - im_actor.tolerance(self.picker.GetTolerance()) - if interpolation == 'nearest': - im_actor.SetInterpolate(False) - else: - im_actor.SetInterpolate(True) - if major_version >= 6: - im_actor.GetMapper().BorderOn() - return im_actor - - image_actor = ImageActor() - if nb_components == 1: - lut = lookup_colormap - if lookup_colormap is None: - # Create a black/white lookup table. - lut = colormap_lookup_table((0, 255), (0, 0), (0, 0), (0, 1)) - - plane_colors = vtk.vtkImageMapToColors() - plane_colors.SetLookupTable(lut) - plane_colors.SetInputConnection(image_resliced.GetOutputPort()) - plane_colors.Update() - image_actor.input_connection(plane_colors) - else: - image_actor.input_connection(image_resliced) - image_actor.display() - image_actor.opacity(opacity) - image_actor.tolerance(picking_tol) - - if interpolation == 'nearest': - image_actor.SetInterpolate(False) - else: - image_actor.SetInterpolate(True) - - if major_version >= 6: - image_actor.GetMapper().BorderOn() - - return image_actor - - -def contour_from_roi(data, affine=None, - color=np.array([1, 0, 0]), opacity=1): - """Generates surface actor from a binary ROI. - - The color and opacity of the surface can be customized. - - Parameters - ---------- - data : array, shape (X, Y, Z) - An ROI file that will be binarized and displayed. - affine : array, shape (4, 4) - Grid to space (usually RAS 1mm) transformation matrix. Default is None. - If None then the identity matrix is used. - color : (1, 3) ndarray - RGB values in [0,1]. - opacity : float - Opacity of surface between 0 and 1. - - Returns - ------- - contour_assembly : vtkAssembly - ROI surface object displayed in space - coordinates as calculated by the affine parameter. - - """ - - if data.ndim != 3: - raise ValueError('Only 3D arrays are currently supported.') - else: - nb_components = 1 - - data = (data > 0) * 1 - vol = np.interp(data, xp=[data.min(), data.max()], fp=[0, 255]) - vol = vol.astype('uint8') - - im = vtk.vtkImageData() - if major_version <= 5: - im.SetScalarTypeToUnsignedChar() - di, dj, dk = vol.shape[:3] - im.SetDimensions(di, dj, dk) - voxsz = (1., 1., 1.) - # im.SetOrigin(0,0,0) - im.SetSpacing(voxsz[2], voxsz[0], voxsz[1]) - if major_version <= 5: - im.AllocateScalars() - im.SetNumberOfScalarComponents(nb_components) - else: - im.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, nb_components) - - # copy data - vol = np.swapaxes(vol, 0, 2) - vol = np.ascontiguousarray(vol) - - if nb_components == 1: - vol = vol.ravel() - else: - vol = np.reshape(vol, [np.prod(vol.shape[:3]), vol.shape[3]]) - - uchar_array = numpy_support.numpy_to_vtk(vol, deep=0) - im.GetPointData().SetScalars(uchar_array) - - if affine is None: - affine = np.eye(4) - - # Set the transform (identity if none given) - transform = vtk.vtkTransform() - transform_matrix = vtk.vtkMatrix4x4() - transform_matrix.DeepCopy(( - affine[0][0], affine[0][1], affine[0][2], affine[0][3], - affine[1][0], affine[1][1], affine[1][2], affine[1][3], - affine[2][0], affine[2][1], affine[2][2], affine[2][3], - affine[3][0], affine[3][1], affine[3][2], affine[3][3])) - transform.SetMatrix(transform_matrix) - transform.Inverse() - - # Set the reslicing - image_resliced = vtk.vtkImageReslice() - set_input(image_resliced, im) - image_resliced.SetResliceTransform(transform) - image_resliced.AutoCropOutputOn() - - # Adding this will allow to support anisotropic voxels - # and also gives the opportunity to slice per voxel coordinates - - rzs = affine[:3, :3] - zooms = np.sqrt(np.sum(rzs * rzs, axis=0)) - image_resliced.SetOutputSpacing(*zooms) - - image_resliced.SetInterpolationModeToLinear() - image_resliced.Update() - - skin_extractor = vtk.vtkContourFilter() - if major_version <= 5: - skin_extractor.SetInput(image_resliced.GetOutput()) - else: - skin_extractor.SetInputData(image_resliced.GetOutput()) - - skin_extractor.SetValue(0, 1) - - skin_normals = vtk.vtkPolyDataNormals() - skin_normals.SetInputConnection(skin_extractor.GetOutputPort()) - skin_normals.SetFeatureAngle(60.0) - - skin_mapper = vtk.vtkPolyDataMapper() - skin_mapper.SetInputConnection(skin_normals.GetOutputPort()) - skin_mapper.ScalarVisibilityOff() - - skin_actor = vtk.vtkActor() - - skin_actor.SetMapper(skin_mapper) - skin_actor.GetProperty().SetOpacity(opacity) - - skin_actor.GetProperty().SetColor(color[0], color[1], color[2]) - - return skin_actor - - -def streamtube(lines, colors=None, opacity=1, linewidth=0.1, tube_sides=9, - lod=True, lod_points=10 ** 4, lod_points_size=3, - spline_subdiv=None, lookup_colormap=None): - """ Uses streamtubes to visualize polylines - - Parameters - ---------- - lines : list - list of N curves represented as 2D ndarrays - - colors : array (N, 3), list of arrays, tuple (3,), array (K,), None - If None then a standard orientation colormap is used for every line. - If one tuple of color is used. Then all streamlines will have the same - colour. - If an array (N, 3) is given, where N is equal to the number of lines. - Then every line is coloured with a different RGB color. - If a list of RGB arrays is given then every point of every line takes - a different color. - If an array (K, ) is given, where K is the number of points of all - lines then these are considered as the values to be used by the - colormap. - If an array (L, ) is given, where L is the number of streamlines then - these are considered as the values to be used by the colormap per - streamline. - If an array (X, Y, Z) or (X, Y, Z, 3) is given then the values for the - colormap are interpolated automatically using trilinear interpolation. - - opacity : float - Takes values from 0 (fully transparent) to 1 (opaque). Default is 1. - linewidth : float - Default is 0.01. - tube_sides : int - Default is 9. - lod : bool - Use vtkLODActor(level of detail) rather than vtkActor. Default is True. - Level of detail actors do not render the full geometry when the - frame rate is low. - lod_points : int - Number of points to be used when LOD is in effect. Default is 10000. - lod_points_size : int - Size of points when lod is in effect. Default is 3. - spline_subdiv : int - Number of splines subdivision to smooth streamtubes. Default is None. - lookup_colormap : vtkLookupTable - Add a default lookup table to the colormap. Default is None which calls - :func:`dipy.viz.actor.colormap_lookup_table`. - - Examples - -------- - >>> import numpy as np - >>> from dipy.viz import actor, window - >>> ren = window.Renderer() - >>> lines = [np.random.rand(10, 3), np.random.rand(20, 3)] - >>> colors = np.random.rand(2, 3) - >>> c = actor.streamtube(lines, colors) - >>> ren.add(c) - >>> #window.show(ren) - - Notes - ----- - Streamtubes can be heavy on GPU when loading many streamlines and - therefore, you may experience slow rendering time depending on system GPU. - A solution to this problem is to reduce the number of points in each - streamline. In Dipy we provide an algorithm that will reduce the number of - points on the straighter parts of the streamline but keep more points on - the curvier parts. This can be used in the following way:: - - from dipy.tracking.distances import approx_polygon_track - lines = [approx_polygon_track(line, 0.2) for line in lines] - - Alternatively we suggest using the ``line`` actor which is much more - efficient. - - See Also - -------- - :func:`dipy.viz.actor.line` - """ - # Poly data with lines and colors - poly_data, is_colormap = lines_to_vtk_polydata(lines, colors) - next_input = poly_data - - # Set Normals - poly_normals = set_input(vtk.vtkPolyDataNormals(), next_input) - poly_normals.ComputeCellNormalsOn() - poly_normals.ComputePointNormalsOn() - poly_normals.ConsistencyOn() - poly_normals.AutoOrientNormalsOn() - poly_normals.Update() - next_input = poly_normals.GetOutputPort() - - # Spline interpolation - if (spline_subdiv is not None) and (spline_subdiv > 0): - spline_filter = set_input(vtk.vtkSplineFilter(), next_input) - spline_filter.SetSubdivideToSpecified() - spline_filter.SetNumberOfSubdivisions(spline_subdiv) - spline_filter.Update() - next_input = spline_filter.GetOutputPort() - - # Add thickness to the resulting lines - tube_filter = set_input(vtk.vtkTubeFilter(), next_input) - tube_filter.SetNumberOfSides(tube_sides) - tube_filter.SetRadius(linewidth) - # TODO using the line above we will be able to visualize - # streamtubes of varying radius - # tube_filter.SetVaryRadiusToVaryRadiusByScalar() - tube_filter.CappingOn() - tube_filter.Update() - next_input = tube_filter.GetOutputPort() - - # Poly mapper - poly_mapper = set_input(vtk.vtkPolyDataMapper(), next_input) - poly_mapper.ScalarVisibilityOn() - poly_mapper.SetScalarModeToUsePointFieldData() - poly_mapper.SelectColorArray("Colors") - - # Enable only for OpenGL1 rendering backend - if vtk.VTK_MAJOR_VERSION <= 6: - poly_mapper.GlobalImmediateModeRenderingOn() - - poly_mapper.Update() - - # Color Scale with a lookup table - if is_colormap: - if lookup_colormap is None: - lookup_colormap = colormap_lookup_table() - poly_mapper.SetLookupTable(lookup_colormap) - poly_mapper.UseLookupTableScalarRangeOn() - poly_mapper.Update() - - # Set Actor - if lod: - actor = vtk.vtkLODActor() - actor.SetNumberOfCloudPoints(lod_points) - actor.GetProperty().SetPointSize(lod_points_size) - else: - actor = vtk.vtkActor() - - actor.SetMapper(poly_mapper) - - # Use different defaults for OpenGL1 rendering backend - if vtk.VTK_MAJOR_VERSION <= 6: - actor.GetProperty().SetAmbient(0.1) - actor.GetProperty().SetDiffuse(0.15) - actor.GetProperty().SetSpecular(0.05) - actor.GetProperty().SetSpecularPower(6) - - actor.GetProperty().SetInterpolationToPhong() - actor.GetProperty().BackfaceCullingOn() - actor.GetProperty().SetOpacity(opacity) - - return actor - - -def line(lines, colors=None, opacity=1, linewidth=1, - spline_subdiv=None, lod=True, lod_points=10 ** 4, lod_points_size=3, - lookup_colormap=None): - """ Create an actor for one or more lines. - - Parameters - ------------ - lines : list of arrays - - colors : array (N, 3), list of arrays, tuple (3,), array (K,), None - If None then a standard orientation colormap is used for every line. - If one tuple of color is used. Then all streamlines will have the same - colour. - If an array (N, 3) is given, where N is equal to the number of lines. - Then every line is coloured with a different RGB color. - If a list of RGB arrays is given then every point of every line takes - a different color. - If an array (K, ) is given, where K is the number of points of all - lines then these are considered as the values to be used by the - colormap. - If an array (L, ) is given, where L is the number of streamlines then - these are considered as the values to be used by the colormap per - streamline. - If an array (X, Y, Z) or (X, Y, Z, 3) is given then the values for the - colormap are interpolated automatically using trilinear interpolation. - - opacity : float, optional - Takes values from 0 (fully transparent) to 1 (opaque). Default is 1. - - linewidth : float, optional - Line thickness. Default is 1. - spline_subdiv : int, optional - Number of splines subdivision to smooth streamtubes. Default is None - which means no subdivision. - lod : bool - Use vtkLODActor(level of detail) rather than vtkActor. Default is True. - Level of detail actors do not render the full geometry when the - frame rate is low. - lod_points : int - Number of points to be used when LOD is in effect. Default is 10000. - lod_points_size : int - Size of points when lod is in effect. Default is 3. - lookup_colormap : bool, optional - Add a default lookup table to the colormap. Default is None which calls - :func:`dipy.viz.actor.colormap_lookup_table`. - - Returns - ---------- - v : vtkActor or vtkLODActor object - Line. - - Examples - ---------- - >>> from dipy.viz import actor, window - >>> ren = window.Renderer() - >>> lines = [np.random.rand(10, 3), np.random.rand(20, 3)] - >>> colors = np.random.rand(2, 3) - >>> c = actor.line(lines, colors) - >>> ren.add(c) - >>> #window.show(ren) - """ - # Poly data with lines and colors - poly_data, is_colormap = lines_to_vtk_polydata(lines, colors) - next_input = poly_data - - # use spline interpolation - if (spline_subdiv is not None) and (spline_subdiv > 0): - spline_filter = set_input(vtk.vtkSplineFilter(), next_input) - spline_filter.SetSubdivideToSpecified() - spline_filter.SetNumberOfSubdivisions(spline_subdiv) - spline_filter.Update() - next_input = spline_filter.GetOutputPort() - - poly_mapper = set_input(vtk.vtkPolyDataMapper(), next_input) - poly_mapper.ScalarVisibilityOn() - poly_mapper.SetScalarModeToUsePointFieldData() - poly_mapper.SelectColorArray("Colors") - poly_mapper.Update() - - # Color Scale with a lookup table - if is_colormap: - - if lookup_colormap is None: - lookup_colormap = colormap_lookup_table() - - poly_mapper.SetLookupTable(lookup_colormap) - poly_mapper.UseLookupTableScalarRangeOn() - poly_mapper.Update() - - # Set Actor - if lod: - actor = vtk.vtkLODActor() - actor.SetNumberOfCloudPoints(lod_points) - actor.GetProperty().SetPointSize(lod_points_size) - else: - actor = vtk.vtkActor() - - # actor = vtk.vtkActor() - actor.SetMapper(poly_mapper) - actor.GetProperty().SetLineWidth(linewidth) - actor.GetProperty().SetOpacity(opacity) - - return actor - - -def scalar_bar(lookup_table=None, title=" "): - """ Default scalar bar actor for a given colormap (colorbar) - - Parameters - ---------- - lookup_table : vtkLookupTable or None - If None then ``colormap_lookup_table`` is called with default options. - title : str - - Returns - ------- - scalar_bar : vtkScalarBarActor - - See Also - -------- - :func:`dipy.viz.actor.colormap_lookup_table` - - """ - lookup_table_copy = vtk.vtkLookupTable() - if lookup_table is None: - lookup_table = colormap_lookup_table() - # Deepcopy the lookup_table because sometimes vtkPolyDataMapper deletes it - lookup_table_copy.DeepCopy(lookup_table) - scalar_bar = vtk.vtkScalarBarActor() - scalar_bar.SetTitle(title) - scalar_bar.SetLookupTable(lookup_table_copy) - scalar_bar.SetNumberOfLabels(6) - - return scalar_bar - - -def _arrow(pos=(0, 0, 0), color=(1, 0, 0), scale=(1, 1, 1), opacity=1): - """ Internal function for generating arrow actors. - """ - arrow = vtk.vtkArrowSource() - # arrow.SetTipLength(length) - - arrowm = vtk.vtkPolyDataMapper() - - if major_version <= 5: - arrowm.SetInput(arrow.GetOutput()) - else: - arrowm.SetInputConnection(arrow.GetOutputPort()) - - arrowa = vtk.vtkActor() - arrowa.SetMapper(arrowm) - - arrowa.GetProperty().SetColor(color) - arrowa.GetProperty().SetOpacity(opacity) - arrowa.SetScale(scale) - - return arrowa - - -def axes(scale=(1, 1, 1), colorx=(1, 0, 0), colory=(0, 1, 0), colorz=(0, 0, 1), - opacity=1): - """ Create an actor with the coordinate's system axes where - red = x, green = y, blue = z. - - Parameters - ---------- - scale : tuple (3,) - Axes size e.g. (100, 100, 100). Default is (1, 1, 1). - colorx : tuple (3,) - x-axis color. Default red (1, 0, 0). - colory : tuple (3,) - y-axis color. Default green (0, 1, 0). - colorz : tuple (3,) - z-axis color. Default blue (0, 0, 1). - opacity : float, optional - Takes values from 0 (fully transparent) to 1 (opaque). Default is 1. - - Returns - ------- - vtkAssembly - """ - - arrowx = _arrow(color=colorx, scale=scale, opacity=opacity) - arrowy = _arrow(color=colory, scale=scale, opacity=opacity) - arrowz = _arrow(color=colorz, scale=scale, opacity=opacity) - - arrowy.RotateZ(90) - arrowz.RotateY(-90) - - ass = vtk.vtkAssembly() - ass.AddPart(arrowx) - ass.AddPart(arrowy) - ass.AddPart(arrowz) - - return ass - - -def odf_slicer(odfs, affine=None, mask=None, sphere=None, scale=2.2, - norm=True, radial_scale=True, opacity=1., - colormap='plasma', global_cm=False): - """ Slice spherical fields in native or world coordinates - - Parameters - ---------- - odfs : ndarray - 4D array of spherical functions - affine : array - 4x4 transformation array from native coordinates to world coordinates - mask : ndarray - 3D mask - sphere : Sphere - a sphere - scale : float - Distance between spheres. - norm : bool - Normalize `sphere_values`. - radial_scale : bool - Scale sphere points according to odf values. - opacity : float - Takes values from 0 (fully transparent) to 1 (opaque). Default is 1. - colormap : None or str - If None then white color is used. Otherwise the name of colormap is - given. Matplotlib colormaps are supported (e.g., 'inferno'). - global_cm : bool - If True the colormap will be applied in all ODFs. If False - it will be applied individually at each voxel (default False). - - Returns - --------- - actor : vtkActor - Spheres - """ - - if mask is None: - mask = np.ones(odfs.shape[:3], dtype=np.bool) - else: - mask = mask.astype(np.bool) - - szx, szy, szz = odfs.shape[:3] - - class OdfSlicerActor(vtk.vtkLODActor): - - def display_extent(self, x1, x2, y1, y2, z1, z2): - tmp_mask = np.zeros(odfs.shape[:3], dtype=np.bool) - tmp_mask[x1:x2 + 1, y1:y2 + 1, z1:z2 + 1] = True - tmp_mask = np.bitwise_and(tmp_mask, mask) - - self.mapper = _odf_slicer_mapper(odfs=odfs, - affine=affine, - mask=tmp_mask, - sphere=sphere, - scale=scale, - norm=norm, - radial_scale=radial_scale, - opacity=opacity, - colormap=colormap, - global_cm=global_cm) - self.SetMapper(self.mapper) - - def display(self, x=None, y=None, z=None): - if x is None and y is None and z is None: - self.display_extent(0, szx - 1, 0, szy - 1, - int(np.floor(szz/2)), int(np.floor(szz/2))) - if x is not None: - self.display_extent(x, x, 0, szy - 1, 0, szz - 1) - if y is not None: - self.display_extent(0, szx - 1, y, y, 0, szz - 1) - if z is not None: - self.display_extent(0, szx - 1, 0, szy - 1, z, z) - - odf_actor = OdfSlicerActor() - odf_actor.display_extent(0, szx - 1, 0, szy - 1, - int(np.floor(szz/2)), int(np.floor(szz/2))) - - return odf_actor - - -def _odf_slicer_mapper(odfs, affine=None, mask=None, sphere=None, scale=2.2, - norm=True, radial_scale=True, opacity=1., - colormap='plasma', global_cm=False): - """ Helper function for slicing spherical fields - - Parameters - ---------- - odfs : ndarray - 4D array of spherical functions - affine : array - 4x4 transformation array from native coordinates to world coordinates - mask : ndarray - 3D mask - sphere : Sphere - a sphere - scale : float - Distance between spheres. - norm : bool - Normalize `sphere_values`. - radial_scale : bool - Scale sphere points according to odf values. - opacity : float - Takes values from 0 (fully transparent) to 1 (opaque) - colormap : None or str - If None then white color is used. Otherwise the name of colormap is - given. Matplotlib colormaps are supported (e.g., 'inferno'). - global_cm : bool - If True the colormap will be applied in all ODFs. If False - it will be applied individually at each voxel (default False). - - Returns - --------- - mapper : vtkPolyDataMapper - Spheres mapper - """ - if mask is None: - mask = np.ones(odfs.shape[:3]) - - ijk = np.ascontiguousarray(np.array(np.nonzero(mask)).T) - - if len(ijk) == 0: - return None - - if affine is not None: - ijk = np.ascontiguousarray(apply_affine(affine, ijk)) - - faces = np.asarray(sphere.faces, dtype=int) - vertices = sphere.vertices - - all_xyz = [] - all_faces = [] - all_ms = [] - for (k, center) in enumerate(ijk): - - m = odfs[tuple(center.astype(np.int))].copy() - - if norm: - m /= np.abs(m).max() - - if radial_scale: - xyz = vertices * m[:, None] - else: - xyz = vertices.copy() - - all_xyz.append(scale * xyz + center) - all_faces.append(faces + k * xyz.shape[0]) - all_ms.append(m) - - all_xyz = np.ascontiguousarray(np.concatenate(all_xyz)) - all_xyz_vtk = numpy_support.numpy_to_vtk(all_xyz, deep=True) - - all_faces = np.concatenate(all_faces) - all_faces = np.hstack((3 * np.ones((len(all_faces), 1)), - all_faces)) - ncells = len(all_faces) - - all_faces = np.ascontiguousarray(all_faces.ravel(), dtype='i8') - all_faces_vtk = numpy_support.numpy_to_vtkIdTypeArray(all_faces, - deep=True) - if global_cm: - all_ms = np.ascontiguousarray( - np.concatenate(all_ms), dtype='f4') - - points = vtk.vtkPoints() - points.SetData(all_xyz_vtk) - - cells = vtk.vtkCellArray() - cells.SetCells(ncells, all_faces_vtk) - - if colormap is not None: - if global_cm: - cols = create_colormap(all_ms.ravel(), colormap) - else: - cols = np.zeros((ijk.shape[0],) + sphere.vertices.shape, - dtype='f4') - for k in range(ijk.shape[0]): - tmp = create_colormap(all_ms[k].ravel(), colormap) - cols[k] = tmp.copy() - - cols = np.ascontiguousarray( - np.reshape(cols, (cols.shape[0] * cols.shape[1], - cols.shape[2])), dtype='f4') - - vtk_colors = numpy_support.numpy_to_vtk( - np.asarray(255 * cols), - deep=True, - array_type=vtk.VTK_UNSIGNED_CHAR) - - vtk_colors.SetName("Colors") - - polydata = vtk.vtkPolyData() - polydata.SetPoints(points) - polydata.SetPolys(cells) - - if colormap is not None: - polydata.GetPointData().SetScalars(vtk_colors) - - mapper = vtk.vtkPolyDataMapper() - if major_version <= 5: - mapper.SetInput(polydata) - else: - mapper.SetInputData(polydata) - - return mapper - - -def _makeNd(array, ndim): - """Pads as many 1s at the beginning of array's shape as are need to give - array ndim dimensions.""" - new_shape = (1,) * (ndim - array.ndim) + array.shape - return array.reshape(new_shape) - - -def tensor_slicer(evals, evecs, affine=None, mask=None, sphere=None, scale=2.2, - norm=True, opacity=1., scalar_colors=None): - """ Slice many tensors as ellipsoids in native or world coordinates - - Parameters - ---------- - evals : (3,) or (X, 3) or (X, Y, 3) or (X, Y, Z, 3) ndarray - eigenvalues - evecs : (3, 3) or (X, 3, 3) or (X, Y, 3, 3) or (X, Y, Z, 3, 3) ndarray - eigenvectors - affine : array - 4x4 transformation array from native coordinates to world coordinates* - mask : ndarray - 3D mask - sphere : Sphere - a sphere - scale : float - Distance between spheres. - norm : bool - Normalize `sphere_values`. - opacity : float - Takes values from 0 (fully transparent) to 1 (opaque). Default is 1. - scalar_colors : (3,) or (X, 3) or (X, Y, 3) or (X, Y, Z, 3) ndarray - RGB colors used to show the tensors - Default None, color the ellipsoids using ``color_fa`` - - Returns - --------- - actor : vtkActor - Ellipsoid - """ - - if mask is None: - mask = np.ones(evals.shape[:3], dtype=np.bool) - else: - mask = mask.astype(np.bool) - - szx, szy, szz = evals.shape[:3] - - class TensorSlicerActor(vtk.vtkLODActor): - - def display_extent(self, x1, x2, y1, y2, z1, z2): - tmp_mask = np.zeros(evals.shape[:3], dtype=np.bool) - tmp_mask[x1:x2 + 1, y1:y2 + 1, z1:z2 + 1] = True - tmp_mask = np.bitwise_and(tmp_mask, mask) - - self.mapper = _tensor_slicer_mapper(evals=evals, - evecs=evecs, - affine=affine, - mask=tmp_mask, - sphere=sphere, - scale=scale, - norm=norm, - opacity=opacity, - scalar_colors=scalar_colors) - self.SetMapper(self.mapper) - - def display(self, x=None, y=None, z=None): - if x is None and y is None and z is None: - self.display_extent(0, szx - 1, 0, szy - 1, - int(np.floor(szz/2)), int(np.floor(szz/2))) - if x is not None: - self.display_extent(x, x, 0, szy - 1, 0, szz - 1) - if y is not None: - self.display_extent(0, szx - 1, y, y, 0, szz - 1) - if z is not None: - self.display_extent(0, szx - 1, 0, szy - 1, z, z) - - tensor_actor = TensorSlicerActor() - tensor_actor.display_extent(0, szx - 1, 0, szy - 1, - int(np.floor(szz/2)), int(np.floor(szz/2))) - - return tensor_actor - - -def _tensor_slicer_mapper(evals, evecs, affine=None, mask=None, sphere=None, scale=2.2, - norm=True, opacity=1., scalar_colors=None): - """ Helper function for slicing tensor fields - - Parameters - ---------- - evals : (3,) or (X, 3) or (X, Y, 3) or (X, Y, Z, 3) ndarray - eigenvalues - evecs : (3, 3) or (X, 3, 3) or (X, Y, 3, 3) or (X, Y, Z, 3, 3) ndarray - eigenvectors - affine : array - 4x4 transformation array from native coordinates to world coordinates - mask : ndarray - 3D mask - sphere : Sphere - a sphere - scale : float - Distance between spheres. - norm : bool - Normalize `sphere_values`. - opacity : float - Takes values from 0 (fully transparent) to 1 (opaque) - scalar_colors : (3,) or (X, 3) or (X, Y, 3) or (X, Y, Z, 3) ndarray - RGB colors used to show the tensors - Default None, color the ellipsoids using ``color_fa`` - - Returns - --------- - mapper : vtkPolyDataMapper - Ellipsoid mapper - """ - if mask is None: - mask = np.ones(evals.shape[:3]) - - ijk = np.ascontiguousarray(np.array(np.nonzero(mask)).T) - if len(ijk) == 0: - return None - - if affine is not None: - ijk = np.ascontiguousarray(apply_affine(affine, ijk)) - - faces = np.asarray(sphere.faces, dtype=int) - vertices = sphere.vertices - - if scalar_colors is None: - from dipy.reconst.dti import color_fa, fractional_anisotropy - cfa = color_fa(fractional_anisotropy(evals), evecs) - else: - cfa = _makeNd(scalar_colors, 4) - - cols = np.zeros((ijk.shape[0],) + sphere.vertices.shape, - dtype='f4') - - all_xyz = [] - all_faces = [] - for (k, center) in enumerate(ijk): - ea = evals[tuple(center.astype(np.int))] - if norm: - ea /= ea.max() - ea = np.diag(ea.copy()) - - ev = evecs[tuple(center.astype(np.int))].copy() - xyz = np.dot(ev, np.dot(ea, vertices.T)) - - xyz = xyz.T - all_xyz.append(scale * xyz + center) - all_faces.append(faces + k * xyz.shape[0]) - - cols[k, ...] = np.interp(cfa[tuple(center.astype(np.int))], [0, 1], [0, 255]).astype('ubyte') - - all_xyz = np.ascontiguousarray(np.concatenate(all_xyz)) - all_xyz_vtk = numpy_support.numpy_to_vtk(all_xyz, deep=True) - - all_faces = np.concatenate(all_faces) - all_faces = np.hstack((3 * np.ones((len(all_faces), 1)), - all_faces)) - ncells = len(all_faces) - - all_faces = np.ascontiguousarray(all_faces.ravel(), dtype='i8') - all_faces_vtk = numpy_support.numpy_to_vtkIdTypeArray(all_faces, - deep=True) - - points = vtk.vtkPoints() - points.SetData(all_xyz_vtk) - - cells = vtk.vtkCellArray() - cells.SetCells(ncells, all_faces_vtk) - - cols = np.ascontiguousarray( - np.reshape(cols, (cols.shape[0] * cols.shape[1], - cols.shape[2])), dtype='f4') - - vtk_colors = numpy_support.numpy_to_vtk( - cols, - deep=True, - array_type=vtk.VTK_UNSIGNED_CHAR) - - vtk_colors.SetName("Colors") - - polydata = vtk.vtkPolyData() - polydata.SetPoints(points) - polydata.SetPolys(cells) - polydata.GetPointData().SetScalars(vtk_colors) - - mapper = vtk.vtkPolyDataMapper() - if major_version <= 5: - mapper.SetInput(polydata) - else: - mapper.SetInputData(polydata) - - return mapper - - -def peak_slicer(peaks_dirs, peaks_values=None, mask=None, affine=None, - colors=(1, 0, 0), opacity=1., linewidth=1, - lod=False, lod_points=10 ** 4, lod_points_size=3): - """ Visualize peak directions as given from ``peaks_from_model`` - - Parameters - ---------- - peaks_dirs : ndarray - Peak directions. The shape of the array can be (M, 3) or (X, M, 3) or - (X, Y, M, 3) or (X, Y, Z, M, 3) - peaks_values : ndarray - Peak values. The shape of the array can be (M, ) or (X, M) or - (X, Y, M) or (X, Y, Z, M) - affine : array - 4x4 transformation array from native coordinates to world coordinates - mask : ndarray - 3D mask - colors : tuple or None - Default red color. If None then every peak gets an orientation color - in similarity to a DEC map. - - opacity : float, optional - Takes values from 0 (fully transparent) to 1 (opaque) - - linewidth : float, optional - Line thickness. Default is 1. - - lod : bool - Use vtkLODActor(level of detail) rather than vtkActor. - Default is False. Level of detail actors do not render the full - geometry when the frame rate is low. - lod_points : int - Number of points to be used when LOD is in effect. Default is 10000. - lod_points_size : int - Size of points when lod is in effect. Default is 3. - - Returns - ------- - vtkActor - - See Also - -------- - dipy.viz.actor.odf_slicer - - """ - peaks_dirs = np.asarray(peaks_dirs) - if peaks_dirs.ndim > 5: - raise ValueError("Wrong shape") - - peaks_dirs = _makeNd(peaks_dirs, 5) - if peaks_values is not None: - peaks_values = _makeNd(peaks_values, 4) - - grid_shape = np.array(peaks_dirs.shape[:3]) - - if mask is None: - mask = np.ones(grid_shape).astype(np.bool) - - class PeakSlicerActor(vtk.vtkLODActor): - - def display_extent(self, x1, x2, y1, y2, z1, z2): - - tmp_mask = np.zeros(grid_shape, dtype=np.bool) - tmp_mask[x1:x2 + 1, y1:y2 + 1, z1:z2 + 1] = True - tmp_mask = np.bitwise_and(tmp_mask, mask) - - ijk = np.ascontiguousarray(np.array(np.nonzero(tmp_mask)).T) - if len(ijk) == 0: - self.SetMapper(None) - return - if affine is not None: - ijk_trans = np.ascontiguousarray(apply_affine(affine, ijk)) - list_dirs = [] - for index, center in enumerate(ijk): - # center = tuple(center) - if affine is None: - xyz = center[:, None] - else: - xyz = ijk_trans[index][:, None] - xyz = xyz.T - for i in range(peaks_dirs[tuple(center)].shape[-2]): - - if peaks_values is not None: - pv = peaks_values[tuple(center)][i] - else: - pv = 1. - symm = np.vstack((-peaks_dirs[tuple(center)][i] * pv + xyz, - peaks_dirs[tuple(center)][i] * pv + xyz)) - list_dirs.append(symm) - - self.mapper = line(list_dirs, colors=colors, - opacity=opacity, linewidth=linewidth, - lod=lod, lod_points=lod_points, - lod_points_size=lod_points_size).GetMapper() - self.SetMapper(self.mapper) - - def display(self, x=None, y=None, z=None): - if x is None and y is None and z is None: - self.display_extent(0, szx - 1, 0, szy - 1, - int(np.floor(szz/2)), int(np.floor(szz/2))) - if x is not None: - self.display_extent(x, x, 0, szy - 1, 0, szz - 1) - if y is not None: - self.display_extent(0, szx - 1, y, y, 0, szz - 1) - if z is not None: - self.display_extent(0, szx - 1, 0, szy - 1, z, z) - - peak_actor = PeakSlicerActor() - - szx, szy, szz = grid_shape - peak_actor.display_extent(0, szx - 1, 0, szy - 1, - int(np.floor(szz / 2)), int(np.floor(szz / 2))) - - return peak_actor - - -def dots(points, color=(1, 0, 0), opacity=1, dot_size=5): - """ Create one or more 3d points - - Parameters - ---------- - points : ndarray, (N, 3) - color : tuple (3,) - opacity : float, optional - Takes values from 0 (fully transparent) to 1 (opaque) - dot_size : int - - Returns - -------- - vtkActor - - See Also - --------- - dipy.viz.actor.point - - """ - - if points.ndim == 2: - points_no = points.shape[0] - else: - points_no = 1 - - polyVertexPoints = vtk.vtkPoints() - polyVertexPoints.SetNumberOfPoints(points_no) - aPolyVertex = vtk.vtkPolyVertex() - aPolyVertex.GetPointIds().SetNumberOfIds(points_no) - - cnt = 0 - if points.ndim > 1: - for point in points: - polyVertexPoints.InsertPoint(cnt, point[0], point[1], point[2]) - aPolyVertex.GetPointIds().SetId(cnt, cnt) - cnt += 1 - else: - polyVertexPoints.InsertPoint(cnt, points[0], points[1], points[2]) - aPolyVertex.GetPointIds().SetId(cnt, cnt) - cnt += 1 - - aPolyVertexGrid = vtk.vtkUnstructuredGrid() - aPolyVertexGrid.Allocate(1, 1) - aPolyVertexGrid.InsertNextCell(aPolyVertex.GetCellType(), - aPolyVertex.GetPointIds()) - - aPolyVertexGrid.SetPoints(polyVertexPoints) - aPolyVertexMapper = vtk.vtkDataSetMapper() - if major_version <= 5: - aPolyVertexMapper.SetInput(aPolyVertexGrid) - else: - aPolyVertexMapper.SetInputData(aPolyVertexGrid) - aPolyVertexActor = vtk.vtkActor() - aPolyVertexActor.SetMapper(aPolyVertexMapper) - - aPolyVertexActor.GetProperty().SetColor(color) - aPolyVertexActor.GetProperty().SetOpacity(opacity) - aPolyVertexActor.GetProperty().SetPointSize(dot_size) - return aPolyVertexActor - - -def point(points, colors, opacity=1., point_radius=0.1, theta=8, phi=8): - """ Visualize points as sphere glyphs - - Parameters - ---------- - points : ndarray, shape (N, 3) - colors : ndarray (N,3) or tuple (3,) - point_radius : float - theta : int - phi : int - opacity : float, optional - Takes values from 0 (fully transparent) to 1 (opaque) - - Returns - ------- - vtkActor - - Examples - -------- - >>> from dipy.viz import window, actor - >>> ren = window.Renderer() - >>> pts = np.random.rand(5, 3) - >>> point_actor = actor.point(pts, window.colors.coral) - >>> ren.add(point_actor) - >>> #window.show(ren) - """ - - if np.array(colors).ndim == 1: - # return dots(points,colors,opacity) - colors = np.tile(colors, (len(points), 1)) - - scalars = vtk.vtkUnsignedCharArray() - scalars.SetNumberOfComponents(3) - - pts = vtk.vtkPoints() - cnt_colors = 0 - - for p in points: - - pts.InsertNextPoint(p[0], p[1], p[2]) - scalars.InsertNextTuple3( - round(255 * colors[cnt_colors][0]), - round(255 * colors[cnt_colors][1]), - round(255 * colors[cnt_colors][2])) - cnt_colors += 1 - - src = vtk.vtkSphereSource() - src.SetRadius(point_radius) - src.SetThetaResolution(theta) - src.SetPhiResolution(phi) - - polyData = vtk.vtkPolyData() - polyData.SetPoints(pts) - polyData.GetPointData().SetScalars(scalars) - - glyph = vtk.vtkGlyph3D() - glyph.SetSourceConnection(src.GetOutputPort()) - if major_version <= 5: - glyph.SetInput(polyData) - else: - glyph.SetInputData(polyData) - glyph.SetColorModeToColorByScalar() - glyph.SetScaleModeToDataScalingOff() - glyph.Update() - - mapper = vtk.vtkPolyDataMapper() - if major_version <= 5: - mapper.SetInput(glyph.GetOutput()) - else: - mapper.SetInputData(glyph.GetOutput()) - actor = vtk.vtkActor() - actor.SetMapper(mapper) - actor.GetProperty().SetOpacity(opacity) - - return actor - - -def label(text='Origin', pos=(0, 0, 0), scale=(0.2, 0.2, 0.2), - color=(1, 1, 1)): - """ Create a label actor. - - This actor will always face the camera - - Parameters - ---------- - text : str - Text for the label. - pos : (3,) array_like, optional - Left down position of the label. - scale : (3,) array_like - Changes the size of the label. - color : (3,) array_like - Label color as ``(r,g,b)`` tuple. - - Returns - ------- - l : vtkActor object - Label. - - Examples - -------- - >>> from dipy.viz import window, actor - >>> ren = window.Renderer() - >>> l = actor.label(text='Hello') - >>> ren.add(l) - >>> #window.show(ren) - """ - - atext = vtk.vtkVectorText() - atext.SetText(text) - - textm = vtk.vtkPolyDataMapper() - if major_version <= 5: - textm.SetInput(atext.GetOutput()) - else: - textm.SetInputData(atext.GetOutput()) - - texta = vtk.vtkFollower() - texta.SetMapper(textm) - texta.SetScale(scale) - - texta.GetProperty().SetColor(color) - texta.SetPosition(pos) - - return texta diff --git a/dipy/viz/colormap.py b/dipy/viz/colormap.py deleted file mode 100644 index 1313080c29..0000000000 --- a/dipy/viz/colormap.py +++ /dev/null @@ -1,319 +0,0 @@ -import numpy as np - -# Conditional import machinery for vtk -from dipy.utils.optpkg import optional_package - -# Allow import, but disable doctests if we don't have vtk -vtk, have_vtk, setup_module = optional_package('vtk') -cm, have_matplotlib, _ = optional_package('matplotlib.cm') - -if have_matplotlib: - get_cmap = cm.get_cmap -else: - from dipy.data import get_cmap -from warnings import warn - - -def colormap_lookup_table(scale_range=(0, 1), hue_range=(0.8, 0), - saturation_range=(1, 1), value_range=(0.8, 0.8)): - """ Lookup table for the colormap - - Parameters - ---------- - scale_range : tuple - It can be anything e.g. (0, 1) or (0, 255). Usually it is the mininum - and maximum value of your data. Default is (0, 1). - hue_range : tuple of floats - HSV values (min 0 and max 1). Default is (0.8, 0). - saturation_range : tuple of floats - HSV values (min 0 and max 1). Default is (1, 1). - value_range : tuple of floats - HSV value (min 0 and max 1). Default is (0.8, 0.8). - - Returns - ------- - lookup_table : vtkLookupTable - - """ - lookup_table = vtk.vtkLookupTable() - lookup_table.SetRange(scale_range) - lookup_table.SetTableRange(scale_range) - - lookup_table.SetHueRange(hue_range) - lookup_table.SetSaturationRange(saturation_range) - lookup_table.SetValueRange(value_range) - - lookup_table.Build() - return lookup_table - - -def cc(na, nd): - return (na * np.cos(nd * np.pi / 180.0)) - - -def ss(na, nd): - return na * np.sin(nd * np.pi / 180.0) - - -def boys2rgb(v): - """ boys 2 rgb cool colormap - - Maps a given field of undirected lines (line field) to rgb - colors using Boy's Surface immersion of the real projective - plane. - Boy's Surface is one of the three possible surfaces - obtained by gluing a Mobius strip to the edge of a disk. - The other two are the crosscap and Roman surface, - Steiner surfaces that are homeomorphic to the real - projective plane (Pinkall 1986). The Boy's surface - is the only 3D immersion of the projective plane without - singularities. - Visit http://www.cs.brown.edu/~cad/rp2coloring for further details. - Cagatay Demiralp, 9/7/2008. - - Code was initially in matlab and was rewritten in Python for dipy by - the Dipy Team. Thank you Cagatay for putting this online. - - Parameters - ------------ - v : array, shape (N, 3) of unit vectors (e.g., principal eigenvectors of - tensor data) representing one of the two directions of the - undirected lines in a line field. - - Returns - --------- - c : array, shape (N, 3) matrix of rgb colors corresponding to the vectors - given in V. - - Examples - ---------- - - >>> from dipy.viz import colormap - >>> v = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - >>> c = colormap.boys2rgb(v) - """ - - if v.ndim == 1: - x = v[0] - y = v[1] - z = v[2] - - if v.ndim == 2: - x = v[:, 0] - y = v[:, 1] - z = v[:, 2] - - x2 = x ** 2 - y2 = y ** 2 - z2 = z ** 2 - - x3 = x * x2 - y3 = y * y2 - z3 = z * z2 - - z4 = z * z2 - - xy = x * y - xz = x * z - yz = y * z - - hh1 = .5 * (3 * z2 - 1) / 1.58 - hh2 = 3 * xz / 2.745 - hh3 = 3 * yz / 2.745 - hh4 = 1.5 * (x2 - y2) / 2.745 - hh5 = 6 * xy / 5.5 - hh6 = (1 / 1.176) * .125 * (35 * z4 - 30 * z2 + 3) - hh7 = 2.5 * x * (7 * z3 - 3 * z) / 3.737 - hh8 = 2.5 * y * (7 * z3 - 3 * z) / 3.737 - hh9 = ((x2 - y2) * 7.5 * (7 * z2 - 1)) / 15.85 - hh10 = ((2 * xy) * (7.5 * (7 * z2 - 1))) / 15.85 - hh11 = 105 * (4 * x3 * z - 3 * xz * (1 - z2)) / 59.32 - hh12 = 105 * (-4 * y3 * z + 3 * yz * (1 - z2)) / 59.32 - - s0 = -23.0 - s1 = 227.9 - s2 = 251.0 - s3 = 125.0 - - ss23 = ss(2.71, s0) - cc23 = cc(2.71, s0) - ss45 = ss(2.12, s1) - cc45 = cc(2.12, s1) - ss67 = ss(.972, s2) - cc67 = cc(.972, s2) - ss89 = ss(.868, s3) - cc89 = cc(.868, s3) - - X = 0.0 - - X = X + hh2 * cc23 - X = X + hh3 * ss23 - - X = X + hh5 * cc45 - X = X + hh4 * ss45 - - X = X + hh7 * cc67 - X = X + hh8 * ss67 - - X = X + hh10 * cc89 - X = X + hh9 * ss89 - - Y = 0.0 - - Y = Y + hh2 * -ss23 - Y = Y + hh3 * cc23 - - Y = Y + hh5 * -ss45 - Y = Y + hh4 * cc45 - - Y = Y + hh7 * -ss67 - Y = Y + hh8 * cc67 - - Y = Y + hh10 * -ss89 - Y = Y + hh9 * cc89 - - Z = 0.0 - - Z = Z + hh1 * -2.8 - Z = Z + hh6 * -0.5 - Z = Z + hh11 * 0.3 - Z = Z + hh12 * -2.5 - - # scale and normalize to fit - # in the rgb space - - w_x = 4.1925 - trl_x = -2.0425 - w_y = 4.0217 - trl_y = -1.8541 - w_z = 4.0694 - trl_z = -2.1899 - - if v.ndim == 2: - - N = len(x) - C = np.zeros((N, 3)) - - C[:, 0] = 0.9 * np.abs(((X - trl_x) / w_x)) + 0.05 - C[:, 1] = 0.9 * np.abs(((Y - trl_y) / w_y)) + 0.05 - C[:, 2] = 0.9 * np.abs(((Z - trl_z) / w_z)) + 0.05 - - if v.ndim == 1: - - C = np.zeros((3,)) - C[0] = 0.9 * np.abs(((X - trl_x) / w_x)) + 0.05 - C[1] = 0.9 * np.abs(((Y - trl_y) / w_y)) + 0.05 - C[2] = 0.9 * np.abs(((Z - trl_z) / w_z)) + 0.05 - - return C - - -def orient2rgb(v): - """ standard orientation 2 rgb colormap - - v : array, shape (N, 3) of vectors not necessarily normalized - - Returns - ------- - - c : array, shape (N, 3) matrix of rgb colors corresponding to the vectors - given in V. - - Examples - -------- - - >>> from dipy.viz import colormap - >>> v = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - >>> c = colormap.orient2rgb(v) - - """ - - if v.ndim == 1: - orient = v - orient = np.abs(orient / np.linalg.norm(orient)) - - if v.ndim == 2: - orientn = np.sqrt(v[:, 0] ** 2 + v[:, 1] ** 2 + v[:, 2] ** 2) - orientn.shape = orientn.shape + (1,) - orient = np.abs(v / orientn) - - return orient - - -def line_colors(streamlines, cmap='rgb_standard'): - """ Create colors for streamlines to be used in actor.line - - Parameters - ---------- - streamlines : sequence of ndarrays - cmap : ('rgb_standard', 'boys_standard') - - Returns - ------- - colors : ndarray - """ - - if cmap == 'rgb_standard': - col_list = [orient2rgb(streamline[-1] - streamline[0]) - for streamline in streamlines] - - if cmap == 'boys_standard': - col_list = [boys2rgb(streamline[-1] - streamline[0]) - for streamline in streamlines] - - return np.vstack(col_list) - - -lowercase_cm_name = {'blues': 'Blues', 'accent': 'Accent'} - - -def create_colormap(v, name='plasma', auto=True): - """Create colors from a specific colormap and return it - as an array of shape (N,3) where every row gives the corresponding - r,g,b value. The colormaps we use are similar with those of matplotlib. - - Parameters - ---------- - v : (N,) array - vector of values to be mapped in RGB colors according to colormap - name : str. - Name of the colormap. Currently implemented: 'jet', 'blues', - 'accent', 'bone' and matplotlib colormaps if you have matplotlib - installed. For example, we suggest using 'plasma', 'viridis' or - 'inferno'. 'jet' is popular but can be often misleading and we will - deprecate it the future. - auto : bool, - if auto is True then v is interpolated to [0, 10] from v.min() - to v.max() - - Notes - ----- - Dipy supports a few colormaps for those who do not use Matplotlib, for - more colormaps consider downloading Matplotlib (see matplotlib.org). - """ - - if name == 'jet': - msg = 'Jet is a popular colormap but can often be misleading' - msg += 'Use instead plasma, viridis, hot or inferno.' - warn(msg, DeprecationWarning) - - if v.ndim > 1: - msg = 'This function works only with 1d arrays. Use ravel()' - raise ValueError(msg) - - if auto: - v = np.interp(v, [v.min(), v.max()], [0, 1]) - else: - v = np.clip(v, 0, 1) - - # For backwards compatibility with lowercase names - newname = lowercase_cm_name.get(name) or name - - colormap = get_cmap(newname) - if colormap is None: - e_s = "Colormap {} is not yet implemented ".format(name) - raise ValueError(e_s) - - rgba = colormap(v) - rgb = rgba[:, :3].copy() - return rgb diff --git a/dipy/viz/fvtk.py b/dipy/viz/fvtk.py deleted file mode 100644 index 898f2adba3..0000000000 --- a/dipy/viz/fvtk.py +++ /dev/null @@ -1,933 +0,0 @@ -""" Fvtk module implements simple visualization functions using VTK. - -The main idea is the following: -A window can have one or more renderers. A renderer can have none, -one or more actors. Examples of actors are a sphere, line, point etc. -You basically add actors in a renderer and in that way you can -visualize the forementioned objects e.g. sphere, line ... - -Examples ---------- ->>> from dipy.viz import fvtk ->>> r=fvtk.ren() ->>> a=fvtk.axes() ->>> fvtk.add(r,a) ->>> #fvtk.show(r) - -For more information on VTK there many neat examples in -http://www.vtk.org/Wiki/VTK/Tutorials/External_Tutorials -""" -from __future__ import division, print_function, absolute_import -from warnings import warn - -from dipy.utils.six.moves import xrange - -import numpy as np - -from dipy.core.ndindex import ndindex - -# Conditional import machinery for vtk -from dipy.utils.optpkg import optional_package - -# Allow import, but disable doctests if we don't have vtk -vtk, have_vtk, setup_module = optional_package('vtk') -colors, have_vtk_colors, _ = optional_package('vtk.util.colors') - -cm, have_matplotlib, _ = optional_package('matplotlib.cm') - -if have_matplotlib: - get_cmap = cm.get_cmap -else: - from dipy.data import get_cmap - -from dipy.viz.colormap import create_colormap - -# a track buffer used only with picking tracks -track_buffer = [] -# indices buffer for the tracks -ind_buffer = [] -# tempory renderer used only with picking tracks -tmp_ren = None - -if have_vtk: - - major_version = vtk.vtkVersion.GetVTKMajorVersion() - - # Create a text mapper and actor to display the results of picking. - textMapper = vtk.vtkTextMapper() - tprop = textMapper.GetTextProperty() - tprop.SetFontFamilyToArial() - tprop.SetFontSize(10) - # tprop.BoldOn() - # tprop.ShadowOn() - tprop.SetColor(1, 0, 0) - textActor = vtk.vtkActor2D() - textActor.VisibilityOff() - textActor.SetMapper(textMapper) - # Create a cell picker. - picker = vtk.vtkCellPicker() - - from dipy.viz.window import (ren, renderer, add, clear, rm, rm_all, - show, record, snapshot) - from dipy.viz.actor import line, streamtube, slicer, axes, dots, point - - try: - if major_version < 7: - from vtk import vtkVolumeTextureMapper2D as VolumeMapper - else: - from vtk import vtkSmartVolumeMapper as VolumeMapper - have_vtk_texture_mapper2D = True - except Exception: - have_vtk_texture_mapper2D = False - -else: - ren, have_ren, _ = optional_package('dipy.viz.window.ren', - 'Python VTK is not installed') - - - -deprecation_msg = ("Module 'dipy.viz.fvtk' is deprecated as of version" - " 0.14 of dipy and will be removed in a future version." - " Please, instead use module 'dipy.viz.window' or " - " 'dipy.viz.actor'.") -warn(DeprecationWarning(deprecation_msg)) - - -def volume(vol, voxsz=(1.0, 1.0, 1.0), affine=None, center_origin=1, - info=0, maptype=0, trilinear=1, iso=0, iso_thr=100, - opacitymap=None, colormap=None): - ''' Create a volume and return a volumetric actor using volumetric - rendering. - - This function has many different interesting capabilities. The maptype, - opacitymap and colormap are the most crucial parameters here. - - Parameters - ---------- - vol : array, shape (N, M, K), dtype uint8 - An array representing the volumetric dataset that we want to visualize - using volumetric rendering. - voxsz : (3,) array_like - Voxel size. - affine : (4, 4) ndarray - As given by volumeimages. - center_origin : int {0,1} - It considers that the center of the volume is the - point ``(-vol.shape[0]/2.0+0.5,-vol.shape[1]/2.0+0.5, - -vol.shape[2]/2.0+0.5)``. - info : int {0,1} - If 1 it prints out some info about the volume, the method and the - dataset. - trilinear : int {0,1} - Use trilinear interpolation, default 1, gives smoother rendering. If - you want faster interpolation use 0 (Nearest). - maptype : int {0,1} - The maptype is a very important parameter which affects the - raycasting algorithm in use for the rendering. - The options are: - If 0 then vtkVolumeTextureMapper2D is used. - If 1 then vtkVolumeRayCastFunction is used. - iso : int {0,1} - If iso is 1 and maptype is 1 then we use - ``vtkVolumeRayCastIsosurfaceFunction`` which generates an isosurface at - the predefined iso_thr value. If iso is 0 and maptype is 1 - ``vtkVolumeRayCastCompositeFunction`` is used. - iso_thr : int - If iso is 1 then then this threshold in the volume defines the value - which will be used to create the isosurface. - opacitymap : (2, 2) ndarray - The opacity map assigns a transparency coefficient to every point in - the volume. The default value uses the histogram of the volume to - calculate the opacitymap. - colormap : (4, 4) ndarray - The color map assigns a color value to every point in the volume. - When None from the histogram it uses a red-blue colormap. - - Returns - ------- - v : vtkVolume - Volume. - - Notes - -------- - What is the difference between TextureMapper2D and RayCastFunction? Coming - soon... See VTK user's guide [book] & The Visualization Toolkit [book] and - VTK's online documentation & online docs. - - What is the difference between RayCastIsosurfaceFunction and - RayCastCompositeFunction? Coming soon... See VTK user's guide [book] & - The Visualization Toolkit [book] and VTK's online documentation & - online docs. - - What about trilinear interpolation? - Coming soon... well when time permits really ... :-) - - Examples - -------- - First example random points. - - >>> from dipy.viz import fvtk - >>> import numpy as np - >>> vol=100*np.random.rand(100,100,100) - >>> vol=vol.astype('uint8') - >>> vol.min(), vol.max() - (0, 99) - >>> r = fvtk.ren() - >>> v = fvtk.volume(vol) - >>> fvtk.add(r,v) - >>> #fvtk.show(r) - - Second example with a more complicated function - - >>> from dipy.viz import fvtk - >>> import numpy as np - >>> x, y, z = np.ogrid[-10:10:20j, -10:10:20j, -10:10:20j] - >>> s = np.sin(x*y*z)/(x*y*z) - >>> r = fvtk.ren() - >>> v = fvtk.volume(s) - >>> fvtk.add(r,v) - >>> #fvtk.show(r) - - If you find this function too complicated you can always use mayavi. - Please do not forget to use the -wthread switch in ipython if you are - running mayavi. - - from enthought.mayavi import mlab - import numpy as np - x, y, z = np.ogrid[-10:10:20j, -10:10:20j, -10:10:20j] - s = np.sin(x*y*z)/(x*y*z) - mlab.pipeline.volume(mlab.pipeline.scalar_field(s)) - mlab.show() - - More mayavi demos are available here: - - http://code.enthought.com/projects/mayavi/docs/development/html/mayavi/mlab.html - - ''' - if vol.ndim != 3: - raise ValueError('3d numpy arrays only please') - - if info: - print('Datatype', vol.dtype, 'converted to uint8') - - vol = np.interp(vol, [vol.min(), vol.max()], [0, 255]) - vol = vol.astype('uint8') - - if opacitymap is None: - - bin, res = np.histogram(vol.ravel()) - res2 = np.interp(res, [vol.min(), vol.max()], [0, 1]) - opacitymap = np.vstack((res, res2)).T - opacitymap = opacitymap.astype('float32') - - ''' - opacitymap=np.array([[ 0.0, 0.0], - [50.0, 0.9]]) - ''' - - if info: - print('opacitymap', opacitymap) - - if colormap is None: - - bin, res = np.histogram(vol.ravel()) - res2 = np.interp(res, [vol.min(), vol.max()], [0, 1]) - zer = np.zeros(res2.shape) - colormap = np.vstack((res, res2, zer, res2[::-1])).T - colormap = colormap.astype('float32') - - ''' - colormap=np.array([[0.0, 0.5, 0.0, 0.0], - [64.0, 1.0, 0.5, 0.5], - [128.0, 0.9, 0.2, 0.3], - [196.0, 0.81, 0.27, 0.1], - [255.0, 0.5, 0.5, 0.5]]) - ''' - - if info: - print('colormap', colormap) - - im = vtk.vtkImageData() - - if major_version <= 5: - im.SetScalarTypeToUnsignedChar() - im.SetDimensions(vol.shape[0], vol.shape[1], vol.shape[2]) - # im.SetOrigin(0,0,0) - # im.SetSpacing(voxsz[2],voxsz[0],voxsz[1]) - if major_version <= 5: - im.AllocateScalars() - else: - im.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, 3) - - for i in range(vol.shape[0]): - for j in range(vol.shape[1]): - for k in range(vol.shape[2]): - - im.SetScalarComponentFromFloat(i, j, k, 0, vol[i, j, k]) - - if affine is not None: - - aff = vtk.vtkMatrix4x4() - aff.DeepCopy((affine[0, 0], affine[0, 1], affine[0, 2], - affine[0, 3], affine[1, 0], affine[1, 1], - affine[1, 2], affine[1, 3], affine[2, 0], - affine[2, 1], affine[2, 2], affine[2, 3], - affine[3, 0], affine[3, 1], affine[3, 2], - affine[3, 3])) - # aff.DeepCopy((affine[0,0],affine[0,1],affine[0,2],0,affine[1,0],affine[1,1],affine[1,2],0,affine[2,0],affine[2,1],affine[2,2],0,affine[3,0],affine[3,1],affine[3,2],1)) - # aff.DeepCopy((affine[0,0],affine[0,1],affine[0,2],127.5,affine[1,0],affine[1,1],affine[1,2],-127.5,affine[2,0],affine[2,1],affine[2,2],-127.5,affine[3,0],affine[3,1],affine[3,2],1)) - - reslice = vtk.vtkImageReslice() - if major_version <= 5: - reslice.SetInput(im) - else: - reslice.SetInputData(im) - # reslice.SetOutputDimensionality(2) - # reslice.SetOutputOrigin(127,-145,147) - - reslice.SetResliceAxes(aff) - # reslice.SetOutputOrigin(-127,-127,-127) - # reslice.SetOutputExtent(-127,128,-127,128,-127,128) - # reslice.SetResliceAxesOrigin(0,0,0) - # print 'Get Reslice Axes Origin ', reslice.GetResliceAxesOrigin() - # reslice.SetOutputSpacing(1.0,1.0,1.0) - - reslice.SetInterpolationModeToLinear() - # reslice.UpdateWholeExtent() - - # print 'reslice GetOutputOrigin', reslice.GetOutputOrigin() - # print 'reslice GetOutputExtent',reslice.GetOutputExtent() - # print 'reslice GetOutputSpacing',reslice.GetOutputSpacing() - - changeFilter = vtk.vtkImageChangeInformation() - if major_version <= 5: - changeFilter.SetInput(reslice.GetOutput()) - else: - changeFilter.SetInputData(reslice.GetOutput()) - # changeFilter.SetInput(im) - if center_origin: - changeFilter.SetOutputOrigin( - -vol.shape[0] / 2.0 + 0.5, - -vol.shape[1] / 2.0 + 0.5, - -vol.shape[2] / 2.0 + 0.5) - print('ChangeFilter ', changeFilter.GetOutputOrigin()) - - opacity = vtk.vtkPiecewiseFunction() - for i in range(opacitymap.shape[0]): - opacity.AddPoint(opacitymap[i, 0], opacitymap[i, 1]) - - color = vtk.vtkColorTransferFunction() - for i in range(colormap.shape[0]): - color.AddRGBPoint( - colormap[i, 0], colormap[i, 1], colormap[i, 2], colormap[i, 3]) - - if(maptype == 0): - if not have_vtk_texture_mapper2D: - raise ValueError("VolumeTextureMapper2D is not available in your " - "version of VTK") - - property = vtk.vtkVolumeProperty() - property.SetColor(color) - property.SetScalarOpacity(opacity) - - if trilinear: - property.SetInterpolationTypeToLinear() - else: - property.SetInterpolationTypeToNearest() - - if info: - print('mapper VolumeTextureMapper2D') - mapper = VolumeMapper() # vtk.vtkVolumeTextureMapper2D() - if affine is None: - if major_version <= 5: - mapper.SetInput(im) - else: - mapper.SetInputData(im) - else: - if major_version <= 5: - mapper.SetInput(changeFilter.GetOutput()) - else: - mapper.SetInputData(changeFilter.GetOutput()) - - if (maptype == 1): - - property = vtk.vtkVolumeProperty() - property.SetColor(color) - property.SetScalarOpacity(opacity) - property.ShadeOn() - if trilinear: - property.SetInterpolationTypeToLinear() - else: - property.SetInterpolationTypeToNearest() - - if iso: - isofunc = vtk.vtkVolumeRayCastIsosurfaceFunction() - isofunc.SetIsoValue(iso_thr) - else: - compositeFunction = vtk.vtkVolumeRayCastCompositeFunction() - - if info: - print('mapper VolumeRayCastMapper') - - mapper = vtk.vtkVolumeRayCastMapper() - if iso: - mapper.SetVolumeRayCastFunction(isofunc) - if info: - print('Isosurface') - else: - mapper.SetVolumeRayCastFunction(compositeFunction) - - # mapper.SetMinimumImageSampleDistance(0.2) - if info: - print('Composite') - - if affine is None: - if major_version <= 5: - mapper.SetInput(im) - else: - mapper.SetInputData(im) - else: - # mapper.SetInput(reslice.GetOutput()) - if major_version <= 5: - mapper.SetInput(changeFilter.GetOutput()) - else: - mapper.SetInputData(changeFilter.GetOutput()) - # Return mid position in world space - # im2=reslice.GetOutput() - # index=im2.FindPoint(vol.shape[0]/2.0,vol.shape[1]/2.0,vol.shape[2]/2.0) - # print 'Image Getpoint ' , im2.GetPoint(index) - - volum = vtk.vtkVolume() - volum.SetMapper(mapper) - volum.SetProperty(property) - - if info: - - print('Origin', volum.GetOrigin()) - print('Orientation', volum.GetOrientation()) - print('OrientationW', volum.GetOrientationWXYZ()) - print('Position', volum.GetPosition()) - print('Center', volum.GetCenter()) - print('Get XRange', volum.GetXRange()) - print('Get YRange', volum.GetYRange()) - print('Get ZRange', volum.GetZRange()) - print('Volume data type', vol.dtype) - - return volum - - -def contour(vol, voxsz=(1.0, 1.0, 1.0), affine=None, levels=[50], - colors=[np.array([1.0, 0.0, 0.0])], opacities=[0.5]): - """ Take a volume and draw surface contours for any any number of - thresholds (levels) where every contour has its own color and opacity - - Parameters - ---------- - vol : (N, M, K) ndarray - An array representing the volumetric dataset for which we will draw - some beautiful contours . - voxsz : (3,) array_like - Voxel size. - affine : None - Not used. - levels : array_like - Sequence of thresholds for the contours taken from image values needs - to be same datatype as `vol`. - colors : (N, 3) ndarray - RGB values in [0,1]. - opacities : array_like - Opacities of contours. - - Returns - ------- - vtkAssembly - - Examples - -------- - >>> import numpy as np - >>> from dipy.viz import fvtk - >>> A=np.zeros((10,10,10)) - >>> A[3:-3,3:-3,3:-3]=1 - >>> r=fvtk.ren() - >>> fvtk.add(r,fvtk.contour(A,levels=[1])) - >>> #fvtk.show(r) - - """ - - im = vtk.vtkImageData() - if major_version <= 5: - im.SetScalarTypeToUnsignedChar() - - im.SetDimensions(vol.shape[0], vol.shape[1], vol.shape[2]) - # im.SetOrigin(0,0,0) - # im.SetSpacing(voxsz[2],voxsz[0],voxsz[1]) - if major_version <= 5: - im.AllocateScalars() - else: - im.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, 3) - - for i in range(vol.shape[0]): - for j in range(vol.shape[1]): - for k in range(vol.shape[2]): - - im.SetScalarComponentFromFloat(i, j, k, 0, vol[i, j, k]) - - ass = vtk.vtkAssembly() - # ass=[] - - for (i, l) in enumerate(levels): - - # print levels - skinExtractor = vtk.vtkContourFilter() - if major_version <= 5: - skinExtractor.SetInput(im) - else: - skinExtractor.SetInputData(im) - skinExtractor.SetValue(0, l) - - skinNormals = vtk.vtkPolyDataNormals() - skinNormals.SetInputConnection(skinExtractor.GetOutputPort()) - skinNormals.SetFeatureAngle(60.0) - - skinMapper = vtk.vtkPolyDataMapper() - skinMapper.SetInputConnection(skinNormals.GetOutputPort()) - skinMapper.ScalarVisibilityOff() - - skin = vtk.vtkActor() - - skin.SetMapper(skinMapper) - skin.GetProperty().SetOpacity(opacities[i]) - - # print colors[i] - skin.GetProperty().SetColor(colors[i][0], colors[i][1], colors[i][2]) - # skin.Update() - ass.AddPart(skin) - - del skin - del skinMapper - del skinExtractor - - return ass - - -def _makeNd(array, ndim): - """Pads as many 1s at the beginning of array's shape as are need to give - array ndim dimensions.""" - new_shape = (1,) * (ndim - array.ndim) + array.shape - return array.reshape(new_shape) - - -def sphere_funcs(sphere_values, sphere, image=None, colormap='jet', - scale=2.2, norm=True, radial_scale=True): - """Plot many morphed spherical functions simultaneously. - - Parameters - ---------- - sphere_values : (M,) or (X, M) or (X, Y, M) or (X, Y, Z, M) ndarray - Values on the sphere. - sphere : Sphere - image : None, - Not yet supported. - colormap : None or 'jet' - If None then no color is used. - scale : float, - Distance between spheres. - norm : bool, - Normalize `sphere_values`. - radial_scale : bool, - Scale sphere points according to odf values. - - Returns - ------- - actor : vtkActor - Spheres. - - Examples - -------- - >>> from dipy.viz import fvtk - >>> r = fvtk.ren() - >>> odfs = np.ones((5, 5, 724)) - >>> odfs[..., 0] = 2. - >>> from dipy.data import get_sphere - >>> sphere = get_sphere('symmetric724') - >>> fvtk.add(r, fvtk.sphere_funcs(odfs, sphere)) - >>> #fvtk.show(r) - - """ - - sphere_values = np.asarray(sphere_values) - if sphere_values.ndim > 4: - raise ValueError("Wrong shape") - sphere_values = _makeNd(sphere_values, 4) - - grid_shape = np.array(sphere_values.shape[:3]) - faces = np.asarray(sphere.faces, dtype=int) - vertices = sphere.vertices - - if sphere_values.shape[-1] != sphere.vertices.shape[0]: - msg = 'Sphere.vertices.shape[0] should be the same as the ' - msg += 'last dimensions of sphere_values i.e. sphere_values.shape[-1]' - raise ValueError(msg) - - list_sq = [] - list_cols = [] - - for ijk in np.ndindex(*grid_shape): - m = sphere_values[ijk].copy() - - if norm: - m /= abs(m).max() - - if radial_scale: - xyz = vertices.T * m - else: - xyz = vertices.T.copy() - - xyz += scale * (ijk - grid_shape / 2.)[:, None] - - xyz = xyz.T - - list_sq.append(xyz) - if colormap is not None: - cols = create_colormap(m, colormap) - cols = np.interp(cols, [0, 1], [0, 255]).astype('ubyte') - list_cols.append(cols) - - points = vtk.vtkPoints() - triangles = vtk.vtkCellArray() - if colormap is not None: - colors = vtk.vtkUnsignedCharArray() - colors.SetNumberOfComponents(3) - colors.SetName("Colors") - - for k in xrange(len(list_sq)): - - xyz = list_sq[k] - if colormap is not None: - cols = list_cols[k] - - for i in xrange(xyz.shape[0]): - - points.InsertNextPoint(*xyz[i]) - if colormap is not None: - colors.InsertNextTuple3(*cols[i]) - - for j in xrange(faces.shape[0]): - - triangle = vtk.vtkTriangle() - triangle.GetPointIds().SetId(0, faces[j, 0] + k * xyz.shape[0]) - triangle.GetPointIds().SetId(1, faces[j, 1] + k * xyz.shape[0]) - triangle.GetPointIds().SetId(2, faces[j, 2] + k * xyz.shape[0]) - triangles.InsertNextCell(triangle) - del triangle - - polydata = vtk.vtkPolyData() - polydata.SetPoints(points) - polydata.SetPolys(triangles) - - if colormap is not None: - polydata.GetPointData().SetScalars(colors) - polydata.Modified() - - mapper = vtk.vtkPolyDataMapper() - if major_version <= 5: - mapper.SetInput(polydata) - else: - mapper.SetInputData(polydata) - - actor = vtk.vtkActor() - actor.SetMapper(mapper) - - return actor - - -def peaks(peaks_dirs, peaks_values=None, scale=2.2, colors=(1, 0, 0)): - """ Visualize peak directions as given from ``peaks_from_model`` - - Parameters - ---------- - peaks_dirs : ndarray - Peak directions. The shape of the array can be (M, 3) or (X, M, 3) or - (X, Y, M, 3) or (X, Y, Z, M, 3) - peaks_values : ndarray - Peak values. The shape of the array can be (M, ) or (X, M) or - (X, Y, M) or (X, Y, Z, M) - - scale : float - Distance between spheres - - colors : ndarray or tuple - Peak colors - - Returns - ------- - vtkActor - - See Also - -------- - dipy.viz.fvtk.sphere_funcs - - """ - peaks_dirs = np.asarray(peaks_dirs) - if peaks_dirs.ndim > 5: - raise ValueError("Wrong shape") - - peaks_dirs = _makeNd(peaks_dirs, 5) - if peaks_values is not None: - peaks_values = _makeNd(peaks_values, 4) - - grid_shape = np.array(peaks_dirs.shape[:3]) - - list_dirs = [] - - for ijk in np.ndindex(*grid_shape): - - xyz = scale * (ijk - grid_shape / 2.)[:, None] - - xyz = xyz.T - - for i in range(peaks_dirs.shape[-2]): - - if peaks_values is not None: - - pv = peaks_values[ijk][i] - - else: - - pv = 1. - - symm = np.vstack((-peaks_dirs[ijk][i] * pv + xyz, - peaks_dirs[ijk][i] * pv + xyz)) - - list_dirs.append(symm) - - return line(list_dirs, colors) - - -def tensor(evals, evecs, scalar_colors=None, - sphere=None, scale=2.2, norm=True): - """Plot many tensors as ellipsoids simultaneously. - - Parameters - ---------- - evals : (3,) or (X, 3) or (X, Y, 3) or (X, Y, Z, 3) ndarray - eigenvalues - evecs : (3, 3) or (X, 3, 3) or (X, Y, 3, 3) or (X, Y, Z, 3, 3) ndarray - eigenvectors - scalar_colors : (3,) or (X, 3) or (X, Y, 3) or (X, Y, Z, 3) ndarray - RGB colors used to show the tensors - Default None, color the ellipsoids using ``color_fa`` - sphere : Sphere, - this sphere will be transformed to the tensor ellipsoid - Default is None which uses a symmetric sphere with 724 points. - scale : float, - distance between ellipsoids. - norm : boolean, - Normalize `evals`. - - Returns - ------- - actor : vtkActor - Ellipsoids - - Examples - -------- - >>> from dipy.viz import fvtk - >>> r = fvtk.ren() - >>> evals = np.array([1.4, .35, .35]) * 10 ** (-3) - >>> evecs = np.eye(3) - >>> from dipy.data import get_sphere - >>> sphere = get_sphere('symmetric724') - >>> fvtk.add(r, fvtk.tensor(evals, evecs, sphere=sphere)) - >>> #fvtk.show(r) - - """ - - evals = np.asarray(evals) - if evals.ndim > 4: - raise ValueError("Wrong shape") - evals = _makeNd(evals, 4) - evecs = _makeNd(evecs, 5) - - grid_shape = np.array(evals.shape[:3]) - - if sphere is None: - from dipy.data import get_sphere - sphere = get_sphere('symmetric724') - faces = np.asarray(sphere.faces, dtype=int) - vertices = sphere.vertices - - colors = vtk.vtkUnsignedCharArray() - colors.SetNumberOfComponents(3) - colors.SetName("Colors") - - if scalar_colors is None: - from dipy.reconst.dti import color_fa, fractional_anisotropy - cfa = color_fa(fractional_anisotropy(evals), evecs) - else: - cfa = _makeNd(scalar_colors, 4) - - list_sq = [] - list_cols = [] - - for ijk in ndindex(grid_shape): - ea = evals[ijk] - if norm: - ea /= ea.max() - ea = np.diag(ea.copy()) - - ev = evecs[ijk].copy() - xyz = np.dot(ev, np.dot(ea, vertices.T)) - - xyz += scale * (ijk - grid_shape / 2.)[:, None] - - xyz = xyz.T - - list_sq.append(xyz) - - acolor = np.zeros(xyz.shape) - acolor[:, :] = np.interp(cfa[ijk], [0, 1], [0, 255]) - list_cols.append(acolor.astype('ubyte')) - - points = vtk.vtkPoints() - triangles = vtk.vtkCellArray() - - for k in xrange(len(list_sq)): - - xyz = list_sq[k] - - cols = list_cols[k] - - for i in xrange(xyz.shape[0]): - - points.InsertNextPoint(*xyz[i]) - colors.InsertNextTuple3(*cols[i]) - - for j in xrange(faces.shape[0]): - - triangle = vtk.vtkTriangle() - triangle.GetPointIds().SetId(0, faces[j, 0] + k * xyz.shape[0]) - triangle.GetPointIds().SetId(1, faces[j, 1] + k * xyz.shape[0]) - triangle.GetPointIds().SetId(2, faces[j, 2] + k * xyz.shape[0]) - triangles.InsertNextCell(triangle) - del triangle - - polydata = vtk.vtkPolyData() - polydata.SetPoints(points) - polydata.SetPolys(triangles) - - polydata.GetPointData().SetScalars(colors) - polydata.Modified() - - mapper = vtk.vtkPolyDataMapper() - if major_version <= 5: - mapper.SetInput(polydata) - else: - mapper.SetInputData(polydata) - - actor = vtk.vtkActor() - actor.SetMapper(mapper) - - return actor - - -def label(ren, text='Origin', pos=(0, 0, 0), scale=(0.2, 0.2, 0.2), - color=(1, 1, 1)): - """ Create a label actor. - This actor will always face the camera - Parameters - ---------- - ren : vtkRenderer() object - Renderer as returned by ``ren()``. - text : str - Text for the label. - pos : (3,) array_like, optional - Left down position of the label. - scale : (3,) array_like - Changes the size of the label. - color : (3,) array_like - Label color as ``(r,g,b)`` tuple. - Returns - ------- - l : vtkActor object - Label. - Examples - -------- - >>> from dipy.viz import fvtk - >>> r=fvtk.ren() - >>> l=fvtk.label(r) - >>> fvtk.add(r,l) - >>> #fvtk.show(r) - """ - atext = vtk.vtkVectorText() - atext.SetText(text) - - textm = vtk.vtkPolyDataMapper() - if major_version <= 5: - textm.SetInput(atext.GetOutput()) - else: - textm.SetInputData(atext.GetOutput()) - - texta = vtk.vtkFollower() - texta.SetMapper(textm) - texta.SetScale(scale) - - texta.GetProperty().SetColor(color) - texta.SetPosition(pos) - - ren.AddActor(texta) - texta.SetCamera(ren.GetActiveCamera()) - - return texta - - -def camera(ren, pos=None, focal=None, viewup=None, verbose=True): - """ Change the active camera - - Parameters - ---------- - ren : vtkRenderer - pos : tuple - (x, y, z) position of the camera - focal : tuple - (x, y, z) focal point - viewup : tuple - (x, y, z) viewup vector - verbose : bool - show information about the camera - - Returns - ------- - vtkCamera - """ - - msg = "This function is deprecated." - msg += "Please use the window.Renderer class to get/set the active camera." - warn(DeprecationWarning(msg)) - - cam = ren.GetActiveCamera() - if verbose: - print('Camera Position (%.2f,%.2f,%.2f)' % cam.GetPosition()) - print('Camera Focal Point (%.2f,%.2f,%.2f)' % cam.GetFocalPoint()) - print('Camera View Up (%.2f,%.2f,%.2f)' % cam.GetViewUp()) - if pos is not None: - ren.GetActiveCamera().SetPosition(*pos) - if focal is not None: - ren.GetActiveCamera().SetFocalPoint(*focal) - if viewup is not None: - ren.GetActiveCamera().SetViewUp(*viewup) - - cam = ren.GetActiveCamera() - if pos is not None or focal is not None or viewup is not None: - if verbose: - print('-------------------------------------') - print('Camera New Position (%.2f,%.2f,%.2f)' % cam.GetPosition()) - print('Camera New Focal Point (%.2f,%.2f,%.2f)' % - cam.GetFocalPoint()) - print('Camera New View Up (%.2f,%.2f,%.2f)' % cam.GetViewUp()) - - return cam - - -if __name__ == "__main__": - pass diff --git a/dipy/viz/interactor.py b/dipy/viz/interactor.py deleted file mode 100644 index c9cd9d294c..0000000000 --- a/dipy/viz/interactor.py +++ /dev/null @@ -1,298 +0,0 @@ -import numpy as np - -# Conditional import machinery for vtk -from dipy.utils.optpkg import optional_package - -# Allow import, but disable doctests if we don't have vtk -vtk, have_vtk, setup_module = optional_package('vtk') - -if have_vtk: - vtkInteractorStyleUser = vtk.vtkInteractorStyleUser - # version = vtk.vtkVersion.GetVTKSourceVersion().split(' ')[-1] - # major_version = vtk.vtkVersion.GetVTKMajorVersion() -else: - vtkInteractorStyleUser = object - - -class Event(object): - def __init__(self): - self.position = None - self.name = None - self.key = None - self._abort_flag = None - - @property - def abort_flag(self): - return self._abort_flag - - def update(self, event_name, interactor): - """ Updates current event information. """ - self.name = event_name - self.position = np.asarray(interactor.GetEventPosition()) - self.key = interactor.GetKeySym() - self._abort_flag = False # Reset abort flag - - def abort(self): - """ Aborts the event i.e. do not propagate it any further. """ - self._abort_flag = True - - def reset(self): - """ Done with the current event. Reset the attributes. """ - self.position = None - self.name = None - self.key = None - self._abort_flag = False - - -class CustomInteractorStyle(vtkInteractorStyleUser): - """ Manipulate the camera and interact with objects in the scene. - - This interactor style allows the user to interactively manipulate (pan, - rotate and zoom) the camera. It also allows the user to interact (click, - scroll, etc.) with objects in the scene. - - Several events handling methods from :class:`vtkInteractorStyleUser` have - been overloaded to allow the propagation of the events to the objects the - user is interacting with. - - In summary, while interacting with the scene, the mouse events are as - follows: - - Left mouse button: rotates the camera - - Right mouse button: dollys the camera - - Mouse wheel: dollys the camera - - Middle mouse button: pans the camera - """ - def __init__(self): - # Default interactor is responsible for moving the camera. - self.default_interactor = vtk.vtkInteractorStyleTrackballCamera() - # The picker allows us to know which object/actor is under the mouse. - self.picker = vtk.vtkPropPicker() - self.chosen_element = None - self.event = Event() - - # Define some interaction states - self.left_button_down = False - self.right_button_down = False - self.middle_button_down = False - self.active_props = set() - - self.selected_props = {"left_button": set(), - "right_button": set(), - "middle_button": set()} - - def add_active_prop(self, prop): - self.active_props.add(prop) - - def remove_active_prop(self, prop): - self.active_props.remove(prop) - - def get_prop_at_event_position(self): - """ Returns the prop that lays at the event position. """ - # TODO: return a list of items (i.e. each level of the assembly path). - event_pos = self.GetInteractor().GetEventPosition() - self.picker.Pick(event_pos[0], event_pos[1], 0, - self.GetCurrentRenderer()) - - path = self.picker.GetPath() - if path is None: - return None - - node = path.GetLastNode() - prop = node.GetViewProp() - return prop - - def propagate_event(self, evt, *props): - for prop in props: - # Propagate event to the prop. - prop.InvokeEvent(evt) - - if self.event.abort_flag: - return - - def on_left_button_down(self, obj, evt): - self.left_button_down = True - prop = self.get_prop_at_event_position() - if prop is not None: - self.selected_props["left_button"].add(prop) - self.propagate_event(evt, prop) - - if not self.event.abort_flag: - self.default_interactor.OnLeftButtonDown() - - def on_left_button_up(self, obj, evt): - self.left_button_down = False - self.propagate_event(evt, *self.selected_props["left_button"]) - self.selected_props["left_button"].clear() - self.default_interactor.OnLeftButtonUp() - - def on_right_button_down(self, obj, evt): - self.right_button_down = True - prop = self.get_prop_at_event_position() - if prop is not None: - self.selected_props["right_button"].add(prop) - self.propagate_event(evt, prop) - - if not self.event.abort_flag: - self.default_interactor.OnRightButtonDown() - - def on_right_button_up(self, obj, evt): - self.right_button_down = False - self.propagate_event(evt, *self.selected_props["right_button"]) - self.selected_props["right_button"].clear() - self.default_interactor.OnRightButtonUp() - - def on_middle_button_down(self, obj, evt): - self.middle_button_down = True - prop = self.get_prop_at_event_position() - if prop is not None: - self.selected_props["middle_button"].add(prop) - self.propagate_event(evt, prop) - - if not self.event.abort_flag: - self.default_interactor.OnMiddleButtonDown() - - def on_middle_button_up(self, obj, evt): - self.middle_button_down = False - self.propagate_event(evt, *self.selected_props["middle_button"]) - self.selected_props["middle_button"].clear() - self.default_interactor.OnMiddleButtonUp() - - def on_mouse_move(self, obj, evt): - # Only propagate events to active or selected props. - self.propagate_event(evt, *(self.active_props | - self.selected_props["left_button"] | - self.selected_props["right_button"] | - self.selected_props["middle_button"])) - self.default_interactor.OnMouseMove() - - def on_mouse_wheel_forward(self, obj, evt): - # First, propagate mouse wheel event to underneath prop. - prop = self.get_prop_at_event_position() - if prop is not None: - self.propagate_event(evt, prop) - - # Then, to the active props. - if not self.event.abort_flag: - self.propagate_event(evt, *self.active_props) - - # Finally, to the default interactor. - if not self.event.abort_flag: - self.default_interactor.OnMouseWheelForward() - - self.event.reset() - - def on_mouse_wheel_backward(self, obj, evt): - # First, propagate mouse wheel event to underneath prop. - prop = self.get_prop_at_event_position() - if prop is not None: - self.propagate_event(evt, prop) - - # Then, to the active props. - if not self.event.abort_flag: - self.propagate_event(evt, *self.active_props) - - # Finally, to the default interactor. - if not self.event.abort_flag: - self.default_interactor.OnMouseWheelBackward() - - self.event.reset() - - def on_char(self, obj, evt): - self.propagate_event(evt, *self.active_props) - - def on_key_press(self, obj, evt): - self.propagate_event(evt, *self.active_props) - - def on_key_release(self, obj, evt): - self.propagate_event(evt, *self.active_props) - - def SetInteractor(self, interactor): - # Internally, `InteractorStyle` objects need a handle to a - # `vtkWindowInteractor` object and this is done via `SetInteractor`. - # However, this has the side effect of adding directly all their - # observers to the `interactor`! - self.default_interactor.SetInteractor(interactor) - - # Remove all observers *most likely* (cannot guarantee that the - # interactor didn't already have these observers) added by - # `vtkInteractorStyleTrackballCamera`, i.e. our `default_interactor`. - # - # Note: Be sure that no observer has been manually added to the - # `interactor` before setting the InteractorStyle. - interactor.RemoveObservers("TimerEvent") - interactor.RemoveObservers("EnterEvent") - interactor.RemoveObservers("LeaveEvent") - interactor.RemoveObservers("ExposeEvent") - interactor.RemoveObservers("ConfigureEvent") - interactor.RemoveObservers("CharEvent") - interactor.RemoveObservers("KeyPressEvent") - interactor.RemoveObservers("KeyReleaseEvent") - interactor.RemoveObservers("MouseMoveEvent") - interactor.RemoveObservers("LeftButtonPressEvent") - interactor.RemoveObservers("RightButtonPressEvent") - interactor.RemoveObservers("MiddleButtonPressEvent") - interactor.RemoveObservers("LeftButtonReleaseEvent") - interactor.RemoveObservers("RightButtonReleaseEvent") - interactor.RemoveObservers("MiddleButtonReleaseEvent") - interactor.RemoveObservers("MouseWheelForwardEvent") - interactor.RemoveObservers("MouseWheelBackwardEvent") - - # This class is a `vtkClass` (instead of `object`), so `super()` - # cannot be used. Also the method `SetInteractor` is not overridden in - # `vtkInteractorStyleUser` so we have to call directly the one from - # `vtkInteractorStyle`. In addition to setting the interactor, the - # following line adds the necessary hooks to listen to this instance's - # observers. - vtk.vtkInteractorStyle.SetInteractor(self, interactor) - - # Keyboard events. - self.AddObserver("CharEvent", self.on_char) - self.AddObserver("KeyPressEvent", self.on_key_press) - self.AddObserver("KeyReleaseEvent", self.on_key_release) - - # Mouse events. - self.AddObserver("MouseMoveEvent", self.on_mouse_move) - self.AddObserver("LeftButtonPressEvent", self.on_left_button_down) - self.AddObserver("LeftButtonReleaseEvent", self.on_left_button_up) - self.AddObserver("RightButtonPressEvent", self.on_right_button_down) - self.AddObserver("RightButtonReleaseEvent", self.on_right_button_up) - self.AddObserver("MiddleButtonPressEvent", self.on_middle_button_down) - self.AddObserver("MiddleButtonReleaseEvent", self.on_middle_button_up) - - # Windows and special events. - # TODO: we ever find them useful we could support them. - # self.AddObserver("TimerEvent", self.on_timer) - # self.AddObserver("EnterEvent", self.on_enter) - # self.AddObserver("LeaveEvent", self.on_leave) - # self.AddObserver("ExposeEvent", self.on_expose) - # self.AddObserver("ConfigureEvent", self.on_configure) - - # These observers need to be added directly to the interactor because - # `vtkInteractorStyleUser` does not support wheel events prior 7.1. See - # https://github.com/Kitware/VTK/commit/373258ed21f0915c425eddb996ce6ac13404be28 - interactor.AddObserver("MouseWheelForwardEvent", - self.on_mouse_wheel_forward) - interactor.AddObserver("MouseWheelBackwardEvent", - self.on_mouse_wheel_backward) - - def force_render(self): - """ Causes the renderer to refresh. """ - self.GetInteractor().GetRenderWindow().Render() - - def add_callback(self, prop, event_type, callback, priority=0, args=[]): - """ Adds a callback associated to a specific event for a VTK prop. - - Parameters - ---------- - prop : vtkProp - event_type : event code - callback : function - priority : int - """ - - def _callback(obj, event_name): - # Update event information. - self.event.update(event_name, self.GetInteractor()) - callback(self, prop, *args) - - prop.AddObserver(event_type, _callback, priority) diff --git a/dipy/viz/tests/test_actors.py b/dipy/viz/tests/test_actors.py deleted file mode 100644 index 8e20841514..0000000000 --- a/dipy/viz/tests/test_actors.py +++ /dev/null @@ -1,747 +0,0 @@ -import os -import numpy as np - -from dipy.viz import actor, window - -import numpy.testing as npt -from nibabel.tmpdirs import TemporaryDirectory -from dipy.tracking.streamline import center_streamlines, transform_streamlines -from dipy.align.tests.test_streamlinear import fornix_streamlines -from dipy.reconst.dti import color_fa, fractional_anisotropy -from dipy.testing.decorators import xvfb_it -from dipy.data import get_sphere -from tempfile import mkstemp - - -use_xvfb = os.environ.get('TEST_WITH_XVFB', False) -if use_xvfb == 'skip': - skip_it = True -else: - skip_it = False - -run_test = (actor.have_vtk and - actor.have_vtk_colors and - window.have_imread and - not skip_it) - -if actor.have_vtk: - if actor.major_version == 5 and use_xvfb: - skip_slicer = True - else: - skip_slicer = False -else: - skip_slicer = False - - -@npt.dec.skipif(skip_slicer) -@npt.dec.skipif(not run_test) -@xvfb_it -def test_slicer(): - renderer = window.renderer() - data = (255 * np.random.rand(50, 50, 50)) - affine = np.eye(4) - slicer = actor.slicer(data, affine) - slicer.display(None, None, 25) - renderer.add(slicer) - - renderer.reset_camera() - renderer.reset_clipping_range() - # window.show(renderer) - - # copy pixels in numpy array directly - arr = window.snapshot(renderer, 'test_slicer.png', offscreen=True) - import scipy - print(scipy.__version__) - print(scipy.__file__) - - print(arr.sum()) - print(np.sum(arr == 0)) - print(np.sum(arr > 0)) - print(arr.shape) - print(arr.dtype) - - report = window.analyze_snapshot(arr, find_objects=True) - - npt.assert_equal(report.objects, 1) - # print(arr[..., 0]) - - # The slicer can cut directly a smaller part of the image - slicer.display_extent(10, 30, 10, 30, 35, 35) - renderer.ResetCamera() - - renderer.add(slicer) - - # save pixels in png file not a numpy array - with TemporaryDirectory() as tmpdir: - fname = os.path.join(tmpdir, 'slice.png') - # window.show(renderer) - window.snapshot(renderer, fname, offscreen=True) - report = window.analyze_snapshot(fname, find_objects=True) - npt.assert_equal(report.objects, 1) - - npt.assert_raises(ValueError, actor.slicer, np.ones(10)) - - renderer.clear() - - rgb = np.zeros((30, 30, 30, 3)) - rgb[..., 0] = 1. - rgb_actor = actor.slicer(rgb) - - renderer.add(rgb_actor) - - renderer.reset_camera() - renderer.reset_clipping_range() - - arr = window.snapshot(renderer, offscreen=True) - report = window.analyze_snapshot(arr, colors=[(255, 0, 0)]) - npt.assert_equal(report.objects, 1) - npt.assert_equal(report.colors_found, [True]) - - lut = actor.colormap_lookup_table(scale_range=(0, 255), - hue_range=(0.4, 1.), - saturation_range=(1, 1.), - value_range=(0., 1.)) - renderer.clear() - slicer_lut = actor.slicer(data, lookup_colormap=lut) - - slicer_lut.display(10, None, None) - slicer_lut.display(None, 10, None) - slicer_lut.display(None, None, 10) - - slicer_lut.opacity(0.5) - slicer_lut.tolerance(0.03) - slicer_lut2 = slicer_lut.copy() - npt.assert_equal(slicer_lut2.GetOpacity(), 0.5) - npt.assert_equal(slicer_lut2.picker.GetTolerance(), 0.03) - slicer_lut2.opacity(1) - slicer_lut2.tolerance(0.025) - slicer_lut2.display(None, None, 10) - renderer.add(slicer_lut2) - - renderer.reset_clipping_range() - - arr = window.snapshot(renderer, offscreen=True) - report = window.analyze_snapshot(arr, find_objects=True) - npt.assert_equal(report.objects, 1) - - renderer.clear() - - data = (255 * np.random.rand(50, 50, 50)) - affine = np.diag([1, 3, 2, 1]) - slicer = actor.slicer(data, affine, interpolation='nearest') - slicer.display(None, None, 25) - - renderer.add(slicer) - renderer.reset_camera() - renderer.reset_clipping_range() - - arr = window.snapshot(renderer, offscreen=True) - report = window.analyze_snapshot(arr, find_objects=True) - npt.assert_equal(report.objects, 1) - npt.assert_equal(data.shape, slicer.shape) - - renderer.clear() - - data = (255 * np.random.rand(50, 50, 50)) - affine = np.diag([1, 3, 2, 1]) - - from dipy.align.reslice import reslice - - data2, affine2 = reslice(data, affine, zooms=(1, 3, 2), - new_zooms=(1, 1, 1)) - - slicer = actor.slicer(data2, affine2, interpolation='linear') - slicer.display(None, None, 25) - - renderer.add(slicer) - renderer.reset_camera() - renderer.reset_clipping_range() - - # window.show(renderer, reset_camera=False) - arr = window.snapshot(renderer, offscreen=True) - report = window.analyze_snapshot(arr, find_objects=True) - npt.assert_equal(report.objects, 1) - npt.assert_array_equal([1, 3, 2] * np.array(data.shape), - np.array(slicer.shape)) - - -@npt.dec.skipif(not run_test) -@xvfb_it -def test_contour_from_roi(): - - # Render volume - renderer = window.renderer() - data = np.zeros((50, 50, 50)) - data[20:30, 25, 25] = 1. - data[25, 20:30, 25] = 1. - affine = np.eye(4) - surface = actor.contour_from_roi(data, affine, - color=np.array([1, 0, 1]), - opacity=.5) - renderer.add(surface) - - renderer.reset_camera() - renderer.reset_clipping_range() - # window.show(renderer) - - # Test binarization - renderer2 = window.renderer() - data2 = np.zeros((50, 50, 50)) - data2[20:30, 25, 25] = 1. - data2[35:40, 25, 25] = 1. - affine = np.eye(4) - surface2 = actor.contour_from_roi(data2, affine, - color=np.array([0, 1, 1]), - opacity=.5) - renderer2.add(surface2) - - renderer2.reset_camera() - renderer2.reset_clipping_range() - # window.show(renderer2) - - arr = window.snapshot(renderer, 'test_surface.png', offscreen=True) - arr2 = window.snapshot(renderer2, 'test_surface2.png', offscreen=True) - - report = window.analyze_snapshot(arr, find_objects=True) - report2 = window.analyze_snapshot(arr2, find_objects=True) - - npt.assert_equal(report.objects, 1) - npt.assert_equal(report2.objects, 2) - - # test on real streamlines using tracking example - from dipy.data import read_stanford_labels - from dipy.reconst.shm import CsaOdfModel - from dipy.data import default_sphere - from dipy.direction import peaks_from_model - from dipy.tracking.local import ThresholdTissueClassifier - from dipy.tracking import utils - from dipy.tracking.local import LocalTracking - from dipy.viz.colormap import line_colors - - hardi_img, gtab, labels_img = read_stanford_labels() - data = hardi_img.get_data() - labels = labels_img.get_data() - affine = hardi_img.get_affine() - - white_matter = (labels == 1) | (labels == 2) - - csa_model = CsaOdfModel(gtab, sh_order=6) - csa_peaks = peaks_from_model(csa_model, data, default_sphere, - relative_peak_threshold=.8, - min_separation_angle=45, - mask=white_matter) - - classifier = ThresholdTissueClassifier(csa_peaks.gfa, .25) - - seed_mask = labels == 2 - seeds = utils.seeds_from_mask(seed_mask, density=[1, 1, 1], affine=affine) - - # Initialization of LocalTracking. - # The computation happens in the next step. - streamlines = LocalTracking(csa_peaks, classifier, seeds, affine, - step_size=2) - - # Compute streamlines and store as a list. - streamlines = list(streamlines) - - # Prepare the display objects. - streamlines_actor = actor.line(streamlines, line_colors(streamlines)) - seedroi_actor = actor.contour_from_roi(seed_mask, affine, [0, 1, 1], 0.5) - - # Create the 3d display. - r = window.ren() - r2 = window.ren() - r.add(streamlines_actor) - arr3 = window.snapshot(r, 'test_surface3.png', offscreen=True) - report3 = window.analyze_snapshot(arr3, find_objects=True) - r2.add(streamlines_actor) - r2.add(seedroi_actor) - arr4 = window.snapshot(r2, 'test_surface4.png', offscreen=True) - report4 = window.analyze_snapshot(arr4, find_objects=True) - - # assert that the seed ROI rendering is not far - # away from the streamlines (affine error) - npt.assert_equal(report3.objects, report4.objects) - # window.show(r) - # window.show(r2) - - -@npt.dec.skipif(not run_test) -@xvfb_it -def test_streamtube_and_line_actors(): - renderer = window.renderer() - - line1 = np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2.]]) - line2 = line1 + np.array([0.5, 0., 0.]) - - lines = [line1, line2] - colors = np.array([[1, 0, 0], [0, 0, 1.]]) - c = actor.line(lines, colors, linewidth=3) - window.add(renderer, c) - - c = actor.line(lines, colors, spline_subdiv=5, linewidth=3) - window.add(renderer, c) - - # create streamtubes of the same lines and shift them a bit - c2 = actor.streamtube(lines, colors, linewidth=.1) - c2.SetPosition(2, 0, 0) - window.add(renderer, c2) - - arr = window.snapshot(renderer) - - report = window.analyze_snapshot(arr, - colors=[(255, 0, 0), (0, 0, 255)], - find_objects=True) - - npt.assert_equal(report.objects, 4) - npt.assert_equal(report.colors_found, [True, True]) - - # as before with splines - c2 = actor.streamtube(lines, colors, spline_subdiv=5, linewidth=.1) - c2.SetPosition(2, 0, 0) - window.add(renderer, c2) - - arr = window.snapshot(renderer) - - report = window.analyze_snapshot(arr, - colors=[(255, 0, 0), (0, 0, 255)], - find_objects=True) - - npt.assert_equal(report.objects, 4) - npt.assert_equal(report.colors_found, [True, True]) - - -@npt.dec.skipif(not run_test) -@xvfb_it -def test_bundle_maps(): - renderer = window.renderer() - bundle = fornix_streamlines() - bundle, shift = center_streamlines(bundle) - - mat = np.array([[1, 0, 0, 100], - [0, 1, 0, 100], - [0, 0, 1, 100], - [0, 0, 0, 1.]]) - - bundle = transform_streamlines(bundle, mat) - - # metric = np.random.rand(*(200, 200, 200)) - metric = 100 * np.ones((200, 200, 200)) - - # add lower values - metric[100, :, :] = 100 * 0.5 - - # create a nice orange-red colormap - lut = actor.colormap_lookup_table(scale_range=(0., 100.), - hue_range=(0., 0.1), - saturation_range=(1, 1), - value_range=(1., 1)) - - line = actor.line(bundle, metric, linewidth=0.1, lookup_colormap=lut) - window.add(renderer, line) - window.add(renderer, actor.scalar_bar(lut, ' ')) - - report = window.analyze_renderer(renderer) - - npt.assert_almost_equal(report.actors, 1) - # window.show(renderer) - - renderer.clear() - - nb_points = np.sum([len(b) for b in bundle]) - values = 100 * np.random.rand(nb_points) - # values[:nb_points/2] = 0 - - line = actor.streamtube(bundle, values, linewidth=0.1, lookup_colormap=lut) - renderer.add(line) - # window.show(renderer) - - report = window.analyze_renderer(renderer) - npt.assert_equal(report.actors_classnames[0], 'vtkLODActor') - - renderer.clear() - - colors = np.random.rand(nb_points, 3) - # values[:nb_points/2] = 0 - - line = actor.line(bundle, colors, linewidth=2) - renderer.add(line) - # window.show(renderer) - - report = window.analyze_renderer(renderer) - npt.assert_equal(report.actors_classnames[0], 'vtkLODActor') - # window.show(renderer) - - arr = window.snapshot(renderer) - report2 = window.analyze_snapshot(arr) - npt.assert_equal(report2.objects, 1) - - # try other input options for colors - renderer.clear() - actor.line(bundle, (1., 0.5, 0)) - actor.line(bundle, np.arange(len(bundle))) - actor.line(bundle) - colors = [np.random.rand(*b.shape) for b in bundle] - actor.line(bundle, colors=colors) - - -@npt.dec.skipif(not run_test) -@xvfb_it -def test_odf_slicer(interactive=False): - - sphere = get_sphere('symmetric362') - - shape = (11, 11, 11, sphere.vertices.shape[0]) - - fid, fname = mkstemp(suffix='_odf_slicer.mmap') - print(fid) - print(fname) - - odfs = np.memmap(fname, dtype=np.float64, mode='w+', - shape=shape) - - odfs[:] = 1 - - affine = np.eye(4) - renderer = window.Renderer() - - mask = np.ones(odfs.shape[:3]) - mask[:4, :4, :4] = 0 - - odfs[..., 0] = 1 - - odf_actor = actor.odf_slicer(odfs, affine, - mask=mask, sphere=sphere, scale=.25, - colormap='jet') - fa = 0. * np.zeros(odfs.shape[:3]) - fa[:, 0, :] = 1. - fa[:, -1, :] = 1. - fa[0, :, :] = 1. - fa[-1, :, :] = 1. - fa[5, 5, 5] = 1 - - k = 5 - I, J, K = odfs.shape[:3] - - fa_actor = actor.slicer(fa, affine) - fa_actor.display_extent(0, I, 0, J, k, k) - renderer.add(odf_actor) - renderer.reset_camera() - renderer.reset_clipping_range() - - odf_actor.display_extent(0, I, 0, J, k, k) - odf_actor.GetProperty().SetOpacity(1.0) - if interactive: - window.show(renderer, reset_camera=False) - - arr = window.snapshot(renderer) - report = window.analyze_snapshot(arr, find_objects=True) - npt.assert_equal(report.objects, 11 * 11) - - renderer.clear() - renderer.add(fa_actor) - renderer.reset_camera() - renderer.reset_clipping_range() - if interactive: - window.show(renderer) - - mask[:] = 0 - mask[5, 5, 5] = 1 - fa[5, 5, 5] = 0 - fa_actor = actor.slicer(fa, None) - fa_actor.display(None, None, 5) - odf_actor = actor.odf_slicer(odfs, None, mask=mask, - sphere=sphere, scale=.25, - colormap='jet', - norm=False, global_cm=True) - renderer.clear() - renderer.add(fa_actor) - renderer.add(odf_actor) - renderer.reset_camera() - renderer.reset_clipping_range() - if interactive: - window.show(renderer) - - renderer.clear() - renderer.add(odf_actor) - renderer.add(fa_actor) - odfs[:, :, :] = 1 - mask = np.ones(odfs.shape[:3]) - odf_actor = actor.odf_slicer(odfs, None, mask=mask, - sphere=sphere, scale=.25, - colormap='jet', - norm=False, global_cm=True) - - renderer.clear() - renderer.add(odf_actor) - renderer.add(fa_actor) - renderer.add(actor.axes((11, 11, 11))) - for i in range(11): - odf_actor.display(i, None, None) - fa_actor.display(i, None, None) - if interactive: - window.show(renderer) - for j in range(11): - odf_actor.display(None, j, None) - fa_actor.display(None, j, None) - if interactive: - window.show(renderer) - # with mask equal to zero everything should be black - mask = np.zeros(odfs.shape[:3]) - odf_actor = actor.odf_slicer(odfs, None, mask=mask, - sphere=sphere, scale=.25, - colormap='plasma', - norm=False, global_cm=True) - renderer.clear() - renderer.add(odf_actor) - renderer.reset_camera() - renderer.reset_clipping_range() - if interactive: - window.show(renderer) - - report = window.analyze_renderer(renderer) - npt.assert_equal(report.actors, 1) - npt.assert_equal(report.actors_classnames[0], 'vtkLODActor') - - del odf_actor - odfs._mmap.close() - del odfs - os.close(fid) - - os.remove(fname) - - -@npt.dec.skipif(not run_test) -@xvfb_it -def test_peak_slicer(interactive=False): - - _peak_dirs = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype='f4') - # peak_dirs.shape = (1, 1, 1) + peak_dirs.shape - - peak_dirs = np.zeros((11, 11, 11, 3, 3)) - - peak_values = np.random.rand(11, 11, 11, 3) - - peak_dirs[:, :, :] = _peak_dirs - - renderer = window.Renderer() - peak_actor = actor.peak_slicer(peak_dirs) - renderer.add(peak_actor) - renderer.add(actor.axes((11, 11, 11))) - if interactive: - window.show(renderer) - - renderer.clear() - renderer.add(peak_actor) - renderer.add(actor.axes((11, 11, 11))) - for k in range(11): - peak_actor.display_extent(0, 10, 0, 10, k, k) - - for j in range(11): - peak_actor.display_extent(0, 10, j, j, 0, 10) - - for i in range(11): - peak_actor.display(i, None, None) - - renderer.rm_all() - - peak_actor = actor.peak_slicer( - peak_dirs, - peak_values, - mask=None, - affine=np.diag([3, 2, 1, 1]), - colors=None, - opacity=1, - linewidth=3, - lod=True, - lod_points=10 ** 4, - lod_points_size=3) - - renderer.add(peak_actor) - renderer.add(actor.axes((11, 11, 11))) - if interactive: - window.show(renderer) - - report = window.analyze_renderer(renderer) - ex = ['vtkLODActor', 'vtkOpenGLActor', 'vtkOpenGLActor', 'vtkOpenGLActor'] - npt.assert_equal(report.actors_classnames, ex) - - -@npt.dec.skipif(not run_test) -@xvfb_it -def test_tensor_slicer(interactive=False): - - evals = np.array([1.4, .35, .35]) * 10 ** (-3) - evecs = np.eye(3) - - mevals = np.zeros((3, 2, 4, 3)) - mevecs = np.zeros((3, 2, 4, 3, 3)) - - mevals[..., :] = evals - mevecs[..., :, :] = evecs - - from dipy.data import get_sphere - - sphere = get_sphere('symmetric724') - - affine = np.eye(4) - renderer = window.Renderer() - - tensor_actor = actor.tensor_slicer(mevals, mevecs, affine=affine, - sphere=sphere, scale=.3) - I, J, K = mevals.shape[:3] - renderer.add(tensor_actor) - renderer.reset_camera() - renderer.reset_clipping_range() - - tensor_actor.display_extent(0, 1, 0, J, 0, K) - tensor_actor.GetProperty().SetOpacity(1.0) - if interactive: - window.show(renderer, reset_camera=False) - - npt.assert_equal(renderer.GetActors().GetNumberOfItems(), 1) - - # Test extent - big_extent = renderer.GetActors().GetLastActor().GetBounds() - big_extent_x = abs(big_extent[1] - big_extent[0]) - tensor_actor.display(x=2) - - if interactive: - window.show(renderer, reset_camera=False) - - small_extent = renderer.GetActors().GetLastActor().GetBounds() - small_extent_x = abs(small_extent[1] - small_extent[0]) - npt.assert_equal(big_extent_x > small_extent_x, True) - - # Test empty mask - empty_actor = actor.tensor_slicer(mevals, mevecs, affine=affine, - mask=np.zeros(mevals.shape[:3]), - sphere=sphere, scale=.3) - npt.assert_equal(empty_actor.GetMapper(), None) - - # Test mask - mask = np.ones(mevals.shape[:3]) - mask[:2, :3, :3] = 0 - cfa = color_fa(fractional_anisotropy(mevals), mevecs) - tensor_actor = actor.tensor_slicer(mevals, mevecs, affine=affine, mask=mask, - scalar_colors=cfa, sphere=sphere, scale=.3) - renderer.clear() - renderer.add(tensor_actor) - renderer.reset_camera() - renderer.reset_clipping_range() - - if interactive: - window.show(renderer, reset_camera=False) - - mask_extent = renderer.GetActors().GetLastActor().GetBounds() - mask_extent_x = abs(mask_extent[1] - mask_extent[0]) - npt.assert_equal(big_extent_x > mask_extent_x, True) - - # test display - tensor_actor.display() - current_extent = renderer.GetActors().GetLastActor().GetBounds() - current_extent_x = abs(current_extent[1] - current_extent[0]) - npt.assert_equal(big_extent_x > current_extent_x, True) - if interactive: - window.show(renderer, reset_camera=False) - - tensor_actor.display(y=1) - current_extent = renderer.GetActors().GetLastActor().GetBounds() - current_extent_y = abs(current_extent[3] - current_extent[2]) - big_extent_y = abs(big_extent[3] - big_extent[2]) - npt.assert_equal(big_extent_y > current_extent_y, True) - if interactive: - window.show(renderer, reset_camera=False) - - tensor_actor.display(z=1) - current_extent = renderer.GetActors().GetLastActor().GetBounds() - current_extent_z = abs(current_extent[5] - current_extent[4]) - big_extent_z = abs(big_extent[5] - big_extent[4]) - npt.assert_equal(big_extent_z > current_extent_z, True) - if interactive: - window.show(renderer, reset_camera=False) - - -@npt.dec.skipif(not run_test) -@xvfb_it -def test_dots(interactive=False): - points = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]]) - - dots_actor = actor.dots(points, color=(0, 255, 0)) - - renderer = window.Renderer() - renderer.add(dots_actor) - renderer.reset_camera() - renderer.reset_clipping_range() - - if interactive: - window.show(renderer, reset_camera=False) - - npt.assert_equal(renderer.GetActors().GetNumberOfItems(), 1) - - extent = renderer.GetActors().GetLastActor().GetBounds() - npt.assert_equal(extent, (0.0, 1.0, 0.0, 1.0, 0.0, 0.0)) - - arr = window.snapshot(renderer) - report = window.analyze_snapshot(arr, - colors=(0, 255, 0)) - npt.assert_equal(report.objects, 3) - - # Test one point - points = np.array([0, 0, 0]) - dot_actor = actor.dots(points, color=(0, 0, 255)) - - renderer.clear() - renderer.add(dot_actor) - renderer.reset_camera() - renderer.reset_clipping_range() - - arr = window.snapshot(renderer) - report = window.analyze_snapshot(arr, - colors=(0, 0, 255)) - npt.assert_equal(report.objects, 1) - - -@npt.dec.skipif(not run_test) -@xvfb_it -def test_points(interactive=False): - points = np.array([[0, 0, 0], [0, 1, 0], [1, 0, 0]]) - colors = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) - - points_actor = actor.point(points, colors) - - renderer = window.Renderer() - renderer.add(points_actor) - renderer.reset_camera() - renderer.reset_clipping_range() - - if interactive: - window.show(renderer, reset_camera=False) - - npt.assert_equal(renderer.GetActors().GetNumberOfItems(), 1) - - arr = window.snapshot(renderer) - report = window.analyze_snapshot(arr, - colors=colors) - npt.assert_equal(report.objects, 3) - - -@npt.dec.skipif(not run_test) -@xvfb_it -def test_labels(interactive=False): - - text_actor = actor.label("Hello") - - renderer = window.Renderer() - renderer.add(text_actor) - renderer.reset_camera() - renderer.reset_clipping_range() - - if interactive: - window.show(renderer, reset_camera=False) - - npt.assert_equal(renderer.GetActors().GetNumberOfItems(), 1) - - -if __name__ == "__main__": - npt.run_module_suite() diff --git a/dipy/viz/tests/test_fvtk.py b/dipy/viz/tests/test_fvtk.py deleted file mode 100644 index b72e434770..0000000000 --- a/dipy/viz/tests/test_fvtk.py +++ /dev/null @@ -1,148 +0,0 @@ -"""Testing visualization with fvtk.""" -import os -import warnings -import numpy as np -from distutils.version import LooseVersion - -from dipy.viz import fvtk -from dipy import data - -import numpy.testing as npt -from dipy.testing.decorators import xvfb_it -from dipy.utils.optpkg import optional_package - -use_xvfb = os.environ.get('TEST_WITH_XVFB', False) -if use_xvfb == 'skip': - skip_it = True -else: - skip_it = False - -cm, have_matplotlib, _ = optional_package('matplotlib.cm') - -if have_matplotlib: - import matplotlib - mpl_version = LooseVersion(matplotlib.__version__) - - -@npt.dec.skipif(not fvtk.have_vtk or not fvtk.have_vtk_colors or skip_it) -@xvfb_it -def test_fvtk_functions(): - # This tests will fail if any of the given actors changed inputs or do - # not exist - - # Create a renderer - r = fvtk.ren() - - # Create 2 lines with 2 different colors - lines = [np.random.rand(10, 3), np.random.rand(20, 3)] - colors = np.random.rand(2, 3) - c = fvtk.line(lines, colors) - fvtk.add(r, c) - - # create streamtubes of the same lines and shift them a bit - c2 = fvtk.streamtube(lines, colors) - c2.SetPosition(2, 0, 0) - fvtk.add(r, c2) - - # Create a volume and return a volumetric actor using volumetric rendering - vol = 100 * np.random.rand(100, 100, 100) - vol = vol.astype('uint8') - r = fvtk.ren() - v = fvtk.volume(vol) - fvtk.add(r, v) - - # Remove all objects - fvtk.rm_all(r) - - # Put some text - l = fvtk.label(r, text='Yes Men') - fvtk.add(r, l) - - # Slice the volume - slicer = fvtk.slicer(vol) - slicer.display(50, None, None) - fvtk.add(r, slicer) - - # Change the position of the active camera - fvtk.camera(r, pos=(0.6, 0, 0), verbose=False) - - fvtk.clear(r) - - # Peak directions - p = fvtk.peaks(np.random.rand(3, 3, 3, 5, 3)) - fvtk.add(r, p) - - p2 = fvtk.peaks(np.random.rand(3, 3, 3, 5, 3), - np.random.rand(3, 3, 3, 5), - colors=(0, 1, 0)) - fvtk.add(r, p2) - - -@npt.dec.skipif(not fvtk.have_vtk or not fvtk.have_vtk_colors or skip_it) -@xvfb_it -def test_fvtk_ellipsoid(): - - evals = np.array([1.4, .35, .35]) * 10 ** (-3) - evecs = np.eye(3) - - mevals = np.zeros((3, 2, 4, 3)) - mevecs = np.zeros((3, 2, 4, 3, 3)) - - mevals[..., :] = evals - mevecs[..., :, :] = evecs - - from dipy.data import get_sphere - - sphere = get_sphere('symmetric724') - - ren = fvtk.ren() - - fvtk.add(ren, fvtk.tensor(mevals, mevecs, sphere=sphere)) - - fvtk.add(ren, fvtk.tensor(mevals, mevecs, np.ones(mevals.shape), - sphere=sphere)) - - npt.assert_equal(ren.GetActors().GetNumberOfItems(), 2) - - -def test_colormap(): - v = np.linspace(0., .5) - map1 = fvtk.create_colormap(v, 'bone', auto=True) - map2 = fvtk.create_colormap(v, 'bone', auto=False) - npt.assert_(not np.allclose(map1, map2)) - - npt.assert_raises(ValueError, fvtk.create_colormap, np.ones((2, 3))) - npt.assert_raises(ValueError, fvtk.create_colormap, v, 'no such map') - - -@npt.dec.skipif(not fvtk.have_matplotlib) -def test_colormaps_matplotlib(): - v = np.random.random(1000) - # The "Accent" colormap is deprecated as of 0.12: - with warnings.catch_warnings(record=True) as w: - # Cause all warnings to always be triggered. - warnings.simplefilter("always") - data.get_cmap("Accent") - # Test that the deprecation warning was raised: - npt.assert_(len(w) > 0) - - names = ['jet', 'Blues', 'bone'] - - if have_matplotlib and mpl_version < "2": - names.append('Accent') - - for name in names: - with warnings.catch_warnings(record=True) as w: - # Matplotlib version of get_cmap - rgba1 = fvtk.get_cmap(name)(v) - # Dipy version of get_cmap - rgba2 = data.get_cmap(name)(v) - # dipy's colormaps are close to matplotlibs colormaps, but not - # perfect: - npt.assert_array_almost_equal(rgba1, rgba2, 1) - npt.assert_(len(w) == (1 if name == 'Accent' else 0)) - - - -if __name__ == "__main__": - npt.run_module_suite() diff --git a/dipy/viz/tests/test_interactor.py b/dipy/viz/tests/test_interactor.py deleted file mode 100644 index 79979b8f66..0000000000 --- a/dipy/viz/tests/test_interactor.py +++ /dev/null @@ -1,149 +0,0 @@ -import os -import numpy as np -from os.path import join as pjoin -from collections import defaultdict - -from dipy.viz import actor, window, interactor -from dipy.viz import utils as vtk_utils -from dipy.data import DATA_DIR -import numpy.testing as npt -from dipy.testing.decorators import xvfb_it - -# Conditional import machinery for vtk -from dipy.utils.optpkg import optional_package - -# Allow import, but disable doctests if we don't have vtk -vtk, have_vtk, setup_module = optional_package('vtk') - -use_xvfb = os.environ.get('TEST_WITH_XVFB', False) -if use_xvfb == 'skip': - skip_it = True -else: - skip_it = False - - -@npt.dec.skipif(not have_vtk or not actor.have_vtk_colors or skip_it) -@xvfb_it -def test_custom_interactor_style_events(recording=False): - print("Using VTK {}".format(vtk.vtkVersion.GetVTKVersion())) - filename = "test_custom_interactor_style_events.log.gz" - recording_filename = pjoin(DATA_DIR, filename) - renderer = window.Renderer() - - # the show manager allows to break the rendering process - # in steps so that the widgets can be added properly - interactor_style = interactor.CustomInteractorStyle() - show_manager = window.ShowManager(renderer, size=(800, 800), - reset_camera=False, - interactor_style=interactor_style) - - # Create a cursor, a circle that will follow the mouse. - polygon_source = vtk.vtkRegularPolygonSource() - polygon_source.GeneratePolygonOff() # Only the outline of the circle. - polygon_source.SetNumberOfSides(50) - polygon_source.SetRadius(10) - # polygon_source.SetRadius - polygon_source.SetCenter(0, 0, 0) - - mapper = vtk.vtkPolyDataMapper2D() - vtk_utils.set_input(mapper, polygon_source.GetOutputPort()) - - cursor = vtk.vtkActor2D() - cursor.SetMapper(mapper) - cursor.GetProperty().SetColor(1, 0.5, 0) - renderer.add(cursor) - - def follow_mouse(iren, obj): - obj.SetPosition(*iren.event.position) - iren.force_render() - - interactor_style.add_active_prop(cursor) - interactor_style.add_callback(cursor, "MouseMoveEvent", follow_mouse) - - # create some minimalistic streamlines - lines = [np.array([[-1, 0, 0.], [1, 0, 0.]]), - np.array([[-1, 1, 0.], [1, 1, 0.]])] - colors = np.array([[1., 0., 0.], [0.3, 0.7, 0.]]) - tube1 = actor.streamtube([lines[0]], colors[0]) - tube2 = actor.streamtube([lines[1]], colors[1]) - renderer.add(tube1) - renderer.add(tube2) - - # Define some counter callback. - states = defaultdict(lambda: 0) - - def counter(iren, obj): - states[iren.event.name] += 1 - - # Assign the counter callback to every possible event. - for event in ["CharEvent", "MouseMoveEvent", - "KeyPressEvent", "KeyReleaseEvent", - "LeftButtonPressEvent", "LeftButtonReleaseEvent", - "RightButtonPressEvent", "RightButtonReleaseEvent", - "MiddleButtonPressEvent", "MiddleButtonReleaseEvent"]: - interactor_style.add_callback(tube1, event, counter) - - # Add callback to scale up/down tube1. - def scale_up_obj(iren, obj): - counter(iren, obj) - scale = np.asarray(obj.GetScale()) + 0.1 - obj.SetScale(*scale) - iren.force_render() - iren.event.abort() # Stop propagating the event. - - def scale_down_obj(iren, obj): - counter(iren, obj) - scale = np.array(obj.GetScale()) - 0.1 - obj.SetScale(*scale) - iren.force_render() - iren.event.abort() # Stop propagating the event. - - interactor_style.add_callback(tube2, "MouseWheelForwardEvent", - scale_up_obj) - interactor_style.add_callback(tube2, "MouseWheelBackwardEvent", - scale_down_obj) - - # Add callback to hide/show tube1. - def toggle_visibility(iren, obj): - key = iren.event.key - if key.lower() == "v": - obj.SetVisibility(not obj.GetVisibility()) - iren.force_render() - - interactor_style.add_active_prop(tube1) - interactor_style.add_active_prop(tube2) - interactor_style.remove_active_prop(tube2) - interactor_style.add_callback(tube1, "CharEvent", toggle_visibility) - - if recording: - show_manager.record_events_to_file(recording_filename) - print(list(states.items())) - else: - show_manager.play_events_from_file(recording_filename) - msg = ("Wrong count for '{}'.") - expected = [('CharEvent', 6), - ('KeyPressEvent', 6), - ('KeyReleaseEvent', 6), - ('MouseMoveEvent', 1652), - ('LeftButtonPressEvent', 1), - ('RightButtonPressEvent', 1), - ('MiddleButtonPressEvent', 2), - ('LeftButtonReleaseEvent', 1), - ('MouseWheelForwardEvent', 3), - ('MouseWheelBackwardEvent', 1), - ('MiddleButtonReleaseEvent', 2), - ('RightButtonReleaseEvent', 1)] - - # Useful loop for debugging. - for event, count in expected: - if states[event] != count: - print("{}: {} vs. {} (expected)".format(event, - states[event], - count)) - - for event, count in expected: - npt.assert_equal(states[event], count, err_msg=msg.format(event)) - - -if __name__ == '__main__': - test_custom_interactor_style_events(recording=True) diff --git a/dipy/viz/tests/test_ui.py b/dipy/viz/tests/test_ui.py deleted file mode 100644 index 3ec275aa9c..0000000000 --- a/dipy/viz/tests/test_ui.py +++ /dev/null @@ -1,512 +0,0 @@ -import os -import sys -import pickle -import numpy as np - -from os.path import join as pjoin -import numpy.testing as npt - -from dipy.data import read_viz_icons, fetch_viz_icons -from dipy.viz import ui -from dipy.viz import window -from dipy.data import DATA_DIR -from nibabel.tmpdirs import InTemporaryDirectory - -from dipy.viz.ui import UI - -from dipy.testing.decorators import xvfb_it - -# Conditional import machinery for vtk -from dipy.utils.optpkg import optional_package - -# Allow import, but disable doctests if we don't have vtk -vtk, have_vtk, setup_module = optional_package('vtk') - -use_xvfb = os.environ.get('TEST_WITH_XVFB', False) -if use_xvfb == 'skip': - skip_it = True -else: - skip_it = False - -if have_vtk: - print("Using VTK {}".format(vtk.vtkVersion.GetVTKVersion())) - - -class EventCounter(object): - def __init__(self, events_names=["CharEvent", - "MouseMoveEvent", - "KeyPressEvent", - "KeyReleaseEvent", - "LeftButtonPressEvent", - "LeftButtonReleaseEvent", - "RightButtonPressEvent", - "RightButtonReleaseEvent", - "MiddleButtonPressEvent", - "MiddleButtonReleaseEvent"]): - # Events to count - self.events_counts = {name: 0 for name in events_names} - - def count(self, i_ren, obj, element): - """ Simple callback that counts events occurences. """ - self.events_counts[i_ren.event.name] += 1 - - def monitor(self, ui_component): - for event in self.events_counts: - for actor in ui_component.get_actors(): - ui_component.add_callback(actor, event, self.count) - - def save(self, filename): - with open(filename, 'wb') as f: - pickle.dump(self.events_counts, f, protocol=2) - - @classmethod - def load(cls, filename): - event_counter = cls() - with open(filename, 'rb') as f: - event_counter.events_counts = pickle.load(f) - - return event_counter - - def check_counts(self, expected): - npt.assert_equal(len(self.events_counts), - len(expected.events_counts)) - - # Useful loop for debugging. - msg = "{}: {} vs. {} (expected)" - for event, count in expected.events_counts.items(): - if self.events_counts[event] != count: - print(msg.format(event, self.events_counts[event], count)) - - msg = "Wrong count for '{}'." - for event, count in expected.events_counts.items(): - npt.assert_equal(self.events_counts[event], count, - err_msg=msg.format(event)) - - -@npt.dec.skipif(not have_vtk or skip_it) -@xvfb_it -def test_broken_ui_component(): - class BrokenUI(UI): - def __init__(self): - self.actor = vtk.vtkActor() - super(BrokenUI, self).__init__() - - broken_ui = BrokenUI() - npt.assert_raises(NotImplementedError, broken_ui.get_actors) - npt.assert_raises(NotImplementedError, broken_ui.set_center, (1, 2)) - - -@npt.dec.skipif(not have_vtk or skip_it) -@xvfb_it -def test_wrong_interactor_style(): - panel = ui.Panel2D(center=(440, 90), size=(300, 150)) - dummy_renderer = window.Renderer() - dummy_show_manager = window.ShowManager(dummy_renderer, - interactor_style='trackball') - npt.assert_raises(TypeError, panel.add_to_renderer, dummy_renderer) - - -@npt.dec.skipif(not have_vtk or skip_it) -@xvfb_it -def test_rectangle_2d(): - window_size = (700, 700) - show_manager = window.ShowManager(size=window_size) - - rect = ui.Rectangle2D(size=(100, 50)) - rect.set_position((50, 80)) - npt.assert_equal(rect.position, (50, 80)) - - rect.color = (1, 0.5, 0) - npt.assert_equal(rect.color, (1, 0.5, 0)) - - rect.opacity = 0.5 - npt.assert_equal(rect.opacity, 0.5) - - # Check the rectangle is drawn at right place. - show_manager.ren.add(rect) - # Uncomment this to start the visualisation - # show_manager.start() - - colors = [rect.color] - arr = window.snapshot(show_manager.ren, size=window_size, offscreen=True) - report = window.analyze_snapshot(arr, colors=colors) - assert report.objects == 1 - assert report.colors_found - - # Test visibility off. - rect.set_visibility(False) - arr = window.snapshot(show_manager.ren, size=window_size, offscreen=True) - report = window.analyze_snapshot(arr) - assert report.objects == 0 - - -@npt.dec.skipif(not have_vtk or skip_it) -@xvfb_it -def test_ui_button_panel(recording=False): - filename = "test_ui_button_panel" - recording_filename = pjoin(DATA_DIR, filename + ".log.gz") - expected_events_counts_filename = pjoin(DATA_DIR, filename + ".pkl") - - # Rectangle - rectangle_test = ui.Rectangle2D(size=(10, 10)) - rectangle_test.get_actors() - another_rectangle_test = ui.Rectangle2D(size=(1, 1)) - # /Rectangle - - # Button - fetch_viz_icons() - - icon_files = dict() - icon_files['stop'] = read_viz_icons(fname='stop2.png') - icon_files['play'] = read_viz_icons(fname='play3.png') - - button_test = ui.Button2D(icon_fnames=icon_files) - button_test.set_center((20, 20)) - - def make_invisible(i_ren, obj, button): - # i_ren: CustomInteractorStyle - # obj: vtkActor picked - # button: Button2D - button.set_visibility(False) - i_ren.force_render() - i_ren.event.abort() - - def modify_button_callback(i_ren, obj, button): - # i_ren: CustomInteractorStyle - # obj: vtkActor picked - # button: Button2D - button.next_icon() - i_ren.force_render() - - button_test.on_right_mouse_button_pressed = make_invisible - button_test.on_left_mouse_button_pressed = modify_button_callback - - button_test.scale((2, 2)) - button_color = button_test.color - button_test.color = button_color - # /Button - - # TextBlock - text_block_test = ui.TextBlock2D() - text_block_test.message = 'TextBlock' - text_block_test.color = (0, 0, 0) - - # Panel - panel = ui.Panel2D(center=(440, 90), size=(300, 150), - color=(1, 1, 1), align="right") - panel.add_element(rectangle_test, 'absolute', (580, 150)) - panel.add_element(button_test, 'relative', (0.2, 0.2)) - panel.add_element(text_block_test, 'relative', (0.7, 0.7)) - npt.assert_raises(ValueError, panel.add_element, another_rectangle_test, - 'error_string', (1, 2)) - # /Panel - - # Assign the counter callback to every possible event. - event_counter = EventCounter() - event_counter.monitor(button_test) - event_counter.monitor(panel) - - current_size = (600, 600) - show_manager = window.ShowManager(size=current_size, title="DIPY Button") - - show_manager.ren.add(panel) - - if recording: - show_manager.record_events_to_file(recording_filename) - print(list(event_counter.events_counts.items())) - event_counter.save(expected_events_counts_filename) - - else: - show_manager.play_events_from_file(recording_filename) - expected = EventCounter.load(expected_events_counts_filename) - event_counter.check_counts(expected) - - -@npt.dec.skipif(not have_vtk or skip_it) -@xvfb_it -def test_ui_textbox(recording=False): - filename = "test_ui_textbox" - recording_filename = pjoin(DATA_DIR, filename + ".log.gz") - expected_events_counts_filename = pjoin(DATA_DIR, filename + ".pkl") - - # TextBox - textbox_test = ui.TextBox2D(height=3, width=10, text="Text") - - another_textbox_test = ui.TextBox2D(height=3, width=10, text="Enter Text") - another_textbox_test.set_message("Enter Text") - another_textbox_test.set_center((10, 100)) - # /TextBox - - # Assign the counter callback to every possible event. - event_counter = EventCounter() - event_counter.monitor(textbox_test) - - current_size = (600, 600) - show_manager = window.ShowManager(size=current_size, title="DIPY TextBox") - - show_manager.ren.add(textbox_test) - - if recording: - show_manager.record_events_to_file(recording_filename) - print(list(event_counter.events_counts.items())) - event_counter.save(expected_events_counts_filename) - - else: - show_manager.play_events_from_file(recording_filename) - expected = EventCounter.load(expected_events_counts_filename) - event_counter.check_counts(expected) - - -@npt.dec.skipif(not have_vtk or skip_it) -@xvfb_it -def test_text_block_2d(): - text_block = ui.TextBlock2D() - - def _check_property(obj, attr, values): - for value in values: - setattr(obj, attr, value) - npt.assert_equal(getattr(obj, attr), value) - - _check_property(text_block, "bold", [True, False]) - _check_property(text_block, "italic", [True, False]) - _check_property(text_block, "shadow", [True, False]) - _check_property(text_block, "font_size", range(100)) - _check_property(text_block, "message", ["", "Hello World", "Line\nBreak"]) - _check_property(text_block, "justification", ["left", "center", "right"]) - _check_property(text_block, "position", [(350, 350), (0.5, 0.5)]) - _check_property(text_block, "color", [(0., 0.5, 1.)]) - _check_property(text_block, "background_color", [(0., 0.5, 1.), None]) - _check_property(text_block, "vertical_justification", - ["top", "middle", "bottom"]) - _check_property(text_block, "font_family", ["Arial", "Courier"]) - - with npt.assert_raises(ValueError): - text_block.font_family = "Verdana" - - with npt.assert_raises(ValueError): - text_block.justification = "bottom" - - with npt.assert_raises(ValueError): - text_block.vertical_justification = "left" - - -@npt.dec.skipif(not have_vtk or skip_it) -@xvfb_it -def test_text_block_2d_justification(): - window_size = (700, 700) - show_manager = window.ShowManager(size=window_size) - - # To help visualize the text positions. - grid_size = (500, 500) - bottom, middle, top = 50, 300, 550 - left, center, right = 50, 300, 550 - line_color = (1, 0, 0) - - grid_top = (center, top), (grid_size[0], 1) - grid_bottom = (center, bottom), (grid_size[0], 1) - grid_left = (left, middle), (1, grid_size[1]) - grid_right = (right, middle), (1, grid_size[1]) - grid_middle = (center, middle), (grid_size[0], 1) - grid_center = (center, middle), (1, grid_size[1]) - grid_specs = [grid_top, grid_bottom, grid_left, grid_right, - grid_middle, grid_center] - for spec in grid_specs: - line = ui.Rectangle2D(center=spec[0], size=spec[1], color=line_color) - show_manager.ren.add(line) - - font_size = 60 - bg_color = (1, 1, 1) - texts = [] - texts += [ui.TextBlock2D("HH", position=(left, top), - font_size=font_size, - color=(1, 0, 0), bg_color=bg_color, - justification="left", - vertical_justification="top")] - texts += [ui.TextBlock2D("HH", position=(center, top), - font_size=font_size, - color=(0, 1, 0), bg_color=bg_color, - justification="center", - vertical_justification="top")] - texts += [ui.TextBlock2D("HH", position=(right, top), - font_size=font_size, - color=(0, 0, 1), bg_color=bg_color, - justification="right", - vertical_justification="top")] - - texts += [ui.TextBlock2D("HH", position=(left, middle), - font_size=font_size, - color=(1, 1, 0), bg_color=bg_color, - justification="left", - vertical_justification="middle")] - texts += [ui.TextBlock2D("HH", position=(center, middle), - font_size=font_size, - color=(0, 1, 1), bg_color=bg_color, - justification="center", - vertical_justification="middle")] - texts += [ui.TextBlock2D("HH", position=(right, middle), - font_size=font_size, - color=(1, 0, 1), bg_color=bg_color, - justification="right", - vertical_justification="middle")] - - texts += [ui.TextBlock2D("HH", position=(left, bottom), - font_size=font_size, - color=(0.5, 0, 1), bg_color=bg_color, - justification="left", - vertical_justification="bottom")] - texts += [ui.TextBlock2D("HH", position=(center, bottom), - font_size=font_size, - color=(1, 0.5, 0), bg_color=bg_color, - justification="center", - vertical_justification="bottom")] - texts += [ui.TextBlock2D("HH", position=(right, bottom), - font_size=font_size, - color=(0, 1, 0.5), bg_color=bg_color, - justification="right", - vertical_justification="bottom")] - - show_manager.ren.add(*texts) - - # Uncomment this to start the visualisation - # show_manager.start() - - arr = window.snapshot(show_manager.ren, size=window_size, offscreen=True) - if vtk.vtkVersion.GetVTKVersion() == "6.0.0": - expected = np.load(pjoin(DATA_DIR, "test_ui_text_block.npz")) - npt.assert_array_almost_equal(arr, expected["arr_0"]) - - -@npt.dec.skipif(not have_vtk or skip_it) -@xvfb_it -def test_ui_line_slider_2d(recording=False): - filename = "test_ui_line_slider_2d" - recording_filename = pjoin(DATA_DIR, filename + ".log.gz") - expected_events_counts_filename = pjoin(DATA_DIR, filename + ".pkl") - - line_slider_2d_test = ui.LineSlider2D(initial_value=-2, - min_value=-5, max_value=5) - line_slider_2d_test.set_center((300, 300)) - - # Assign the counter callback to every possible event. - event_counter = EventCounter() - event_counter.monitor(line_slider_2d_test) - - current_size = (600, 600) - show_manager = window.ShowManager(size=current_size, - title="DIPY Line Slider") - - show_manager.ren.add(line_slider_2d_test) - - if recording: - show_manager.record_events_to_file(recording_filename) - print(list(event_counter.events_counts.items())) - event_counter.save(expected_events_counts_filename) - - else: - show_manager.play_events_from_file(recording_filename) - expected = EventCounter.load(expected_events_counts_filename) - event_counter.check_counts(expected) - - -@npt.dec.skipif(not have_vtk or skip_it) -@xvfb_it -def test_ui_disk_slider_2d(recording=False): - filename = "test_ui_disk_slider_2d" - recording_filename = pjoin(DATA_DIR, filename + ".log.gz") - expected_events_counts_filename = pjoin(DATA_DIR, filename + ".pkl") - - disk_slider_2d_test = ui.DiskSlider2D() - disk_slider_2d_test.set_center((300, 300)) - disk_slider_2d_test.value = 90 - - # Assign the counter callback to every possible event. - event_counter = EventCounter() - event_counter.monitor(disk_slider_2d_test) - - current_size = (600, 600) - show_manager = window.ShowManager(size=current_size, - title="DIPY Disk Slider") - - show_manager.ren.add(disk_slider_2d_test) - - if recording: - # Record the following events - # 1. Left Click on the disk and hold it - # 2. Move to the left the disk and make 1.5 tour - # 3. Release the disk - # 4. Left Click on the disk and hold it - # 5. Move to the right the disk and make 1 tour - # 6. Release the disk - show_manager.record_events_to_file(recording_filename) - print(list(event_counter.events_counts.items())) - event_counter.save(expected_events_counts_filename) - - else: - show_manager.play_events_from_file(recording_filename) - expected = EventCounter.load(expected_events_counts_filename) - event_counter.check_counts(expected) - - -@npt.dec.skipif(not have_vtk or skip_it) -@xvfb_it -def test_ui_file_select_menu_2d(recording=False): - filename = "test_ui_file_select_menu_2d" - recording_filename = pjoin(DATA_DIR, filename + ".log.gz") - expected_events_counts_filename = pjoin(DATA_DIR, filename + ".pkl") - with InTemporaryDirectory(): - for i in range(10): - _ = open("test" + str(i) + ".txt", 'wt').write('some text') - - file_select_menu = ui.FileSelectMenu2D(size=(500, 500), - position=(300, 300), - font_size=16, - extensions=["txt"], - directory_path=os.getcwd(), - parent=None) - file_select_menu.set_center((300, 300)) - - npt.assert_equal(file_select_menu.text_item_list[1].file_name[:4], "test") - npt.assert_equal(file_select_menu.text_item_list[5].file_name[:4], "test") - - event_counter = EventCounter() - for event in event_counter.events_counts: - file_select_menu.add_callback(file_select_menu.buttons["up"].actor, - event, event_counter.count) - file_select_menu.add_callback(file_select_menu.buttons["down"].actor, - event, event_counter.count) - file_select_menu.menu.add_callback(file_select_menu.menu.panel.actor, - event, event_counter.count) - for text_ui in file_select_menu.text_item_list: - file_select_menu.add_callback(text_ui.text_actor.get_actors()[0], - event, event_counter.count) - - current_size = (600, 600) - show_manager = window.ShowManager(size=current_size, - title="DIPY File Select Menu") - show_manager.ren.add(file_select_menu) - - if recording: - show_manager.record_events_to_file(recording_filename) - print(list(event_counter.events_counts.items())) - event_counter.save(expected_events_counts_filename) - - else: - show_manager.play_events_from_file(recording_filename) - expected = EventCounter.load(expected_events_counts_filename) - event_counter.check_counts(expected) - -if __name__ == "__main__": - if len(sys.argv) <= 1 or sys.argv[1] == "test_ui_button_panel": - test_ui_button_panel(recording=True) - - if len(sys.argv) <= 1 or sys.argv[1] == "test_ui_textbox": - test_ui_textbox(recording=True) - - if len(sys.argv) <= 1 or sys.argv[1] == "test_ui_line_slider_2d": - test_ui_line_slider_2d(recording=True) - - if len(sys.argv) <= 1 or sys.argv[1] == "test_ui_disk_slider_2d": - test_ui_disk_slider_2d(recording=True) - - if len(sys.argv) <= 1 or sys.argv[1] == "test_ui_file_select_menu_2d": - test_ui_file_select_menu_2d(recording=True) diff --git a/dipy/viz/tests/test_utils.py b/dipy/viz/tests/test_utils.py deleted file mode 100644 index 11873dace1..0000000000 --- a/dipy/viz/tests/test_utils.py +++ /dev/null @@ -1,72 +0,0 @@ -import numpy as np -import numpy.testing as npt -from dipy.viz.utils import map_coordinates_3d_4d - - -def trilinear_interp_numpy(input_array, indices): - """ Evaluate the input_array data at the given indices - """ - - if input_array.ndim <= 2 or input_array.ndim >= 5: - raise ValueError("Input array can only be 3d or 4d") - - x_indices = indices[:, 0] - y_indices = indices[:, 1] - z_indices = indices[:, 2] - - x0 = x_indices.astype(np.integer) - y0 = y_indices.astype(np.integer) - z0 = z_indices.astype(np.integer) - x1 = x0 + 1 - y1 = y0 + 1 - z1 = z0 + 1 - - # Check if xyz1 is beyond array boundary: - x1[np.where(x1 == input_array.shape[0])] = x0.max() - y1[np.where(y1 == input_array.shape[1])] = y0.max() - z1[np.where(z1 == input_array.shape[2])] = z0.max() - - if input_array.ndim == 3: - x = x_indices - x0 - y = y_indices - y0 - z = z_indices - z0 - - elif input_array.ndim == 4: - x = np.expand_dims(x_indices - x0, axis=1) - y = np.expand_dims(y_indices - y0, axis=1) - z = np.expand_dims(z_indices - z0, axis=1) - - output = (input_array[x0, y0, z0] * (1 - x) * (1 - y) * (1 - z) + - input_array[x1, y0, z0] * x * (1 - y) * (1 - z) + - input_array[x0, y1, z0] * (1 - x) * y * (1-z) + - input_array[x0, y0, z1] * (1 - x) * (1 - y) * z + - input_array[x1, y0, z1] * x * (1 - y) * z + - input_array[x0, y1, z1] * (1 - x) * y * z + - input_array[x1, y1, z0] * x * y * (1 - z) + - input_array[x1, y1, z1] * x * y * z) - - return output - - -def test_trilinear_interp(): - - A = np.zeros((5, 5, 5)) - A[2, 2, 2] = 1 - - indices = np.array([[0, 0, 0], [1, 1, 1], [2, 2, 2], [1.5, 1.5, 1.5]]) - - values = trilinear_interp_numpy(A, indices) - values2 = map_coordinates_3d_4d(A, indices) - npt.assert_almost_equal(values, values2) - - B = np.zeros((5, 5, 5, 3)) - B[2, 2, 2] = np.array([1, 1, 1]) - - values = trilinear_interp_numpy(B, indices) - values_4d = map_coordinates_3d_4d(B, indices) - npt.assert_almost_equal(values, values_4d) - - -if __name__ == '__main__': - - npt.run_module_suite() diff --git a/dipy/viz/tests/test_widgets.py b/dipy/viz/tests/test_widgets.py deleted file mode 100644 index 0011daa70d..0000000000 --- a/dipy/viz/tests/test_widgets.py +++ /dev/null @@ -1,199 +0,0 @@ -import os -import numpy as np -from os.path import join as pjoin - -from dipy.viz import actor, window, widget -from dipy.data import DATA_DIR -from dipy.data import fetch_viz_icons, read_viz_icons -import numpy.testing as npt -from dipy.testing.decorators import xvfb_it - -use_xvfb = os.environ.get('TEST_WITH_XVFB', False) -if use_xvfb == 'skip': - skip_it = True -else: - skip_it = False - - -@npt.dec.skipif(not actor.have_vtk or not actor.have_vtk_colors or skip_it) -@xvfb_it -def test_button_and_slider_widgets(): - recording = False - filename = "test_button_and_slider_widgets.log.gz" - recording_filename = pjoin(DATA_DIR, filename) - renderer = window.Renderer() - - # create some minimalistic streamlines - lines = [np.array([[-1, 0, 0.], [1, 0, 0.]]), - np.array([[-1, 1, 0.], [1, 1, 0.]])] - colors = np.array([[1., 0., 0.], [0.3, 0.7, 0.]]) - stream_actor = actor.streamtube(lines, colors) - - states = {'camera_button_count': 0, - 'plus_button_count': 0, - 'minus_button_count': 0, - 'slider_moved_count': 0, - } - - renderer.add(stream_actor) - - # the show manager allows to break the rendering process - # in steps so that the widgets can be added properly - show_manager = window.ShowManager(renderer, size=(800, 800)) - - if recording: - show_manager.initialize() - show_manager.render() - - def button_callback(obj, event): - print('Camera pressed') - states['camera_button_count'] += 1 - - def button_plus_callback(obj, event): - print('+ pressed') - states['plus_button_count'] += 1 - - def button_minus_callback(obj, event): - print('- pressed') - states['minus_button_count'] += 1 - - fetch_viz_icons() - button_png = read_viz_icons(fname='camera.png') - - button = widget.button(show_manager.iren, - show_manager.ren, - button_callback, - button_png, (.98, 1.), (80, 50)) - - button_png_plus = read_viz_icons(fname='plus.png') - button_plus = widget.button(show_manager.iren, - show_manager.ren, - button_plus_callback, - button_png_plus, (.98, .9), (120, 50)) - - button_png_minus = read_viz_icons(fname='minus.png') - button_minus = widget.button(show_manager.iren, - show_manager.ren, - button_minus_callback, - button_png_minus, (.98, .9), (50, 50)) - - def print_status(obj, event): - rep = obj.GetRepresentation() - stream_actor.SetPosition((rep.GetValue(), 0, 0)) - states['slider_moved_count'] += 1 - - slider = widget.slider(show_manager.iren, show_manager.ren, - callback=print_status, - min_value=-1, - max_value=1, - value=0., - label="X", - right_normalized_pos=(.98, 0.6), - size=(120, 0), label_format="%0.2lf") - - # This callback is used to update the buttons/sliders' position - # so they can stay on the right side of the window when the window - # is being resized. - - global size - size = renderer.GetSize() - - if recording: - show_manager.record_events_to_file(recording_filename) - print(states) - else: - show_manager.play_events_from_file(recording_filename) - npt.assert_equal(states["camera_button_count"], 7) - npt.assert_equal(states["plus_button_count"], 3) - npt.assert_equal(states["minus_button_count"], 4) - npt.assert_equal(states["slider_moved_count"], 116) - - if not recording: - button.Off() - slider.Off() - # Uncomment below to test the slider and button with analyze - # button.place(renderer) - # slider.place(renderer) - - arr = window.snapshot(renderer, size=(800, 800)) - report = window.analyze_snapshot(arr) - # import pylab as plt - # plt.imshow(report.labels, origin='lower') - # plt.show() - npt.assert_equal(report.objects, 4) - - report = window.analyze_renderer(renderer) - npt.assert_equal(report.actors, 1) - - -@npt.dec.skipif(not actor.have_vtk or not actor.have_vtk_colors or skip_it) -@xvfb_it -def test_text_widget(): - - interactive = False - - renderer = window.Renderer() - axes = actor.axes() - window.add(renderer, axes) - renderer.ResetCamera() - - show_manager = window.ShowManager(renderer, size=(900, 900)) - - if interactive: - show_manager.initialize() - show_manager.render() - - fetch_viz_icons() - button_png = read_viz_icons(fname='home3.png') - - def button_callback(obj, event): - print('Button Pressed') - - button = widget.button(show_manager.iren, - show_manager.ren, - button_callback, - button_png, (.8, 1.2), (100, 100)) - - global rulez - rulez = True - - def text_callback(obj, event): - - global rulez - print('Text selected') - if rulez: - obj.GetTextActor().SetInput("Diffusion Imaging Rulez!!") - rulez = False - else: - obj.GetTextActor().SetInput("Diffusion Imaging in Python") - rulez = True - show_manager.render() - - text = widget.text(show_manager.iren, - show_manager.ren, - text_callback, - message="Diffusion Imaging in Python", - left_down_pos=(0., 0.), - right_top_pos=(0.4, 0.05), - opacity=1., - border=False) - - if not interactive: - button.Off() - text.Off() - pass - - if interactive: - show_manager.render() - show_manager.start() - - arr = window.snapshot(renderer, size=(900, 900)) - report = window.analyze_snapshot(arr) - npt.assert_equal(report.objects, 3) - - # If you want to see the segmented objects after the analysis is finished - # you can use imshow(report.labels, origin='lower') - - -if __name__ == '__main__': - npt.run_module_suite() diff --git a/dipy/viz/tests/test_window.py b/dipy/viz/tests/test_window.py deleted file mode 100644 index 4907b00fc6..0000000000 --- a/dipy/viz/tests/test_window.py +++ /dev/null @@ -1,231 +0,0 @@ -import os -import numpy as np -from dipy.viz import actor, window -import numpy.testing as npt -from dipy.testing.decorators import xvfb_it - -use_xvfb = os.environ.get('TEST_WITH_XVFB', False) -if use_xvfb == 'skip': - skip_it = True -else: - skip_it = False - - -@npt.dec.skipif(not actor.have_vtk or not actor.have_vtk_colors or skip_it) -@xvfb_it -def test_renderer(): - - ren = window.Renderer() - - npt.assert_equal(ren.size(), (0, 0)) - - # background color for renderer (1, 0.5, 0) - # 0.001 added here to remove numerical errors when moving from float - # to int values - bg_float = (1, 0.501, 0) - - # that will come in the image in the 0-255 uint scale - bg_color = tuple((np.round(255 * np.array(bg_float))).astype('uint8')) - - ren.background(bg_float) - # window.show(ren) - arr = window.snapshot(ren) - - report = window.analyze_snapshot(arr, - bg_color=bg_color, - colors=[bg_color, (0, 127, 0)]) - npt.assert_equal(report.objects, 0) - npt.assert_equal(report.colors_found, [True, False]) - - axes = actor.axes() - ren.add(axes) - # window.show(ren) - - arr = window.snapshot(ren) - report = window.analyze_snapshot(arr, bg_color) - npt.assert_equal(report.objects, 1) - - ren.rm(axes) - arr = window.snapshot(ren) - report = window.analyze_snapshot(arr, bg_color) - npt.assert_equal(report.objects, 0) - - window.add(ren, axes) - arr = window.snapshot(ren) - report = window.analyze_snapshot(arr, bg_color) - npt.assert_equal(report.objects, 1) - - ren.rm_all() - arr = window.snapshot(ren) - report = window.analyze_snapshot(arr, bg_color) - npt.assert_equal(report.objects, 0) - - ren2 = window.renderer(bg_float) - ren2.background((0, 0, 0.)) - - report = window.analyze_renderer(ren2) - npt.assert_equal(report.bg_color, (0, 0, 0)) - - ren2.add(axes) - - report = window.analyze_renderer(ren2) - npt.assert_equal(report.actors, 3) - - window.rm(ren2, axes) - report = window.analyze_renderer(ren2) - npt.assert_equal(report.actors, 0) - - -@npt.dec.skipif(not actor.have_vtk or not actor.have_vtk_colors or skip_it) -@xvfb_it -def test_active_camera(): - renderer = window.Renderer() - renderer.add(actor.axes(scale=(1, 1, 1))) - - renderer.reset_camera() - renderer.reset_clipping_range() - - direction = renderer.camera_direction() - position, focal_point, view_up = renderer.get_camera() - - renderer.set_camera((0., 0., 1.), (0., 0., 0), view_up) - - position, focal_point, view_up = renderer.get_camera() - npt.assert_almost_equal(np.dot(direction, position), -1) - - renderer.zoom(1.5) - - new_position, _, _ = renderer.get_camera() - - npt.assert_array_almost_equal(position, new_position) - - renderer.zoom(1) - - # rotate around focal point - renderer.azimuth(90) - - position, _, _ = renderer.get_camera() - - npt.assert_almost_equal(position, (1.0, 0.0, 0)) - - arr = window.snapshot(renderer) - report = window.analyze_snapshot(arr, colors=[(255, 0, 0)]) - npt.assert_equal(report.colors_found, [True]) - - # rotate around camera's center - renderer.yaw(90) - - arr = window.snapshot(renderer) - report = window.analyze_snapshot(arr, colors=[(0, 0, 0)]) - npt.assert_equal(report.colors_found, [True]) - - renderer.yaw(-90) - renderer.elevation(90) - - arr = window.snapshot(renderer) - report = window.analyze_snapshot(arr, colors=(0, 255, 0)) - npt.assert_equal(report.colors_found, [True]) - - renderer.set_camera((0., 0., 1.), (0., 0., 0), view_up) - - # vertical rotation of the camera around the focal point - renderer.pitch(10) - renderer.pitch(-10) - - # rotate around the direction of projection - renderer.roll(90) - - # inverted normalized distance from focal point along the direction - # of the camera - - position, _, _ = renderer.get_camera() - renderer.dolly(0.5) - new_position, _, _ = renderer.get_camera() - npt.assert_almost_equal(position[2], 0.5 * new_position[2]) - - -@npt.dec.skipif(not actor.have_vtk or not actor.have_vtk_colors or skip_it) -@xvfb_it -def test_parallel_projection(): - - ren = window.Renderer() - axes = actor.axes() - axes2 = actor.axes() - axes2.SetPosition((2, 0, 0)) - - # Add both axes. - ren.add(axes, axes2) - - # Put the camera on a angle so that the - # camera can show the difference between perspective - # and parallel projection - ren.set_camera((1.5, 1.5, 1.5)) - ren.GetActiveCamera().Zoom(2) - - # window.show(ren, reset_camera=True) - ren.reset_camera() - arr = window.snapshot(ren) - - ren.projection('parallel') - # window.show(ren, reset_camera=False) - arr2 = window.snapshot(ren) - # Because of the parallel projection the two axes - # will have the same size and therefore occupy more - # pixels rather than in perspective projection were - # the axes being further will be smaller. - npt.assert_equal(np.sum(arr2 > 0) > np.sum(arr > 0), True) - - -@npt.dec.skipif(not actor.have_vtk or not actor.have_vtk_colors or skip_it) -@xvfb_it -def test_order_transparent(): - - renderer = window.Renderer() - - lines = [np.array([[-1, 0, 0.], [1, 0, 0.]]), - np.array([[-1, 1, 0.], [1, 1, 0.]])] - colors = np.array([[1., 0., 0.], [0., .5, 0.]]) - stream_actor = actor.streamtube(lines, colors, linewidth=0.3, opacity=0.5) - - renderer.add(stream_actor) - - renderer.reset_camera() - - # green in front - renderer.elevation(90) - renderer.camera().OrthogonalizeViewUp() - renderer.reset_clipping_range() - - renderer.reset_camera() - - not_xvfb = os.environ.get("TEST_WITH_XVFB", False) - - if not_xvfb: - arr = window.snapshot(renderer, fname='green_front.png', - offscreen=True, order_transparent=False) - else: - arr = window.snapshot(renderer, fname='green_front.png', - offscreen=False, order_transparent=False) - - # therefore the green component must have a higher value (in RGB terms) - npt.assert_equal(arr[150, 150][1] > arr[150, 150][0], True) - - # red in front - renderer.elevation(-180) - renderer.camera().OrthogonalizeViewUp() - renderer.reset_clipping_range() - - if not_xvfb: - arr = window.snapshot(renderer, fname='red_front.png', - offscreen=True, order_transparent=True) - else: - arr = window.snapshot(renderer, fname='red_front.png', - offscreen=False, order_transparent=True) - - # therefore the red component must have a higher value (in RGB terms) - npt.assert_equal(arr[150, 150][0] > arr[150, 150][1], True) - - -if __name__ == '__main__': - - npt.run_module_suite() diff --git a/dipy/viz/ui.py b/dipy/viz/ui.py deleted file mode 100644 index 82da27bca4..0000000000 --- a/dipy/viz/ui.py +++ /dev/null @@ -1,2637 +0,0 @@ -from __future__ import division -from _warnings import warn - -import os -import glob -import numpy as np - -from dipy.data import read_viz_icons -from dipy.viz.interactor import CustomInteractorStyle - -from dipy.utils.optpkg import optional_package - -# Allow import, but disable doctests if we don't have vtk. -vtk, have_vtk, setup_module = optional_package('vtk') - -if have_vtk: - version = vtk.vtkVersion.GetVTKVersion() - major_version = vtk.vtkVersion.GetVTKMajorVersion() - -TWO_PI = 2 * np.pi - - -class UI(object): - """ An umbrella class for all UI elements. - - While adding UI elements to the renderer, we go over all the sub-elements - that come with it and add those to the renderer automatically. - - Attributes - ---------- - ui_param : object - This is an attribute that can be passed to the UI object by the - interactor. - ui_list : list of :class:`UI` - This is used when there are more than one UI elements inside - a UI element. They're all automatically added to the renderer at the - same time as this one. - parent_ui: UI - Reference to the parent UI element. This is useful of there is a parent - UI element and its reference needs to be passed down to the child. - on_left_mouse_button_pressed: function - Callback function for when the left mouse button is pressed. - on_left_mouse_button_released: function - Callback function for when the left mouse button is released. - on_left_mouse_button_clicked: function - Callback function for when clicked using the left mouse button - (i.e. pressed -> released). - on_left_mouse_button_dragged: function - Callback function for when dragging using the left mouse button. - on_right_mouse_button_pressed: function - Callback function for when the right mouse button is pressed. - on_right_mouse_button_released: function - Callback function for when the right mouse button is released. - on_right_mouse_button_clicked: function - Callback function for when clicking using the right mouse button - (i.e. pressed -> released). - on_right_mouse_button_dragged: function - Callback function for when dragging using the right mouse button. - - """ - - def __init__(self): - self.ui_param = None - self.ui_list = list() - - self.parent_ui = None - self._callbacks = [] - - self.left_button_state = "released" - self.right_button_state = "released" - - self.on_left_mouse_button_pressed = lambda i_ren, obj, element: None - self.on_left_mouse_button_dragged = lambda i_ren, obj, element: None - self.on_left_mouse_button_released = lambda i_ren, obj, element: None - self.on_left_mouse_button_clicked = lambda i_ren, obj, element: None - self.on_right_mouse_button_pressed = lambda i_ren, obj, element: None - self.on_right_mouse_button_released = lambda i_ren, obj, element: None - self.on_right_mouse_button_clicked = lambda i_ren, obj, element: None - self.on_right_mouse_button_dragged = lambda i_ren, obj, element: None - self.on_key_press = lambda i_ren, obj, element: None - - def get_actors(self): - """ Returns the actors that compose this UI component. - - """ - msg = "Subclasses of UI must implement `get_actors(self)`." - raise NotImplementedError(msg) - - def add_to_renderer(self, ren): - """ Allows UI objects to add their own props to the renderer. - - Parameters - ---------- - ren : renderer - - """ - ren.add(*self.get_actors()) - - # Get a hold on the current interactor style. - iren = ren.GetRenderWindow().GetInteractor().GetInteractorStyle() - - for callback in self._callbacks: - if not isinstance(iren, CustomInteractorStyle): - msg = ("The ShowManager requires `CustomInteractorStyle` in" - " order to use callbacks.") - raise TypeError(msg) - - iren.add_callback(*callback, args=[self]) - - def add_callback(self, prop, event_type, callback, priority=0): - """ Adds a callback to a specific event for this UI component. - - Parameters - ---------- - prop : vtkProp - The prop on which is callback is to be added. - event_type : string - The event code. - callback : function - The callback function. - priority : int - Higher number is higher priority. - - """ - # Actually since we need an interactor style we will add the callback - # only when this UI component is added to the renderer. - self._callbacks.append((prop, event_type, callback, priority)) - - def set_center(self, position): - """ Sets the center of the UI component - - Parameters - ---------- - position : (float, float) - These are the x and y coordinates respectively, with the - origin at the bottom left. - - """ - msg = "Subclasses of UI must implement `set_center(self, position)`." - raise NotImplementedError(msg) - - def set_visibility(self, visibility): - """ Sets visibility of this UI component and all its sub-components. - - """ - for actor in self.get_actors(): - actor.SetVisibility(visibility) - - def handle_events(self, actor): - self.add_callback(actor, "LeftButtonPressEvent", self.left_button_click_callback) - self.add_callback(actor, "LeftButtonReleaseEvent", self.left_button_release_callback) - self.add_callback(actor, "RightButtonPressEvent", self.right_button_click_callback) - self.add_callback(actor, "RightButtonReleaseEvent", self.right_button_release_callback) - self.add_callback(actor, "MouseMoveEvent", self.mouse_move_callback) - self.add_callback(actor, "KeyPressEvent", self.key_press_callback) - - @staticmethod - def left_button_click_callback(i_ren, obj, self): - self.left_button_state = "pressing" - self.on_left_mouse_button_pressed(i_ren, obj, self) - i_ren.event.abort() - - @staticmethod - def left_button_release_callback(i_ren, obj, self): - if self.left_button_state == "pressing": - self.on_left_mouse_button_clicked(i_ren, obj, self) - self.left_button_state = "released" - self.on_left_mouse_button_released(i_ren, obj, self) - - @staticmethod - def right_button_click_callback(i_ren, obj, self): - self.right_button_state = "pressing" - self.on_right_mouse_button_pressed(i_ren, obj, self) - i_ren.event.abort() - - @staticmethod - def right_button_release_callback(i_ren, obj, self): - if self.right_button_state == "pressing": - self.on_right_mouse_button_clicked(i_ren, obj, self) - self.right_button_state = "released" - self.on_right_mouse_button_released(i_ren, obj, self) - - @staticmethod - def mouse_move_callback(i_ren, obj, self): - if self.left_button_state == "pressing" or self.left_button_state == "dragging": - self.left_button_state = "dragging" - self.on_left_mouse_button_dragged(i_ren, obj, self) - elif self.right_button_state == "pressing" or self.right_button_state == "dragging": - self.right_button_state = "dragging" - self.on_right_mouse_button_dragged(i_ren, obj, self) - else: - pass - - @staticmethod - def key_press_callback(i_ren, obj, self): - self.on_key_press(i_ren, obj, self) - - -class Button2D(UI): - """ A 2D overlay button and is of type vtkTexturedActor2D. - Currently supports: - - Multiple icons. - - Switching between icons. - - Attributes - ---------- - size: (float, float) - Button size (width, height) in pixels. - - """ - - def __init__(self, icon_fnames, size=(30, 30)): - """ - Parameters - ---------- - size : 2-tuple of int, optional - Button size. - icon_fnames : dict - {iconname : filename, iconname : filename, ...} - - """ - super(Button2D, self).__init__() - self.icon_extents = dict() - self.icons = self.__build_icons(icon_fnames) - self.icon_names = list(self.icons.keys()) - self.current_icon_id = 0 - self.current_icon_name = self.icon_names[self.current_icon_id] - self.actor = self.build_actor(self.icons[self.current_icon_name]) - self.size = size - self.handle_events(self.actor) - - def __build_icons(self, icon_fnames): - """ Converts file names to vtkImageDataGeometryFilters. - - A pre-processing step to prevent re-read of file names during every - state change. - - Parameters - ---------- - icon_fnames : dict - {iconname: filename, iconname: filename, ...} - - Returns - ------- - icons : dict - A dictionary of corresponding vtkImageDataGeometryFilters. - - """ - icons = {} - for icon_name, icon_fname in icon_fnames.items(): - if icon_fname.split(".")[-1] not in ["png", "PNG"]: - error_msg = "A specified icon file is not in the PNG format. SKIPPING." - warn(Warning(error_msg)) - else: - png = vtk.vtkPNGReader() - png.SetFileName(icon_fname) - png.Update() - icons[icon_name] = png.GetOutput() - - return icons - - @property - def size(self): - """ Gets the button size. - - """ - return self._size - - @size.setter - def size(self, size): - """ Sets the button size. - - Parameters - ---------- - size : (float, float) - Button size (width, height) in pixels. - - """ - self._size = np.asarray(size) - - # Update actor. - self.texture_points.SetPoint(0, 0, 0, 0.0) - self.texture_points.SetPoint(1, size[0], 0, 0.0) - self.texture_points.SetPoint(2, size[0], size[1], 0.0) - self.texture_points.SetPoint(3, 0, size[1], 0.0) - self.texture_polydata.SetPoints(self.texture_points) - - @property - def color(self): - """ Gets the button's color. - - """ - color = self.actor.GetProperty().GetColor() - return np.asarray(color) - - @color.setter - def color(self, color): - """ Sets the button's color. - - Parameters - ---------- - color : (float, float, float) - RGB. Must take values in [0, 1]. - - """ - self.actor.GetProperty().SetColor(*color) - - def scale(self, size): - """ Scales the button. - - Parameters - ---------- - size : (float, float) - Scaling factor (width, height) in pixels. - - """ - self.size *= size - - def build_actor(self, icon): - """ Return an image as a 2D actor with a specific position. - - Parameters - ---------- - icon : :class:`vtkImageData` - - Returns - ------- - :class:`vtkTexturedActor2D` - - """ - # This is highly inspired by - # https://github.com/Kitware/VTK/blob/c3ec2495b183e3327820e927af7f8f90d34c3474\ - # /Interaction/Widgets/vtkBalloonRepresentation.cxx#L47 - - self.texture_polydata = vtk.vtkPolyData() - self.texture_points = vtk.vtkPoints() - self.texture_points.SetNumberOfPoints(4) - self.size = icon.GetExtent() - - polys = vtk.vtkCellArray() - polys.InsertNextCell(4) - polys.InsertCellPoint(0) - polys.InsertCellPoint(1) - polys.InsertCellPoint(2) - polys.InsertCellPoint(3) - self.texture_polydata.SetPolys(polys) - - tc = vtk.vtkFloatArray() - tc.SetNumberOfComponents(2) - tc.SetNumberOfTuples(4) - tc.InsertComponent(0, 0, 0.0) - tc.InsertComponent(0, 1, 0.0) - tc.InsertComponent(1, 0, 1.0) - tc.InsertComponent(1, 1, 0.0) - tc.InsertComponent(2, 0, 1.0) - tc.InsertComponent(2, 1, 1.0) - tc.InsertComponent(3, 0, 0.0) - tc.InsertComponent(3, 1, 1.0) - self.texture_polydata.GetPointData().SetTCoords(tc) - - texture_mapper = vtk.vtkPolyDataMapper2D() - if major_version <= 5: - texture_mapper.SetInput(self.texture_polydata) - else: - texture_mapper.SetInputData(self.texture_polydata) - - button = vtk.vtkTexturedActor2D() - button.SetMapper(texture_mapper) - - self.texture = vtk.vtkTexture() - button.SetTexture(self.texture) - - button_property = vtk.vtkProperty2D() - button_property.SetOpacity(1.0) - button.SetProperty(button_property) - - self.set_icon(icon) - return button - - def get_actors(self): - """ Returns the actors that compose this UI component. - - """ - return [self.actor] - - def set_icon(self, icon): - """ Modifies the icon used by the vtkTexturedActor2D. - - Parameters - ---------- - icon : imageDataGeometryFilter - - """ - if major_version <= 5: - self.texture.SetInput(icon) - else: - self.texture.SetInputData(icon) - - def next_icon_name(self): - """ Returns the next icon name while cycling through icons. - - """ - self.current_icon_id += 1 - if self.current_icon_id == len(self.icons): - self.current_icon_id = 0 - self.current_icon_name = self.icon_names[self.current_icon_id] - - def next_icon(self): - """ Increments the state of the Button. - - Also changes the icon. - - """ - self.next_icon_name() - self.set_icon(self.icons[self.current_icon_name]) - - def set_center(self, position): - """ Sets the icon center to position. - - Parameters - ---------- - position : (float, float) - The new center of the button (x, y). - - """ - new_position = np.asarray(position) - self.size / 2. - self.actor.SetPosition(*new_position) - - -class Rectangle2D(UI): - """ A 2D rectangle sub-classed from UI. - Uses vtkPolygon. - - Attributes - ---------- - size : (float, float) - The size of the rectangle (height, width) in pixels. - - """ - - def __init__(self, size, center=(0, 0), color=(1, 1, 1), opacity=1.0): - """ Initializes a rectangle. - - Parameters - ---------- - size : (float, float) - The size of the rectangle (height, width) in pixels. - center : (float, float) - The center of the rectangle (x, y). - color : (float, float, float) - Must take values in [0, 1]. - opacity : float - Must take values in [0, 1]. - - """ - super(Rectangle2D, self).__init__() - self.size = size - self.actor = self.build_actor(size=size) - self.color = color - self.set_center(center) - self.opacity = opacity - self.handle_events(self.actor) - - def get_actors(self): - """ Returns the actors that compose this UI component. - - """ - return [self.actor] - - def build_actor(self, size): - """ Builds the text actor. - - Parameters - ---------- - size : (float, float) - The size of the rectangle (height, width) in pixels. - - Returns - ------- - :class:`vtkActor2D` - - """ - # Setup four points - points = vtk.vtkPoints() - points.InsertNextPoint(0, 0, 0) - points.InsertNextPoint(size[0], 0, 0) - points.InsertNextPoint(size[0], size[1], 0) - points.InsertNextPoint(0, size[1], 0) - - # Create the polygon - polygon = vtk.vtkPolygon() - polygon.GetPointIds().SetNumberOfIds(4) # make a quad - polygon.GetPointIds().SetId(0, 0) - polygon.GetPointIds().SetId(1, 1) - polygon.GetPointIds().SetId(2, 2) - polygon.GetPointIds().SetId(3, 3) - - # Add the polygon to a list of polygons - polygons = vtk.vtkCellArray() - polygons.InsertNextCell(polygon) - - # Create a PolyData - polygonPolyData = vtk.vtkPolyData() - polygonPolyData.SetPoints(points) - polygonPolyData.SetPolys(polygons) - - # Create a mapper and actor - mapper = vtk.vtkPolyDataMapper2D() - if vtk.VTK_MAJOR_VERSION <= 5: - mapper.SetInput(polygonPolyData) - else: - mapper.SetInputData(polygonPolyData) - - actor = vtk.vtkActor2D() - actor.SetMapper(mapper) - - return actor - - def set_position(self, position): - self.actor.SetPosition(*position) - - def set_center(self, position): - """ Sets the center to position. - - Parameters - ---------- - position : (float, float) - The new center of the rectangle (x, y). - - """ - self.actor.SetPosition(position[0] - self.size[0] / 2, - position[1] - self.size[1] / 2) - - @property - def color(self): - """ Gets the rectangle's color. - - """ - color = self.actor.GetProperty().GetColor() - return np.asarray(color) - - @color.setter - def color(self, color): - """ Sets the rectangle's color. - - Parameters - ---------- - color : (float, float, float) - RGB. Must take values in [0, 1]. - - """ - self.actor.GetProperty().SetColor(*color) - - @property - def opacity(self): - """ Gets the rectangle's opacity. - - """ - return self.actor.GetProperty().GetOpacity() - - @opacity.setter - def opacity(self, opacity): - """ Sets the rectangle's opacity. - - Parameters - ---------- - opacity : float - Degree of transparency. Must be between [0, 1]. - - """ - self.actor.GetProperty().SetOpacity(opacity) - - @property - def position(self): - """ Gets text actor position. - - Returns - ------- - (float, float) - The current actor position. (x, y) in pixels. - - """ - return self.actor.GetPosition() - - @position.setter - def position(self, position): - """ Set text actor position. - - Parameters - ---------- - position : (float, float) - The new position. (x, y) in pixels. - - """ - self.actor.SetPosition(*position) - - -class Panel2D(UI): - """ A 2D UI Panel. - - Can contain one or more UI elements. - - Attributes - ---------- - center : (float, float) - The center of the panel (x, y). - size : (float, float) - The size of the panel (width, height) in pixels. - alignment : [left, right] - Alignment of the panel with respect to the overall screen. - - """ - - def __init__(self, center, size, color=(0.1, 0.1, 0.1), opacity=0.7, align="left"): - """ - Parameters - ---------- - center : (float, float) - The center of the panel (x, y). - size : (float, float) - The size of the panel (width, height) in pixels. - color : (float, float, float) - Must take values in [0, 1]. - opacity : float - Must take values in [0, 1]. - align : [left, right] - Alignment of the panel with respect to the overall screen. - - """ - super(Panel2D, self).__init__() - self.center = center - self.size = size - self.lower_limits = (self.center[0] - self.size[0] / 2, - self.center[1] - self.size[1] / 2) - - self.panel = Rectangle2D(size=size, center=center, color=color, - opacity=opacity) - - self.element_positions = [] - self.element_positions.append([self.panel, 'relative', 0.5, 0.5]) - self.alignment = align - - self.handle_events(self.panel.actor) - - self.on_left_mouse_button_pressed = self.left_button_pressed - self.on_left_mouse_button_dragged = self.left_button_dragged - - def add_to_renderer(self, ren): - """ Allows UI objects to add their own props to the renderer. - - Here, we add only call add_to_renderer for the additional components. - - Parameters - ---------- - ren : renderer - - """ - super(Panel2D, self).add_to_renderer(ren) - for ui_item in self.ui_list: - ui_item.add_to_renderer(ren) - - def get_actors(self): - """ Returns the panel actor. - - """ - return [self.panel.actor] - - def add_element(self, element, position_type, position): - """ Adds an element to the panel. - - The center of the rectangular panel is its bottom lower position. - - Parameters - ---------- - element : UI - The UI item to be added. - position_type: string - 'absolute' or 'relative' - position : (float, float) - Absolute for absolute and relative for relative - - """ - self.ui_list.append(element) - if position_type == 'relative': - self.element_positions.append([element, position_type, position[0], position[1]]) - element.set_center((self.lower_limits[0] + position[0] * self.size[0], - self.lower_limits[1] + position[1] * self.size[1])) - elif position_type == 'absolute': - self.element_positions.append([element, position_type, position[0], position[1]]) - element.set_center((position[0], position[1])) - else: - raise ValueError("Position can only be absolute or relative") - - def set_center(self, position): - """ Sets the panel center to position. - - The center of the rectangular panel is its bottom lower position. - - Parameters - ---------- - position : (float, float) - The new center of the panel (x, y). - - """ - shift = [position[0] - self.center[0], position[1] - self.center[1]] - self.center = position - self.lower_limits = (position[0] - self.size[0] / 2, position[1] - self.size[1] / 2) - for ui_element in self.element_positions: - if ui_element[1] == 'relative': - ui_element[0].set_center((self.lower_limits[0] + ui_element[2] * self.size[0], - self.lower_limits[1] + ui_element[3] * self.size[1])) - elif ui_element[1] == 'absolute': - ui_element[2] += shift[0] - ui_element[3] += shift[1] - ui_element[0].set_center((ui_element[2], ui_element[3])) - - @staticmethod - def left_button_pressed(i_ren, obj, panel2d_object): - click_position = i_ren.event.position - panel2d_object.ui_param = (click_position[0] - - panel2d_object.panel.actor.GetPosition()[0] - - panel2d_object.panel.size[0] / 2, - click_position[1] - - panel2d_object.panel.actor.GetPosition()[1] - - panel2d_object.panel.size[1] / 2) - i_ren.event.abort() # Stop propagating the event. - - @staticmethod - def left_button_dragged(i_ren, obj, panel2d_object): - click_position = i_ren.event.position - if panel2d_object.ui_param is not None: - panel2d_object.set_center((click_position[0] - panel2d_object.ui_param[0], - click_position[1] - panel2d_object.ui_param[1])) - i_ren.force_render() - - def re_align(self, window_size_change): - """ Re-organises the elements in case the window size is changed. - - Parameters - ---------- - window_size_change : (int, int) - New window size (width, height) in pixels. - - """ - if self.alignment == "left": - pass - elif self.alignment == "right": - self.set_center((self.center[0] + window_size_change[0], - self.center[1] + window_size_change[1])) - else: - raise ValueError("You can only left-align or right-align objects in a panel.") - - -class TextBlock2D(UI): - """ Wraps over the default vtkTextActor and helps setting the text. - - Contains member functions for text formatting. - - Attributes - ---------- - actor : :class:`vtkTextActor` - The text actor. - message : str - The initial text while building the actor. - position : (float, float) - (x, y) in pixels. - color : (float, float, float) - RGB: Values must be between 0-1. - bg_color : (float, float, float) - RGB: Values must be between 0-1. - font_size : int - Size of the text font. - font_family : str - Currently only supports Arial. - justification : str - left, right or center. - vertical_justification : str - bottom, middle or top. - bold : bool - Makes text bold. - italic : bool - Makes text italicised. - shadow : bool - Adds text shadow. - """ - - def __init__(self, text="Text Block", font_size=18, font_family='Arial', - justification='left', vertical_justification="bottom", - bold=False, italic=False, shadow=False, - color=(1, 1, 1), bg_color=None, position=(0, 0)): - """ - Parameters - ---------- - text : str - The initial text while building the actor. - position : (float, float) - (x, y) in pixels. - color : (float, float, float) - RGB: Values must be between 0-1. - bg_color : (float, float, float) - RGB: Values must be between 0-1. - font_size : int - Size of the text font. - font_family : str - Currently only supports Arial. - justification : str - left, right or center. - vertical_justification : str - bottom, middle or top. - bold : bool - Makes text bold. - italic : bool - Makes text italicised. - shadow : bool - Adds text shadow. - """ - super(TextBlock2D, self).__init__() - self.actor = vtk.vtkTextActor() - - self._background = None # For VTK < 7 - self.position = position - self.color = color - self.background_color = bg_color - self.font_size = font_size - self.font_family = font_family - self.justification = justification - self.bold = bold - self.italic = italic - self.shadow = shadow - self.vertical_justification = vertical_justification - self.message = text - - def get_actor(self): - """ Returns the actor composing this element. - - Returns - ------- - :class:`vtkTextActor` - The actor composing this class. - """ - return self.actor - - def get_actors(self): - """ Returns the actors that compose this UI component. - - """ - if self._background is not None: - return [self._background, self.actor] - - return [self.actor] - - @property - def message(self): - """ Gets message from the text. - - Returns - ------- - str - The current text message. - - """ - return self.actor.GetInput() - - @message.setter - def message(self, text): - """ Sets the text message. - - Parameters - ---------- - text : str - The message to be set. - - """ - self.actor.SetInput(text) - - @property - def font_size(self): - """ Gets text font size. - - Returns - ---------- - int - Text font size. - - """ - return self.actor.GetTextProperty().GetFontSize() - - @font_size.setter - def font_size(self, size): - """ Sets font size. - - Parameters - ---------- - size : int - Text font size. - - """ - self.actor.GetTextProperty().SetFontSize(size) - - @property - def font_family(self): - """ Gets font family. - - Returns - ---------- - str - Text font family. - - """ - return self.actor.GetTextProperty().GetFontFamilyAsString() - - @font_family.setter - def font_family(self, family='Arial'): - """ Sets font family. - - Currently Arial and Courier are supported. - - Parameters - ---------- - family : str - The font family. - - """ - if family == 'Arial': - self.actor.GetTextProperty().SetFontFamilyToArial() - elif family == 'Courier': - self.actor.GetTextProperty().SetFontFamilyToCourier() - else: - raise ValueError("Font not supported yet: {}.".format(family)) - - @property - def justification(self): - """ Gets text justification. - - Returns - ------- - str - Text justification. - - """ - justification = self.actor.GetTextProperty().GetJustificationAsString() - if justification == 'Left': - return "left" - elif justification == 'Centered': - return "center" - elif justification == 'Right': - return "right" - - @justification.setter - def justification(self, justification): - """ Justifies text. - - Parameters - ---------- - justification : str - Possible values are left, right, center. - - """ - text_property = self.actor.GetTextProperty() - if justification == 'left': - text_property.SetJustificationToLeft() - elif justification == 'center': - text_property.SetJustificationToCentered() - elif justification == 'right': - text_property.SetJustificationToRight() - else: - raise ValueError("Text can only be justified left, right and center.") - - @property - def vertical_justification(self): - """ Gets text vertical justification. - - Returns - ------- - str - Text vertical justification. - - """ - text_property = self.actor.GetTextProperty() - vjustification = text_property.GetVerticalJustificationAsString() - if vjustification == 'Bottom': - return "bottom" - elif vjustification == 'Centered': - return "middle" - elif vjustification == 'Top': - return "top" - - @vertical_justification.setter - def vertical_justification(self, vertical_justification): - """ Justifies text vertically. - - Parameters - ---------- - vertical_justification : str - Possible values are bottom, middle, top. - - """ - text_property = self.actor.GetTextProperty() - if vertical_justification == 'bottom': - text_property.SetVerticalJustificationToBottom() - elif vertical_justification == 'middle': - text_property.SetVerticalJustificationToCentered() - elif vertical_justification == 'top': - text_property.SetVerticalJustificationToTop() - else: - msg = "Vertical justification must be: bottom, middle or top." - raise ValueError(msg) - - @property - def bold(self): - """ Returns whether the text is bold. - - Returns - ------- - bool - Text is bold if True. - - """ - return self.actor.GetTextProperty().GetBold() - - @bold.setter - def bold(self, flag): - """ Bolds/un-bolds text. - - Parameters - ---------- - flag : bool - Sets text bold if True. - - """ - self.actor.GetTextProperty().SetBold(flag) - - @property - def italic(self): - """ Returns whether the text is italicised. - - Returns - ------- - bool - Text is italicised if True. - - """ - return self.actor.GetTextProperty().GetItalic() - - @italic.setter - def italic(self, flag): - """ Italicises/un-italicises text. - - Parameters - ---------- - flag : bool - Italicises text if True. - - """ - self.actor.GetTextProperty().SetItalic(flag) - - @property - def shadow(self): - """ Returns whether the text has shadow. - - Returns - ------- - bool - Text is shadowed if True. - - """ - return self.actor.GetTextProperty().GetShadow() - - @shadow.setter - def shadow(self, flag): - """ Adds/removes text shadow. - - Parameters - ---------- - flag : bool - Shadows text if True. - - """ - self.actor.GetTextProperty().SetShadow(flag) - - @property - def color(self): - """ Gets text color. - - Returns - ------- - (float, float, float) - Returns text color in RGB. - - """ - return self.actor.GetTextProperty().GetColor() - - @color.setter - def color(self, color=(1, 0, 0)): - """ Set text color. - - Parameters - ---------- - color : (float, float, float) - RGB: Values must be between 0-1. - - """ - self.actor.GetTextProperty().SetColor(*color) - - @property - def background_color(self): - """ Gets background color. - - Returns - ------- - (float, float, float) or None - If None, there no background color. - Otherwise, background color in RGB. - - """ - if major_version < 7: - if self._background is None: - return None - - return self._background.GetProperty().GetColor() - - if self.actor.GetTextProperty().GetBackgroundOpacity() == 0: - return None - - return self.actor.GetTextProperty().GetBackgroundColor() - - @background_color.setter - def background_color(self, color): - """ Set text color. - - Parameters - ---------- - color : (float, float, float) or None - If None, remove background. - Otherwise, RGB values (must be between 0-1). - - """ - - if color is None: - # Remove background. - if major_version < 7: - self._background = None - else: - self.actor.GetTextProperty().SetBackgroundOpacity(0.) - - else: - if major_version < 7: - self._background = vtk.vtkActor2D() - self._background.GetProperty().SetColor(*color) - self._background.GetProperty().SetOpacity(1) - self._background.SetMapper(self.actor.GetMapper()) - self._background.SetPosition(*self.actor.GetPosition()) - - else: - self.actor.GetTextProperty().SetBackgroundColor(*color) - self.actor.GetTextProperty().SetBackgroundOpacity(1.) - - @property - def position(self): - """ Gets text actor position. - - Returns - ------- - (float, float) - The current actor position. (x, y) in pixels. - - """ - return self.actor.GetPosition() - - @position.setter - def position(self, position): - """ Set text actor position. - - Parameters - ---------- - position : (float, float) - The new position. (x, y) in pixels. - - """ - self.actor.SetPosition(*position) - if self._background is not None: - self._background.SetPosition(*self.actor.GetPosition()) - - def set_center(self, position): - """ Sets the text center to position. - - Parameters - ---------- - position : (float, float) - - """ - self.position = position - - -class TextBox2D(UI): - """ An editable 2D text box that behaves as a UI component. - - Currently supports: - - Basic text editing. - - Cursor movements. - - Single and multi-line text boxes. - - Pre text formatting (text needs to be formatted beforehand). - - Attributes - ---------- - text : str - The current text state. - actor : :class:`vtkActor2d` - The text actor. - width : int - The number of characters in a single line of text. - height : int - The number of lines in the textbox. - window_left : int - Left limit of visible text in the textbox. - window_right : int - Right limit of visible text in the textbox. - caret_pos : int - Position of the caret in the text. - init : bool - Flag which says whether the textbox has just been initialized. - - """ - def __init__(self, width, height, text="Enter Text", position=(100, 10), - color=(0, 0, 0), font_size=18, font_family='Arial', - justification='left', bold=False, - italic=False, shadow=False): - """ - Parameters - ---------- - width : int - The number of characters in a single line of text. - height : int - The number of lines in the textbox. - text : str - The initial text while building the actor. - position : (float, float) - (x, y) in pixels. - color : (float, float, float) - RGB: Values must be between 0-1. - font_size : int - Size of the text font. - font_family : str - Currently only supports Arial. - justification : str - left, right or center. - bold : bool - Makes text bold. - italic : bool - Makes text italicised. - shadow : bool - Adds text shadow. - - """ - super(TextBox2D, self).__init__() - self.text = text - self.actor = self.build_actor(self.text, position, color, font_size, - font_family, justification, bold, italic, shadow) - self.width = width - self.height = height - self.window_left = 0 - self.window_right = 0 - self.caret_pos = 0 - self.init = True - - self.handle_events(self.actor.get_actor()) - - self.on_left_mouse_button_pressed = self.left_button_press - self.on_key_press = self.key_press - - def build_actor(self, text, position, color, font_size, - font_family, justification, bold, italic, shadow): - - """ Builds a text actor. - - Parameters - ---------- - text : str - The initial text while building the actor. - position : (float, float) - (x, y) in pixels. - color : (float, float, float) - RGB: Values must be between 0-1. - font_size : int - Size of the text font. - font_family : str - Currently only supports Arial. - justification : str - left, right or center. - bold : bool - Makes text bold. - italic : bool - Makes text italicised. - shadow : bool - Adds text shadow. - - Returns - ------- - :class:`TextBlock2D` - - """ - text_block = TextBlock2D() - text_block.position = position - text_block.message = text - text_block.font_size = font_size - text_block.font_family = font_family - text_block.justification = justification - text_block.bold = bold - text_block.italic = italic - text_block.shadow = shadow - - if major_version >= 7: - text_block.actor.GetTextProperty().SetBackgroundColor(1, 1, 1) - text_block.actor.GetTextProperty().SetBackgroundOpacity(1.0) - text_block.color = color - - return text_block - - def set_message(self, message): - """ Set custom text to textbox. - - Parameters - ---------- - message: str - The custom message to be set. - - """ - self.text = message - self.actor.message = message - self.init = False - self.window_right = len(self.text) - self.window_left = 0 - self.caret_pos = self.window_right - - def get_actors(self): - """ Returns the actors that compose this UI component. - - """ - return [self.actor.get_actor()] - - def width_set_text(self, text): - """ Adds newlines to text where necessary. - - This is needed for multi-line text boxes. - - Parameters - ---------- - text : str - The final text to be formatted. - - Returns - ------- - str - A multi line formatted text. - - """ - multi_line_text = "" - for i in range(len(text)): - multi_line_text += text[i] - if (i + 1) % self.width == 0: - multi_line_text += "\n" - return multi_line_text.rstrip("\n") - - def handle_character(self, character): - """ Main driving function that handles button events. - - # TODO: Need to handle all kinds of characters like !, +, etc. - - Parameters - ---------- - character : str - - """ - if character.lower() == "return": - self.render_text(False) - return True - if character.lower() == "backspace": - self.remove_character() - elif character.lower() == "left": - self.move_left() - elif character.lower() == "right": - self.move_right() - else: - self.add_character(character) - self.render_text() - return False - - def move_caret_right(self): - """ Moves the caret towards right. - - """ - self.caret_pos = min(self.caret_pos + 1, len(self.text)) - - def move_caret_left(self): - """ Moves the caret towards left. - - """ - self.caret_pos = max(self.caret_pos - 1, 0) - - def right_move_right(self): - """ Moves right boundary of the text window right-wards. - - """ - if self.window_right <= len(self.text): - self.window_right += 1 - - def right_move_left(self): - """ Moves right boundary of the text window left-wards. - - """ - if self.window_right > 0: - self.window_right -= 1 - - def left_move_right(self): - """ Moves left boundary of the text window right-wards. - - """ - if self.window_left <= len(self.text): - self.window_left += 1 - - def left_move_left(self): - """ Moves left boundary of the text window left-wards. - - """ - if self.window_left > 0: - self.window_left -= 1 - - def add_character(self, character): - """ Inserts a character into the text and moves window and caret accordingly. - - Parameters - ---------- - character : str - - """ - if len(character) > 1 and character.lower() != "space": - return - if character.lower() == "space": - character = " " - self.text = (self.text[:self.caret_pos] + - character + - self.text[self.caret_pos:]) - self.move_caret_right() - if (self.window_right - - self.window_left == self.height * self.width - 1): - self.left_move_right() - self.right_move_right() - - def remove_character(self): - """ Removes a character from the text and moves window and caret accordingly. - - """ - if self.caret_pos == 0: - return - self.text = self.text[:self.caret_pos - 1] + self.text[self.caret_pos:] - self.move_caret_left() - if len(self.text) < self.height * self.width - 1: - self.right_move_left() - if (self.window_right - - self.window_left == self.height * self.width - 1): - if self.window_left > 0: - self.left_move_left() - self.right_move_left() - - def move_left(self): - """ Handles left button press. - - """ - self.move_caret_left() - if self.caret_pos == self.window_left - 1: - if (self.window_right - - self.window_left == self.height * self.width - 1): - self.left_move_left() - self.right_move_left() - - def move_right(self): - """ Handles right button press. - - """ - self.move_caret_right() - if self.caret_pos == self.window_right + 1: - if (self.window_right - - self.window_left == self.height * self.width - 1): - self.left_move_right() - self.right_move_right() - - def showable_text(self, show_caret): - """ Chops out text to be shown on the screen. - - Parameters - ---------- - show_caret : bool - Whether or not to show the caret. - - """ - if show_caret: - ret_text = (self.text[:self.caret_pos] + - "_" + - self.text[self.caret_pos:]) - else: - ret_text = self.text - ret_text = ret_text[self.window_left:self.window_right + 1] - return ret_text - - def render_text(self, show_caret=True): - """ Renders text after processing. - - Parameters - ---------- - show_caret : bool - Whether or not to show the caret. - - """ - text = self.showable_text(show_caret) - if text == "": - text = "Enter Text" - self.actor.message = self.width_set_text(text) - - def edit_mode(self): - """ Turns on edit mode. - - """ - if self.init: - self.text = "" - self.init = False - self.caret_pos = 0 - self.render_text() - - def set_center(self, position): - """ Sets the text center to position. - - Parameters - ---------- - position : (float, float) - - """ - self.actor.position = position - - @staticmethod - def left_button_press(i_ren, obj, textbox_object): - """ Left button press handler for textbox - - Parameters - ---------- - i_ren: :class:`CustomInteractorStyle` - obj: :class:`vtkActor` - The picked actor - textbox_object: :class:`TextBox2D` - - """ - i_ren.add_active_prop(textbox_object.actor.get_actor()) - textbox_object.edit_mode() - i_ren.force_render() - - @staticmethod - def key_press(i_ren, obj, textbox_object): - """ Key press handler for textbox - - Parameters - ---------- - i_ren: :class:`CustomInteractorStyle` - obj: :class:`vtkActor` - The picked actor - textbox_object: :class:`TextBox2D` - - """ - key = i_ren.event.key - is_done = textbox_object.handle_character(key) - if is_done: - i_ren.remove_active_prop(textbox_object.actor.get_actor()) - - i_ren.force_render() - - -class LineSlider2D(UI): - """ A 2D Line Slider. - - A sliding ring on a line with a percentage indicator. - - Currently supports: - - A disk on a line (a thin rectangle). - - Setting disk position. - - Attributes - ---------- - line_width : int - Width of the line on which the disk will slide. - inner_radius : int - Inner radius of the disk (ring). - outer_radius : int - Outer radius of the disk. - center : (float, float) - Center of the slider. - length : int - Length of the slider. - slider_line : :class:`vtkActor` - The line on which the slider disk moves. - slider_disk : :class:`vtkActor` - The moving slider disk. - text : :class:`TextBlock2D` - The text that shows percentage. - - """ - def __init__(self, line_width=5, inner_radius=0, outer_radius=10, - center=(450, 300), length=200, initial_value=50, - min_value=0, max_value=100, text_size=16, - text_template="{value:.1f} ({ratio:.0%})"): - """ - Parameters - ---------- - line_width : int - Width of the line on which the disk will slide. - inner_radius : int - Inner radius of the disk (ring). - outer_radius : int - Outer radius of the disk. - center : (float, float) - Center of the slider. - length : int - Length of the slider. - initial_value : float - Initial value of the slider. - min_value : float - Minimum value of the slider. - max_value : float - Maximum value of the slider. - text_size : int - Size of the text to display alongside the slider (pt). - text_template : str, callable - If str, text template can contain one or multiple of the - replacement fields: `{value:}`, `{ratio:}`. - If callable, this instance of `:class:LineSlider2D` will be - passed as argument to the text template function. - - """ - super(LineSlider2D, self).__init__() - - self.length = length - self.min_value = min_value - self.max_value = max_value - - self.text_template = text_template - - self.line_width = line_width - self.center = center - self.current_state = center[0] - self.left_x_position = center[0] - length / 2 - self.right_x_position = center[0] + length / 2 - self._ratio = (self.current_state - self.left_x_position) / length - - self.slider_line = None - self.slider_disk = None - self.text = None - - self.build_actors(inner_radius=inner_radius, - outer_radius=outer_radius, text_size=text_size) - - # Setting the disk position will also update everything. - self.value = initial_value - # self.update() - - self.handle_events(None) - - def build_actors(self, inner_radius, outer_radius, text_size): - """ Builds required actors. - - Parameters - ---------- - inner_radius: int - The inner radius of the sliding disk. - outer_radius: int - The outer radius of the sliding disk. - text_size: int - Size of the text that displays percentage. - - """ - # Slider Line - self.slider_line = Rectangle2D(size=(self.length, self.line_width), - center=self.center).actor - self.slider_line.GetProperty().SetColor(1, 0, 0) - # /Slider Line - - # Slider Disk - # Create source - disk = vtk.vtkDiskSource() - disk.SetInnerRadius(inner_radius) - disk.SetOuterRadius(outer_radius) - disk.SetRadialResolution(10) - disk.SetCircumferentialResolution(50) - disk.Update() - - # Mapper - mapper = vtk.vtkPolyDataMapper2D() - mapper.SetInputConnection(disk.GetOutputPort()) - - # Actor - self.slider_disk = vtk.vtkActor2D() - self.slider_disk.SetMapper(mapper) - # /Slider Disk - - # Slider Text - self.text = TextBlock2D() - self.text.position = (self.left_x_position - 50, self.center[1] - 10) - self.text.font_size = text_size - # /Slider Text - - def get_actors(self): - """ Returns the actors that compose this UI component. - - """ - return [self.slider_line, self.slider_disk, self.text.get_actor()] - - def set_position(self, position): - """ Sets the disk's position. - - Parameters - ---------- - position : (float, float) - The absolute position of the disk (x, y). - - """ - x_position = position[0] - - if x_position < self.center[0] - self.length/2: - x_position = self.center[0] - self.length/2 - - if x_position > self.center[0] + self.length/2: - x_position = self.center[0] + self.length/2 - - self.current_state = x_position - self.update() - - @property - def value(self): - return self._value - - @value.setter - def value(self, value): - value_range = self.max_value - self.min_value - self.ratio = (value - self.min_value) / value_range - - @property - def ratio(self): - return self._ratio - - @ratio.setter - def ratio(self, ratio): - position_x = self.left_x_position + ratio*self.length - self.set_position((position_x, None)) - - def format_text(self): - """ Returns formatted text to display along the slider. """ - if callable(self.text_template): - return self.text_template(self) - - return self.text_template.format(ratio=self.ratio, value=self.value) - - def update(self): - """ Updates the slider. """ - - # Compute the ratio determined by the position of the slider disk. - length = float(self.right_x_position - self.left_x_position) - assert length == self.length - self._ratio = (self.current_state - self.left_x_position) / length - - # Compute the selected value considering min_value and max_value. - value_range = self.max_value - self.min_value - self._value = self.min_value + self.ratio*value_range - - # Update text disk actor. - self.slider_disk.SetPosition(self.current_state, self.center[1]) - - # Update text. - text = self.format_text() - self.text.message = text - offset_x = 8 * len(text) / 2. - offset_y = 30 - self.text.position = (self.current_state - offset_x, - self.center[1] - offset_y) - - def set_center(self, position): - """ Sets the center of the slider to position. - - Parameters - ---------- - position : (float, float) - The new center of the whole slider (x, y). - - """ - self.slider_line.SetPosition(position[0] - self.length / 2, - position[1] - self.line_width / 2) - - x_change = position[0] - self.center[0] - self.current_state += x_change - self.center = position - self.left_x_position = position[0] - self.length / 2 - self.right_x_position = position[0] + self.length / 2 - self.set_position((self.current_state, self.center[1])) - - @staticmethod - def line_click_callback(i_ren, obj, slider): - """ Update disk position and grab the focus. - - Parameters - ---------- - i_ren : :class:`CustomInteractorStyle` - obj : :class:`vtkActor` - The picked actor - slider : :class:`LineSlider2D` - - """ - position = i_ren.event.position - slider.set_position(position) - i_ren.force_render() - i_ren.event.abort() # Stop propagating the event. - - @staticmethod - def disk_press_callback(i_ren, obj, slider): - """ Only need to grab the focus. - - Parameters - ---------- - i_ren : :class:`CustomInteractorStyle` - obj : :class:`vtkActor` - The picked actor - slider : :class:`LineSlider2D` - - """ - i_ren.event.abort() # Stop propagating the event. - - @staticmethod - def disk_move_callback(i_ren, obj, slider): - """ Actual disk movement. - - Parameters - ---------- - i_ren : :class:`CustomInteractorStyle` - obj : :class:`vtkActor` - The picked actor - slider : :class:`LineSlider2D` - - """ - position = i_ren.event.position - slider.set_position(position) - i_ren.force_render() - i_ren.event.abort() # Stop propagating the event. - - def handle_events(self, actor): - """ Handle all events for the LineSlider. - Base method needs to be overridden due to multiple actors. - - """ - self.add_callback(self.slider_line, "LeftButtonPressEvent", - self.line_click_callback, 1) - self.add_callback(self.slider_disk, "LeftButtonPressEvent", - self.disk_press_callback) - self.add_callback(self.slider_disk, "MouseMoveEvent", - self.disk_move_callback) - self.add_callback(self.slider_line, "MouseMoveEvent", - self.disk_move_callback) - - -class DiskSlider2D(UI): - """ A disk slider. - - A disk moves along the boundary of a ring. - Goes from 0-360 degrees. - - Attributes - ---------- - base_disk_center: (float, float) - Position of the system. - slider_inner_radius: int - Inner radius of the base disk. - slider_outer_radius: int - Outer radius of the base disk. - slider_radius: float - Average radius of the base disk. - handle_outer_radius: int - Outer radius of the slider's handle. - handle_inner_radius: int - Inner radius of the slider's handle. - previous_value: float - Value of Rotation of the actor before the current value. - initial_value: float - Initial Value of Rotation of the actor assigned on creation of object. - - """ - def __init__(self, position=(0, 0), - initial_value=180, min_value=0, max_value=360, - slider_inner_radius=40, slider_outer_radius=44, - handle_inner_radius=10, handle_outer_radius=0, - text_size=16, - text_template="{ratio:.0%}"): - - """ - Parameters - ---------- - position : (float, float) - Position (x, y) of the slider's center. - initial_value : float - Initial value of the slider. - min_value : float - Minimum value of the slider. - max_value : float - Maximum value of the slider. - slider_inner_radius : int - Inner radius of the base disk. - slider_outer_radius : int - Outer radius of the base disk. - handle_outer_radius : int - Outer radius of the slider's handle. - handle_inner_radius : int - Inner radius of the slider's handle. - text_size : int - Size of the text to display alongside the slider (pt). - text_template : str, callable - If str, text template can contain one or multiple of the - replacement fields: `{value:}`, `{ratio:}`, `{angle:}`. - If callable, this instance of `:class:DiskSlider2D` will be - passed as argument to the text template function. - - """ - super(DiskSlider2D, self).__init__() - self.center = np.array(position) - self.min_value = min_value - self.max_value = max_value - self.initial_value = initial_value - self.slider_inner_radius = slider_inner_radius - self.slider_outer_radius = slider_outer_radius - self.handle_inner_radius = handle_inner_radius - self.handle_outer_radius = handle_outer_radius - self.slider_radius = (slider_inner_radius + slider_outer_radius) / 2. - - self.handle = None - self.base_disk = None - - self.text = None - self.text_size = text_size - self.text_template = text_template - - self.build_actors() - - # By setting the value, it also updates everything. - self.value = initial_value - self.previous_value = initial_value - self.handle_events(None) - - def build_actors(self): - """ Builds actors for the system. - - """ - base_disk = vtk.vtkDiskSource() - base_disk.SetInnerRadius(self.slider_inner_radius) - base_disk.SetOuterRadius(self.slider_outer_radius) - base_disk.SetRadialResolution(10) - base_disk.SetCircumferentialResolution(50) - base_disk.Update() - - base_disk_mapper = vtk.vtkPolyDataMapper2D() - base_disk_mapper.SetInputConnection(base_disk.GetOutputPort()) - - self.base_disk = vtk.vtkActor2D() - self.base_disk.SetMapper(base_disk_mapper) - self.base_disk.GetProperty().SetColor(1, 0, 0) - self.base_disk.SetPosition(self.center) - - handle = vtk.vtkDiskSource() - handle.SetInnerRadius(self.handle_inner_radius) - handle.SetOuterRadius(self.handle_outer_radius) - handle.SetRadialResolution(10) - handle.SetCircumferentialResolution(50) - handle.Update() - - handle_mapper = vtk.vtkPolyDataMapper2D() - handle_mapper.SetInputConnection(handle.GetOutputPort()) - - self.handle = vtk.vtkActor2D() - self.handle.SetMapper(handle_mapper) - - self.text = TextBlock2D() - offset = np.array((16., 8.)) - self.text.position = self.center - offset - self.text.font_size = self.text_size - - @property - def value(self): - return self._value - - @value.setter - def value(self, value): - value_range = self.max_value - self.min_value - self.ratio = (value - self.min_value) / value_range - - @property - def previous_value(self): - return self._previous_value - - @previous_value.setter - def previous_value(self, previous_value): - self._previous_value = previous_value - - @property - def ratio(self): - return self._ratio - - @ratio.setter - def ratio(self, ratio): - self.angle = ratio * TWO_PI - - @property - def angle(self): - """ Angle (in rad) the handle makes with x-axis """ - return self._angle - - @angle.setter - def angle(self, angle): - self._angle = angle % TWO_PI # Wraparound - self.update() - - def format_text(self): - """ Returns formatted text to display along the slider. """ - if callable(self.text_template): - return self.text_template(self) - - return self.text_template.format(ratio=self.ratio, value=self.value, - angle=np.rad2deg(self.angle)) - - def update(self): - """ Updates the slider. """ - - # Compute the ratio determined by the position of the slider disk. - self._ratio = self.angle / TWO_PI - - # Compute the selected value considering min_value and max_value. - value_range = self.max_value - self.min_value - try: - self._previous_value = self.value - except: - self._previous_value = self.initial_value - self._value = self.min_value + self.ratio*value_range - - # Update text disk actor. - x = self.slider_radius * np.cos(self.angle) + self.center[0] - y = self.slider_radius * np.sin(self.angle) + self.center[1] - self.handle.SetPosition(x, y) - - # Update text. - text = self.format_text() - self.text.message = text - - def get_actors(self): - """ Returns the actors that compose this UI component. - - """ - return [self.base_disk, self.handle, self.text.get_actor()] - - def move_handle(self, click_position): - """Moves the slider's handle. - - Parameters - ---------- - click_position: (float, float) - Position of the mouse click. - - """ - x, y = np.array(click_position) - self.center - angle = np.arctan2(y, x) - if angle < 0: - angle += TWO_PI - - self.angle = angle - - def set_center(self, position): - """ Changes the slider's center position. - - Parameters - ---------- - position : (float, float) - New position (x, y). - - """ - position = np.array(position) - offset = position - self.center - self.base_disk.SetPosition(position) - self.handle.SetPosition(*(offset + self.handle.GetPosition())) - self.text.position += offset - self.center = position - - @staticmethod - def base_disk_click_callback(i_ren, obj, slider): - """ Update disk position and grab the focus. - - Parameters - ---------- - i_ren : :class:`CustomInteractorStyle` - obj : :class:`vtkActor` - The picked actor - slider : :class:`DiskSlider2D` - - """ - click_position = i_ren.event.position - slider.move_handle(click_position=click_position) - i_ren.force_render() - i_ren.event.abort() # Stop propagating the event. - - @staticmethod - def handle_move_callback(i_ren, obj, slider): - """ Move the slider's handle. - - Parameters - ---------- - i_ren : :class:`CustomInteractorStyle` - obj : :class:`vtkActor` - The picked actor - slider : :class:`DiskSlider2D` - - """ - click_position = i_ren.event.position - slider.move_handle(click_position=click_position) - i_ren.force_render() - i_ren.event.abort() # Stop propagating the event. - - @staticmethod - def handle_press_callback(i_ren, obj, slider): - """ This is only needed to grab the focus. - - Parameters - ---------- - i_ren : :class:`CustomInteractorStyle` - obj : :class:`vtkActor` - The picked actor - slider : :class:`DiskSlider2D` - - """ - i_ren.event.abort() # Stop propagating the event. - - def handle_events(self, actor): - """ Handle all default slider events. - - """ - self.add_callback(self.base_disk, "LeftButtonPressEvent", - self.base_disk_click_callback, 1) - self.add_callback(self.handle, "LeftButtonPressEvent", - self.handle_press_callback) - self.add_callback(self.base_disk, "MouseMoveEvent", - self.handle_move_callback) - self.add_callback(self.handle, "MouseMoveEvent", - self.handle_move_callback) - - -class FileSelectMenu2D(UI): - """ A menu to select files in the current folder. - - Can go to new folder, previous folder and select a file - and keep it in a variable. - - Attributes - ---------- - n_text_actors: int - The number of text actors. Calculated dynamically. - selected_file: string - Current selected file. - text_item_list: list(:class:`FileSelectMenuText2D`) - List of FileSelectMenuText2Ds - both visible and invisible. - window_offset: int - Used for scrolling. - Tells you the index of the first visible FileSelectMenuText2D - object. - size: (float, float) - The size of the system (x, y) in pixels. - font_size: int - The font size in pixels. - line_spacing: float - Distance between menu text items in pixels. - parent_ui: :class:`UI` - The UI component this object belongs to. - extensions: list(string) - List of extensions to be shown as files. - - """ - - def __init__(self, size, font_size, position, parent, extensions, - directory_path, reverse_scrolling=False, line_spacing=1.4): - """ - Parameters - ---------- - size: (float, float) - The size of the system (x, y) in pixels. - font_size: int - The font size in pixels. - parent: :class:`UI` - The UI component this object belongs to. - This will be useful when this UI element is used as a - part of other UI elements, like a file save dialog. - position: (float, float) - The initial position (x, y) in pixels. - reverse_scrolling: {True, False} - If True, scrolling up will move the list of files down. - line_spacing: float - Distance between menu text items in pixels. - extensions: list(string) - List of extensions to be shown as files. - directory_path: string - Path of the directory where this dialog should open. - Example: os.getcwd() - - """ - super(FileSelectMenu2D, self).__init__() - - self.size = size - self.font_size = font_size - self.parent_ui = parent - self.reverse_scrolling = reverse_scrolling - self.line_spacing = line_spacing - self.extensions = extensions - - self.n_text_actors = 0 # Initialisation Value - self.text_item_list = [] - self.selected_file = "" - self.window_offset = 0 - self.current_directory = directory_path - self.buttons = dict() - - self.menu = self.build_actors(position) - - self.fill_text_actors() - self.handle_events(None) - - def add_to_renderer(self, ren): - self.menu.add_to_renderer(ren) - super(FileSelectMenu2D, self).add_to_renderer(ren) - for menu_text in self.text_item_list: - menu_text.add_to_renderer(ren) - - def get_actors(self): - """ Returns the actors that compose this UI component. - - """ - return [self.buttons["up"], self.buttons["down"]] - - def build_actors(self, position): - """ Builds the number of text actors that will fit in the given size. - - Allots them positions in the panel, which is only there to allot positions, - otherwise the panel itself is invisible. - - Parameters - ---------- - position: (float, float) - Position of the panel (x, y) in pixels. - - """ - # Calculating the number of text actors. - self.n_text_actors = int(self.size[1]/(self.font_size*self.line_spacing)) - - # This panel is just to facilitate the addition of actors at the right positions - panel = Panel2D(center=position, size=self.size, color=(1, 1, 1)) - - # Initialisation of empty text actors - for i in range(self.n_text_actors): - - text = FileSelectMenuText2D(position=(0, 0), font_size=self.font_size, - file_select=self) - text.parent_UI = self.parent_ui - self.ui_list.append(text) - self.text_item_list.append(text) - - panel.add_element(text, 'relative', - (0.1, - float(self.n_text_actors-i - 1) / - float(self.n_text_actors))) - - up_button = Button2D({"up": read_viz_icons(fname="arrow-up.png")}) - panel.add_element(up_button, 'relative', (0.95, 0.95)) - self.buttons["up"] = up_button - - down_button = Button2D({"down": read_viz_icons(fname="arrow-down.png")}) - panel.add_element(down_button, 'relative', (0.95, 0.05)) - self.buttons["down"] = down_button - - return panel - - @staticmethod - def up_button_callback(i_ren, obj, file_select_menu): - """ Pressing up button scrolls up in the menu. - - Parameters - ---------- - i_ren: :class:`CustomInteractorStyle` - obj: :class:`vtkActor` - The picked actor - file_select_menu: :class:`FileSelectMenu2D` - - """ - all_file_names = file_select_menu.get_all_file_names() - - if (file_select_menu.n_text_actors + - file_select_menu.window_offset) <= len(all_file_names): - if file_select_menu.window_offset > 0: - file_select_menu.window_offset -= 1 - file_select_menu.fill_text_actors() - - i_ren.force_render() - i_ren.event.abort() # Stop propagating the event. - - @staticmethod - def down_button_callback(i_ren, obj, file_select_menu): - """ Pressing down button scrolls down in the menu. - - Parameters - ---------- - i_ren: :class:`CustomInteractorStyle` - obj: :class:`vtkActor` - The picked actor - file_select_menu: :class:`FileSelectMenu2D` - - """ - all_file_names = file_select_menu.get_all_file_names() - - if (file_select_menu.n_text_actors + - file_select_menu.window_offset) < len(all_file_names): - file_select_menu.window_offset += 1 - file_select_menu.fill_text_actors() - - i_ren.force_render() - i_ren.event.abort() # Stop propagating the event. - - def fill_text_actors(self): - """ Fills file/folder names to text actors. - - The list is truncated if the number of file/folder names is greater - than the available number of text actors. - - """ - # Flush all the text actors - for text_item in self.text_item_list: - text_item.text_actor.message = "" - text_item.text_actor.actor.SetVisibility(False) - - all_file_names = self.get_all_file_names() - - clipped_file_names = all_file_names[self.window_offset:self.n_text_actors + self.window_offset] - - # Allot file names as in the above list - i = 0 - for file_name in clipped_file_names: - self.text_item_list[i].text_actor.actor.SetVisibility(True) - self.text_item_list[i].set_attributes(file_name[0], file_name[1]) - if file_name[0] == self.selected_file: - self.text_item_list[i].mark_selected() - i += 1 - - def get_all_file_names(self): - """ Gets file and directory names. - - Returns - ------- - all_file_names: list(string) - List of all file and directory names as string. - - """ - all_file_names = [] - - directory_names = self.get_directory_names() - for directory_name in directory_names: - all_file_names.append((directory_name, "directory")) - - file_names = self.get_file_names() - for file_name in file_names: - all_file_names.append((file_name, "file")) - - return all_file_names - - def get_directory_names(self): - """ Re-allots file names to the text actors. - - Uses FileSelectMenuText2D for selecting files and folders. - - Returns - ------- - directory_names: list(string) - List of all directory names as string. - - """ - # A list of directory names in the current directory - directory_names = next(os.walk(self.current_directory))[1] - directory_names = [os.path.basename(os.path.abspath(dn)) for dn in directory_names] - directory_names = ["../"] + directory_names - - return directory_names - - def get_file_names(self): - """ Re-allots file names to the text actors. - - Uses FileSelectMenuText2D for selecting files and folders. - - Returns - ------- - file_names: list(string) - List of all file names as string. - - """ - # A list of file names with extension in the current directory - file_names = [] - for extension in self.extensions: - file_names += glob.glob(self.current_directory + "/*." + extension) - file_names = [os.path.basename(os.path.abspath(fn)) for fn in file_names] - return file_names - - def select_file(self, file_name): - """ Changes the selected file name. - - Parameters - ---------- - file_name: string - Name of the file. - - """ - self.selected_file = file_name - - def set_center(self, position): - """ Sets the elements center. - - Parameters - ---------- - position: (float, float) - New position (x, y) in pixels. - - """ - self.menu.set_center(position=position) - - def handle_events(self, actor): - self.add_callback(self.buttons["up"].actor, "LeftButtonPressEvent", - self.up_button_callback) - self.add_callback(self.buttons["down"].actor, "LeftButtonPressEvent", - self.down_button_callback) - - # Handle mouse wheel events - up_event = "MouseWheelForwardEvent" - down_event = "MouseWheelBackwardEvent" - if self.reverse_scrolling: - up_event, down_event = down_event, up_event # Swap events - - self.add_callback(self.menu.get_actors()[0], up_event, - self.up_button_callback) - self.add_callback(self.menu.get_actors()[0], down_event, - self.down_button_callback) - - for text_ui in self.text_item_list: - self.add_callback(text_ui.text_actor.get_actors()[0], up_event, - self.up_button_callback) - self.add_callback(text_ui.text_actor.get_actors()[0], down_event, - self.down_button_callback) - - -class FileSelectMenuText2D(UI): - """ The text to select folder in a file select menu. - - Provides a callback to change the directory. - - Attributes - ---------- - file_name: string - The name of the file the text is displaying. - file_type: string - Whether the file is a file or directory. - file_select: :class:`FileSelect2D` - The FileSelectMenu2D reference this text belongs to. - - """ - - def __init__(self, font_size, position, file_select): - """ - Parameters - ---------- - font_size: int - The font size of the text in pixels. - position: (float, float) - Absolute text position (x, y) in pixels. - file_select: :class:`FileSelect2D` - The FileSelectMenu2D reference this text belongs to. - - """ - super(FileSelectMenuText2D, self).__init__() - - self.file_name = "" - self.file_type = "" - self.file_select = file_select - - self.text_actor = self.build_actor(position=position, font_size=font_size) - - self.handle_events(self.text_actor.get_actor()) - - self.on_left_mouse_button_clicked = self.left_button_clicked - - def build_actor(self, position, text="Text", color=(1, 1, 1), font_family='Arial', - justification='left', bold=False, italic=False, - shadow=False, font_size='14'): - """ Builds a text actor. - - Parameters - ---------- - text: string - The initial text while building the actor. - position: (float, float) - The text position (x, y) in pixels. - color: (float, float, float) - Values must be between 0-1 (RGB). - font_family: string - Currently only supports Arial. - justification: string - Text justification - left, right or center. - bold: bool - Whether or not the text is bold. - italic: bool - Whether or not the text is italicized. - shadow: bool - Whether or not the text has shadow. - font_size: int - The font size of the text in pixels. - - Returns - ------- - text_actor: :class:`TextBlock2D` - The base text actor. - - """ - text_actor = TextBlock2D() - text_actor.position = position - text_actor.message = text - text_actor.font_size = font_size - text_actor.font_family = font_family - text_actor.justification = justification - text_actor.bold = bold - text_actor.italic = italic - text_actor.shadow = shadow - text_actor.color = color - - if vtk.vtkVersion.GetVTKVersion() <= "6.2.0": - pass - else: - text_actor.actor.GetTextProperty().SetBackgroundColor(1, 1, 1) - text_actor.actor.GetTextProperty().SetBackgroundOpacity(1.0) - - text_actor.actor.GetTextProperty().SetColor(0, 0, 0) - text_actor.actor.GetTextProperty().SetLineSpacing(1) - - return text_actor - - def get_actors(self): - """ Returns the actors that compose this UI component. - - """ - return [self.text_actor.get_actor()] - - def set_attributes(self, file_name, file_type): - """ Set attributes (file name and type) of this component. - - This function is for use by a FileSelectMenu2D to set the - current file_name and file_type for this FileSelectMenuText2D - component. - - Parameters - ---------- - file_name: string - The name of the file. - file_type: string - File type = directory or file. - - """ - self.file_name = file_name - self.file_type = file_type - self.text_actor.message = file_name - - if vtk.vtkVersion.GetVTKVersion() <= "6.2.0": - self.text_actor.get_actor().GetTextProperty().SetColor(1, 1, 1) - if file_type != "file": - self.text_actor.get_actor().GetTextProperty().SetBold(True) - - else: - if file_type == "file": - self.text_actor.get_actor().GetTextProperty().SetBackgroundColor(0, 0, 0) - self.text_actor.get_actor().GetTextProperty().SetColor(1, 1, 1) - else: - self.text_actor.get_actor().GetTextProperty().SetBackgroundColor(1, 1, 1) - self.text_actor.get_actor().GetTextProperty().SetColor(0, 0, 0) - - def mark_selected(self): - """ Changes the background color of the actor. - - """ - if vtk.vtkVersion.GetVTKVersion() <= "6.2.0": - self.text_actor.actor.GetTextProperty().SetColor(1, 0, 0) - else: - self.text_actor.actor.GetTextProperty().SetBackgroundColor(1, 0, 0) - self.text_actor.actor.GetTextProperty().SetBackgroundOpacity(1.0) - - @staticmethod - def left_button_clicked(i_ren, obj, file_select_text): - """ A callback to handle left click for this UI element. - - Parameters - ---------- - i_ren: :class:`CustomInteractorStyle` - obj: :class:`vtkActor` - The picked actor - file_select_text: :class:`FileSelectMenuText2D` - - """ - - if file_select_text.file_type == "directory": - file_select_text.file_select.select_file(file_name="") - file_select_text.file_select.window_offset = 0 - file_select_text.file_select.current_directory = os.path.abspath( - os.path.join(file_select_text.file_select.current_directory, - file_select_text.text_actor.message)) - file_select_text.file_select.window = 0 - file_select_text.file_select.fill_text_actors() - else: - file_select_text.file_select.select_file( - file_name=file_select_text.file_name) - file_select_text.file_select.fill_text_actors() - file_select_text.mark_selected() - - i_ren.force_render() - i_ren.event.abort() # Stop propagating the event. - - def set_center(self, position): - """ Sets the text center to position. - - Parameters - ---------- - position: (float, float) - The new position (x, y) in pixels. - """ - self.text_actor.position = position diff --git a/dipy/viz/utils.py b/dipy/viz/utils.py deleted file mode 100644 index 8faf2f2645..0000000000 --- a/dipy/viz/utils.py +++ /dev/null @@ -1,475 +0,0 @@ - -from __future__ import division, print_function, absolute_import - -import numpy as np -from scipy.ndimage import map_coordinates -from dipy.viz.colormap import line_colors - -# Conditional import machinery for vtk -from dipy.utils.optpkg import optional_package - -# import vtk -# Allow import, but disable doctests if we don't have vtk -vtk, have_vtk, setup_module = optional_package('vtk') -ns, have_numpy_support, _ = optional_package('vtk.util.numpy_support') - - -def set_input(vtk_object, inp): - """ Generic input function which takes into account VTK 5 or 6 - - Parameters - ---------- - vtk_object: vtk object - inp: vtkPolyData or vtkImageData or vtkAlgorithmOutput - - Returns - ------- - vtk_object - - Notes - ------- - This can be used in the following way:: - from dipy.viz.utils import set_input - poly_mapper = set_input(vtk.vtkPolyDataMapper(), poly_data) - """ - if isinstance(inp, vtk.vtkPolyData) \ - or isinstance(inp, vtk.vtkImageData): - if vtk.VTK_MAJOR_VERSION <= 5: - vtk_object.SetInput(inp) - else: - vtk_object.SetInputData(inp) - elif isinstance(inp, vtk.vtkAlgorithmOutput): - vtk_object.SetInputConnection(inp) - - vtk_object.Update() - return vtk_object - - -def numpy_to_vtk_points(points): - """ Numpy points array to a vtk points array - - Parameters - ---------- - points : ndarray - - Returns - ------- - vtk_points : vtkPoints() - """ - vtk_points = vtk.vtkPoints() - vtk_points.SetData(ns.numpy_to_vtk(np.asarray(points), deep=True)) - return vtk_points - - -def numpy_to_vtk_colors(colors): - """ Numpy color array to a vtk color array - - Parameters - ---------- - colors: ndarray - - Returns - ------- - vtk_colors : vtkDataArray - - Notes - ----- - If colors are not already in UNSIGNED_CHAR you may need to multiply by 255. - - Examples - -------- - >>> import numpy as np - >>> from dipy.viz.utils import numpy_to_vtk_colors - >>> rgb_array = np.random.rand(100, 3) - >>> vtk_colors = numpy_to_vtk_colors(255 * rgb_array) - """ - vtk_colors = ns.numpy_to_vtk(np.asarray(colors), deep=True, - array_type=vtk.VTK_UNSIGNED_CHAR) - return vtk_colors - - -def map_coordinates_3d_4d(input_array, indices): - """ Evaluate the input_array data at the given indices - using trilinear interpolation - - Parameters - ---------- - input_array : ndarray, - 3D or 4D array - indices : ndarray - - Returns - ------- - output : ndarray - 1D or 2D array - """ - - if input_array.ndim <= 2 or input_array.ndim >= 5: - raise ValueError("Input array can only be 3d or 4d") - - if input_array.ndim == 3: - return map_coordinates(input_array, indices.T, order=1) - - if input_array.ndim == 4: - values_4d = [] - for i in range(input_array.shape[-1]): - values_tmp = map_coordinates(input_array[..., i], - indices.T, order=1) - values_4d.append(values_tmp) - return np.ascontiguousarray(np.array(values_4d).T) - - -def lines_to_vtk_polydata(lines, colors=None): - """ Create a vtkPolyData with lines and colors - - Parameters - ---------- - lines : list - list of N curves represented as 2D ndarrays - colors : array (N, 3), list of arrays, tuple (3,), array (K,), None - If None then a standard orientation colormap is used for every line. - If one tuple of color is used. Then all streamlines will have the same - colour. - If an array (N, 3) is given, where N is equal to the number of lines. - Then every line is coloured with a different RGB color. - If a list of RGB arrays is given then every point of every line takes - a different color. - If an array (K, 3) is given, where K is the number of points of all - lines then every point is colored with a different RGB color. - If an array (K,) is given, where K is the number of points of all - lines then these are considered as the values to be used by the - colormap. - If an array (L,) is given, where L is the number of streamlines then - these are considered as the values to be used by the colormap per - streamline. - If an array (X, Y, Z) or (X, Y, Z, 3) is given then the values for the - colormap are interpolated automatically using trilinear interpolation. - - Returns - ------- - poly_data : vtkPolyData - is_colormap : bool, true if the input color array was a colormap - """ - - # Get the 3d points_array - points_array = np.vstack(lines) - - nb_lines = len(lines) - nb_points = len(points_array) - - lines_range = range(nb_lines) - - # Get lines_array in vtk input format - lines_array = [] - # Using np.intp (instead of int64), because of a bug in numpy: - # https://github.com/nipy/dipy/pull/789 - # https://github.com/numpy/numpy/issues/4384 - points_per_line = np.zeros([nb_lines], np.intp) - current_position = 0 - for i in lines_range: - current_len = len(lines[i]) - points_per_line[i] = current_len - - end_position = current_position + current_len - lines_array += [current_len] - lines_array += range(current_position, end_position) - current_position = end_position - - lines_array = np.array(lines_array) - - # Set Points to vtk array format - vtk_points = numpy_to_vtk_points(points_array) - - # Set Lines to vtk array format - vtk_lines = vtk.vtkCellArray() - vtk_lines.GetData().DeepCopy(ns.numpy_to_vtk(lines_array)) - vtk_lines.SetNumberOfCells(nb_lines) - - is_colormap = False - # Get colors_array (reformat to have colors for each points) - # - if/else tested and work in normal simple case - if colors is None: # set automatic rgb colors - cols_arr = line_colors(lines) - colors_mapper = np.repeat(lines_range, points_per_line, axis=0) - vtk_colors = numpy_to_vtk_colors(255 * cols_arr[colors_mapper]) - else: - cols_arr = np.asarray(colors) - if cols_arr.dtype == np.object: # colors is a list of colors - vtk_colors = numpy_to_vtk_colors(255 * np.vstack(colors)) - else: - if len(cols_arr) == nb_points: - if cols_arr.ndim == 1: # values for every point - vtk_colors = ns.numpy_to_vtk(cols_arr, deep=True) - is_colormap = True - elif cols_arr.ndim == 2: # map color to each point - vtk_colors = numpy_to_vtk_colors(255 * cols_arr) - - elif cols_arr.ndim == 1: - if len(cols_arr) == nb_lines: # values for every streamline - cols_arrx = [] - for (i, value) in enumerate(colors): - cols_arrx += lines[i].shape[0]*[value] - cols_arrx = np.array(cols_arrx) - vtk_colors = ns.numpy_to_vtk(cols_arrx, deep=True) - is_colormap = True - else: # the same colors for all points - vtk_colors = numpy_to_vtk_colors( - np.tile(255 * cols_arr, (nb_points, 1))) - - elif cols_arr.ndim == 2: # map color to each line - colors_mapper = np.repeat(lines_range, points_per_line, axis=0) - vtk_colors = numpy_to_vtk_colors(255 * cols_arr[colors_mapper]) - else: # colormap - # get colors for each vertex - cols_arr = map_coordinates_3d_4d(cols_arr, points_array) - vtk_colors = ns.numpy_to_vtk(cols_arr, deep=True) - is_colormap = True - - vtk_colors.SetName("Colors") - - # Create the poly_data - poly_data = vtk.vtkPolyData() - poly_data.SetPoints(vtk_points) - poly_data.SetLines(vtk_lines) - poly_data.GetPointData().SetScalars(vtk_colors) - return poly_data, is_colormap - - -def get_polydata_lines(line_polydata): - """ vtk polydata to a list of lines ndarrays - - Parameters - ---------- - line_polydata : vtkPolyData - - Returns - ------- - lines : list - List of N curves represented as 2D ndarrays - """ - lines_vertices = ns.vtk_to_numpy(line_polydata.GetPoints().GetData()) - lines_idx = ns.vtk_to_numpy(line_polydata.GetLines().GetData()) - - lines = [] - current_idx = 0 - while current_idx < len(lines_idx): - line_len = lines_idx[current_idx] - - next_idx = current_idx + line_len + 1 - line_range = lines_idx[current_idx + 1: next_idx] - - lines += [lines_vertices[line_range]] - current_idx = next_idx - return lines - - -def get_polydata_triangles(polydata): - """ get triangles (ndarrays Nx3 int) from a vtk polydata - - Parameters - ---------- - polydata : vtkPolyData - - Returns - ------- - output : array (N, 3) - triangles - """ - vtk_polys = ns.vtk_to_numpy(polydata.GetPolys().GetData()) - assert((vtk_polys[::4] == 3).all()) # test if its really triangles - return np.vstack([vtk_polys[1::4], vtk_polys[2::4], vtk_polys[3::4]]).T - - -def get_polydata_vertices(polydata): - """ get vertices (ndarrays Nx3 int) from a vtk polydata - - Parameters - ---------- - polydata : vtkPolyData - - Returns - ------- - output : array (N, 3) - points, represented as 2D ndarrays - """ - return ns.vtk_to_numpy(polydata.GetPoints().GetData()) - - -def get_polydata_normals(polydata): - """ get vertices normal (ndarrays Nx3 int) from a vtk polydata - - Parameters - ---------- - polydata : vtkPolyData - - Returns - ------- - output : array (N, 3) - Normals, represented as 2D ndarrays (Nx3). None if there are no normals - in the vtk polydata. - """ - vtk_normals = polydata.GetPointData().GetNormals() - if vtk_normals is None: - return None - else: - return ns.vtk_to_numpy(vtk_normals) - - -def get_polydata_colors(polydata): - """ get points color (ndarrays Nx3 int) from a vtk polydata - - Parameters - ---------- - polydata : vtkPolyData - - Returns - ------- - output : array (N, 3) - Colors. None if no normals in the vtk polydata. - """ - vtk_colors = polydata.GetPointData().GetScalars() - if vtk_colors is None: - return None - else: - return ns.vtk_to_numpy(vtk_colors) - - -def set_polydata_triangles(polydata, triangles): - """ set polydata triangles with a numpy array (ndarrays Nx3 int) - - Parameters - ---------- - polydata : vtkPolyData - triangles : array (N, 3) - triangles, represented as 2D ndarrays (Nx3) - """ - vtk_triangles = np.hstack(np.c_[np.ones(len(triangles)).astype(np.int) * 3, - triangles]) - vtk_triangles = ns.numpy_to_vtkIdTypeArray(vtk_triangles, deep=True) - vtk_cells = vtk.vtkCellArray() - vtk_cells.SetCells(len(triangles), vtk_triangles) - polydata.SetPolys(vtk_cells) - return polydata - - -def set_polydata_vertices(polydata, vertices): - """ set polydata vertices with a numpy array (ndarrays Nx3 int) - - Parameters - ---------- - polydata : vtkPolyData - vertices : vertices, represented as 2D ndarrays (Nx3) - """ - vtk_points = vtk.vtkPoints() - vtk_points.SetData(ns.numpy_to_vtk(vertices, deep=True)) - polydata.SetPoints(vtk_points) - return polydata - - -def set_polydata_normals(polydata, normals): - """ set polydata normals with a numpy array (ndarrays Nx3 int) - - Parameters - ---------- - polydata : vtkPolyData - normals : normals, represented as 2D ndarrays (Nx3) (one per vertex) - """ - vtk_normals = ns.numpy_to_vtk(normals, deep=True) - polydata.GetPointData().SetNormals(vtk_normals) - return polydata - - -def set_polydata_colors(polydata, colors): - """ set polydata colors with a numpy array (ndarrays Nx3 int) - - Parameters - ---------- - polydata : vtkPolyData - colors : colors, represented as 2D ndarrays (Nx3) - colors are uint8 [0,255] RGB for each points - """ - vtk_colors = ns.numpy_to_vtk(colors, deep=True, - array_type=vtk.VTK_UNSIGNED_CHAR) - vtk_colors.SetNumberOfComponents(3) - vtk_colors.SetName("RGB") - polydata.GetPointData().SetScalars(vtk_colors) - return polydata - - -def update_polydata_normals(polydata): - """ generate and update polydata normals - - Parameters - ---------- - polydata : vtkPolyData - """ - normals_gen = set_input(vtk.vtkPolyDataNormals(), polydata) - normals_gen.ComputePointNormalsOn() - normals_gen.ComputeCellNormalsOn() - normals_gen.SplittingOff() - # normals_gen.FlipNormalsOn() - # normals_gen.ConsistencyOn() - # normals_gen.AutoOrientNormalsOn() - normals_gen.Update() - - vtk_normals = normals_gen.GetOutput().GetPointData().GetNormals() - polydata.GetPointData().SetNormals(vtk_normals) - - -def get_polymapper_from_polydata(polydata): - """ get vtkPolyDataMapper from a vtkPolyData - - Parameters - ---------- - polydata : vtkPolyData - - Returns - ------- - poly_mapper : vtkPolyDataMapper - """ - poly_mapper = set_input(vtk.vtkPolyDataMapper(), polydata) - poly_mapper.ScalarVisibilityOn() - poly_mapper.InterpolateScalarsBeforeMappingOn() - poly_mapper.Update() - poly_mapper.StaticOn() - return poly_mapper - - -def get_actor_from_polymapper(poly_mapper): - """ get vtkActor from a vtkPolyDataMapper - - Parameters - ---------- - poly_mapper : vtkPolyDataMapper - - Returns - ------- - actor : vtkActor - """ - actor = vtk.vtkActor() - actor.SetMapper(poly_mapper) - actor.GetProperty().BackfaceCullingOn() - actor.GetProperty().SetInterpolationToPhong() - - # Use different defaults for OpenGL1 rendering backend - if vtk.VTK_MAJOR_VERSION <= 6: - actor.GetProperty().SetAmbient(0.1) - actor.GetProperty().SetDiffuse(0.15) - actor.GetProperty().SetSpecular(0.05) - - return actor - - -def get_actor_from_polydata(polydata): - """ get vtkActor from a vtkPolyData - - Parameters - ---------- - polydata : vtkPolyData - - Returns - ------- - actor : vtkActor - """ - poly_mapper = get_polymapper_from_polydata(polydata) - return get_actor_from_polymapper(poly_mapper) diff --git a/dipy/viz/widget.py b/dipy/viz/widget.py deleted file mode 100644 index 45cffc101d..0000000000 --- a/dipy/viz/widget.py +++ /dev/null @@ -1,329 +0,0 @@ -# Widgets are different than actors in that they can interact with events -# To do so they need as input a vtkRenderWindowInteractor also known as iren. - -import numpy as np - -# Conditional import machinery for vtk -from dipy.utils.optpkg import optional_package - -# Allow import, but disable doctests if we don't have vtk -vtk, have_vtk, setup_module = optional_package('vtk') -colors, have_vtk_colors, _ = optional_package('vtk.util.colors') -numpy_support, have_ns, _ = optional_package('vtk.util.numpy_support') - - -def slider(iren, ren, callback, min_value=0, max_value=255, value=125, - label="Slider", - right_normalized_pos=(0.9, 0.5), - size=(50, 0), - label_format="%0.0lf", - color=(0.5, 0.5, 0.5), - selected_color=(0.9, 0.2, 0.1)): - """ A 2D slider widget - - Parameters - ---------- - iren : vtkRenderWindowInteractor - Used to process events and handle them to the slider. Can also be given - by the attribute ``ShowManager.iren``. - ren : vtkRenderer or Renderer - Used to update the slider's position when the window changes. Can also - be given by the ``ShowManager.ren`` attribute. - callback : function - Function that has at least ``obj`` and ``event`` as parameters. It will - be called when the slider's bar has changed. - min_value : float - Minimum value of slider. - max_value : float - Maximum value of slider. - value : - Default value of slider. - label : str - Slider's caption. - right_normalized_pos : tuple - 2d tuple holding the normalized right (X, Y) position of the slider. - size: tuple - 2d tuple holding the size of the slider in pixels. - label_format: str - Formating in which the slider's value will appear for example "%0.2lf" - allows for 2 decimal values. - - Returns - ------- - slider : SliderObject - This object inherits from vtkSliderWidget and has additional method - called ``place`` which allows to update the position of the slider - when for example the window is resized. - """ - - slider_rep = vtk.vtkSliderRepresentation2D() - slider_rep.SetMinimumValue(min_value) - slider_rep.SetMaximumValue(max_value) - slider_rep.SetValue(value) - slider_rep.SetTitleText(label) - - slider_rep.GetPoint2Coordinate().SetCoordinateSystemToNormalizedDisplay() - slider_rep.GetPoint2Coordinate().SetValue(*right_normalized_pos) - - coord2 = slider_rep.GetPoint2Coordinate().GetComputedDisplayValue(ren) - slider_rep.GetPoint1Coordinate().SetCoordinateSystemToDisplay() - slider_rep.GetPoint1Coordinate().SetValue(coord2[0] - size[0], - coord2[1] - size[1]) - - initial_window_size = ren.GetSize() - length = 0.04 - width = 0.04 - cap_length = 0.01 - cap_width = 0.01 - tube_width = 0.005 - - slider_rep.SetSliderLength(length) - slider_rep.SetSliderWidth(width) - slider_rep.SetEndCapLength(cap_length) - slider_rep.SetEndCapWidth(cap_width) - slider_rep.SetTubeWidth(tube_width) - slider_rep.SetLabelFormat(label_format) - - slider_rep.GetLabelProperty().SetColor(*color) - slider_rep.GetTubeProperty().SetColor(*color) - slider_rep.GetCapProperty().SetColor(*color) - slider_rep.GetTitleProperty().SetColor(*color) - slider_rep.GetSelectedProperty().SetColor(*selected_color) - slider_rep.GetSliderProperty().SetColor(*color) - - slider_rep.GetLabelProperty().SetShadow(0) - slider_rep.GetTitleProperty().SetShadow(0) - - class SliderWidget(vtk.vtkSliderWidget): - - def place(self, ren): - - slider_rep = self.GetRepresentation() - coord2_norm = slider_rep.GetPoint2Coordinate() - coord2_norm.SetCoordinateSystemToNormalizedDisplay() - coord2_norm.SetValue(*right_normalized_pos) - - coord2 = coord2_norm.GetComputedDisplayValue(ren) - slider_rep.GetPoint1Coordinate().SetCoordinateSystemToDisplay() - slider_rep.GetPoint1Coordinate().SetValue(coord2[0] - size[0], - coord2[1] - size[1]) - - window_size = ren.GetSize() - length = initial_window_size[0] * 0.04 / window_size[0] - width = initial_window_size[1] * 0.04 / window_size[1] - - slider_rep.SetSliderLength(length) - slider_rep.SetSliderWidth(width) - - def set_value(self, value): - return self.GetSliderRepresentation().SetValue(value) - - def get_value(self): - return self.GetSliderRepresentation().GetValue() - - slider = SliderWidget() - slider.SetInteractor(iren) - slider.SetRepresentation(slider_rep) - slider.SetAnimationModeToAnimate() - slider.KeyPressActivationOff() - slider.AddObserver("InteractionEvent", callback) - slider.SetEnabled(True) - - # Place widget after window resizing. - def _place_widget(obj, event): - slider.place(ren) - - iren.GetRenderWindow().AddObserver( - vtk.vtkCommand.StartEvent, _place_widget) - iren.GetRenderWindow().AddObserver( - vtk.vtkCommand.ModifiedEvent, _place_widget) - - return slider - - -def button_display_coordinates(renderer, normalized_display_position, size): - upperRight = vtk.vtkCoordinate() - upperRight.SetCoordinateSystemToNormalizedDisplay() - upperRight.SetValue(normalized_display_position[0], - normalized_display_position[1]) - bds = [0.0] * 6 - bds[0] = upperRight.GetComputedDisplayValue(renderer)[0] - size[0] - bds[1] = bds[0] + size[0] - bds[2] = upperRight.GetComputedDisplayValue(renderer)[1] - size[1] - bds[3] = bds[2] + size[1] - - return bds - - -def button(iren, ren, callback, fname, right_normalized_pos=(.98, .9), - size=(50, 50)): - """ A textured two state button widget - - Parameters - ---------- - iren : vtkRenderWindowInteractor - Used to process events and handle them to the button. Can also be given - by the attribute ``ShowManager.iren``. - ren : vtkRenderer or Renderer - Used to update the slider's position when the window changes. Can also - be given by the ``ShowManager.ren`` attribute. - callback : function - Function that has at least ``obj`` and ``event`` as parameters. It will - be called when the button is pressed. - fname : str - PNG file path of the icon used for the button. - right_normalized_pos : tuple - 2d tuple holding the normalized right (X, Y) position of the slider. - size: tuple - 2d tuple holding the size of the slider in pixels. - - Returns - ------- - button : ButtonWidget - This object inherits from vtkButtonWidget and has an additional method - called ``place`` which allows to update the position of the slider - if necessary. For example when the renderer size changes. - - Notes - ------ - The button and slider widgets have similar positioning system. This enables - the developers to create a HUD-like collections of buttons and sliders on - the right side of the window that always stays in place when the dimensions - of the window change. - """ - - image1 = vtk.vtkPNGReader() - image1.SetFileName(fname) - image1.Update() - - button_rep = vtk.vtkTexturedButtonRepresentation2D() - button_rep.SetNumberOfStates(2) - button_rep.SetButtonTexture(0, image1.GetOutput()) - button_rep.SetButtonTexture(1, image1.GetOutput()) - - class ButtonWidget(vtk.vtkButtonWidget): - - def place(self, renderer): - - bds = button_display_coordinates(renderer, right_normalized_pos, - size) - self.GetRepresentation().SetPlaceFactor(1) - self.GetRepresentation().PlaceWidget(bds) - self.On() - - button = ButtonWidget() - button.SetInteractor(iren) - button.SetRepresentation(button_rep) - button.AddObserver(vtk.vtkCommand.StateChangedEvent, callback) - - # Place widget after window resizing. - def _place_widget(obj, event): - button.place(ren) - - iren.GetRenderWindow().AddObserver( - vtk.vtkCommand.StartEvent, _place_widget) - iren.GetRenderWindow().AddObserver( - vtk.vtkCommand.ModifiedEvent, _place_widget) - - return button - - -def text(iren, ren, callback, message="DIPY", - left_down_pos=(0.8, 0.5), right_top_pos=(0.9, 0.5), - color=(1., .5, .0), opacity=1., border=False): - """ 2D text that can be clicked and process events - - Parameters - ---------- - iren : vtkRenderWindowInteractor - Used to process events and handle them to the button. Can also be given - by the attribute ``ShowManager.iren``. - ren : vtkRenderer or Renderer - Used to update the slider's position when the window changes. Can also - be given by the ``ShowManager.ren`` attribute. - callback : function - Function that has at least ``obj`` and ``event`` as parameters. It will - be called when the button is pressed. - message : str - Message to be shown in the text widget - left_down_pos : tuple - Coordinates for left down corner of text. If float are provided, - the normalized coordinate system is used, otherwise the coordinates - represent pixel positions. Default is (0.8, 0.5). - right_top_pos : tuple - Coordinates for right top corner of text. If float are provided, - the normalized coordinate system is used, otherwise the coordinates - represent pixel positions. Default is (0.9, 0.5). - color : tuple - Foreground RGB color of text. Default is (1., .5, .0). - opacity : float - Takes values from 0 to 1. Default is 1. - border : bool - Show text border. Default is False. - - Returns - ------- - text : TextWidget - This object inherits from ``vtkTextWidget`` has an additional method - called ``place`` which allows to update the position of the text if - necessary. - """ - - # Create the TextActor - text_actor = vtk.vtkTextActor() - text_actor.SetInput(message) - text_actor.GetTextProperty().SetColor(color) - text_actor.GetTextProperty().SetOpacity(opacity) - - # Create the text representation. Used for positioning the text_actor - text_rep = vtk.vtkTextRepresentation() - text_rep.SetTextActor(text_actor) - - if border: - text_rep.SetShowBorderToOn() - else: - text_rep.SetShowBorderToOff() - - class TextWidget(vtk.vtkTextWidget): - - def place(self, renderer): - text_rep = self.GetRepresentation() - - position = text_rep.GetPositionCoordinate() - position2 = text_rep.GetPosition2Coordinate() - - # The dtype of `left_down_pos` determines coordinate system type. - if np.issubdtype(np.asarray(left_down_pos).dtype, np.integer): - position.SetCoordinateSystemToDisplay() - else: - position.SetCoordinateSystemToNormalizedDisplay() - - # The dtype of `right_top_pos` determines coordinate system type. - if np.issubdtype(np.asarray(right_top_pos).dtype, np.integer): - position2.SetCoordinateSystemToDisplay() - else: - position2.SetCoordinateSystemToNormalizedDisplay() - - position.SetValue(*left_down_pos) - position2.SetValue(*right_top_pos) - - text_widget = TextWidget() - text_widget.SetRepresentation(text_rep) - text_widget.SetInteractor(iren) - text_widget.SelectableOn() - text_widget.ResizableOff() - - text_widget.AddObserver(vtk.vtkCommand.WidgetActivateEvent, callback) - - # Place widget after window resizing. - def _place_widget(obj, event): - text_widget.place(ren) - - iren.GetRenderWindow().AddObserver( - vtk.vtkCommand.StartEvent, _place_widget) - iren.GetRenderWindow().AddObserver( - vtk.vtkCommand.ModifiedEvent, _place_widget) - - text_widget.On() - - return text_widget diff --git a/dipy/viz/window.py b/dipy/viz/window.py deleted file mode 100644 index 6b2b7c1855..0000000000 --- a/dipy/viz/window.py +++ /dev/null @@ -1,932 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import division, print_function, absolute_import - -import gzip -from warnings import warn - -import numpy as np -from scipy import ndimage -from copy import copy - -from nibabel.tmpdirs import InTemporaryDirectory -from nibabel.py3k import asbytes - -try: - import Tkinter as tkinter - has_tkinter = True -except ImportError: - try: - import tkinter - has_tkinter = True - except ImportError: - has_tkinter = False - -try: - import tkFileDialog as filedialog -except ImportError: - try: - from tkinter import filedialog - except ImportError: - has_tkinter = False - -# Conditional import machinery for vtk -from dipy.utils.optpkg import optional_package - -from dipy import __version__ as dipy_version -from dipy.utils.six import string_types - -from dipy.viz.interactor import CustomInteractorStyle - -# Allow import, but disable doctests if we don't have vtk -vtk, have_vtk, setup_module = optional_package('vtk') -colors, have_vtk_colors, _ = optional_package('vtk.util.colors') -numpy_support, have_ns, _ = optional_package('vtk.util.numpy_support') -_, have_imread, _ = optional_package('Image') -if not have_imread: - _, have_imread, _ = optional_package('PIL') - -if have_vtk: - version = vtk.vtkVersion.GetVTKSourceVersion().split(' ')[-1] - major_version = vtk.vtkVersion.GetVTKMajorVersion() - from vtk.util.numpy_support import vtk_to_numpy - vtkRenderer = vtk.vtkRenderer -else: - vtkRenderer = object - -if have_imread: - from scipy.misc import imread - - -class Renderer(vtkRenderer): - """ Your scene class - - This is an important object that is responsible for preparing objects - e.g. actors and volumes for rendering. This is a more pythonic version - of ``vtkRenderer`` proving simple methods for adding and removing actors - but also it provides access to all the functionality - available in ``vtkRenderer`` if necessary. - """ - - def background(self, color): - """ Set a background color - """ - self.SetBackground(color) - - def add(self, *actors): - """ Add an actor to the renderer - """ - for actor in actors: - if isinstance(actor, vtk.vtkVolume): - self.AddVolume(actor) - elif isinstance(actor, vtk.vtkActor2D): - self.AddActor2D(actor) - elif hasattr(actor, 'add_to_renderer'): - actor.add_to_renderer(self) - else: - self.AddActor(actor) - - def rm(self, actor): - """ Remove a specific actor - """ - self.RemoveActor(actor) - - def clear(self): - """ Remove all actors from the renderer - """ - self.RemoveAllViewProps() - - def rm_all(self): - """ Remove all actors from the renderer - """ - self.RemoveAllViewProps() - - def projection(self, proj_type='perspective'): - """ Deside between parallel or perspective projection - - Parameters - ---------- - proj_type : str - Can be 'parallel' or 'perspective' (default). - - """ - if proj_type == 'parallel': - self.GetActiveCamera().ParallelProjectionOn() - else: - self.GetActiveCamera().ParallelProjectionOff() - - def reset_camera(self): - """ Reset the camera to an automatic position given by the engine. - """ - self.ResetCamera() - - def reset_clipping_range(self): - self.ResetCameraClippingRange() - - def camera(self): - return self.GetActiveCamera() - - def get_camera(self): - cam = self.GetActiveCamera() - return cam.GetPosition(), cam.GetFocalPoint(), cam.GetViewUp() - - def camera_info(self): - cam = self.camera() - print('# Active Camera') - print(' Position (%.2f, %.2f, %.2f)' % cam.GetPosition()) - print(' Focal Point (%.2f, %.2f, %.2f)' % cam.GetFocalPoint()) - print(' View Up (%.2f, %.2f, %.2f)' % cam.GetViewUp()) - - def set_camera(self, position=None, focal_point=None, view_up=None): - if position is not None: - self.GetActiveCamera().SetPosition(*position) - if focal_point is not None: - self.GetActiveCamera().SetFocalPoint(*focal_point) - if view_up is not None: - self.GetActiveCamera().SetViewUp(*view_up) - self.ResetCameraClippingRange() - - def size(self): - """ Renderer size""" - return self.GetSize() - - def zoom(self, value): - """ In perspective mode, decrease the view angle by the specified - factor. In parallel mode, decrease the parallel scale by the specified - factor. A value greater than 1 is a zoom-in, a value less than 1 is a - zoom-out. - """ - self.GetActiveCamera().Zoom(value) - - def azimuth(self, angle): - """ Rotate the camera about the view up vector centered at the focal - point. Note that the view up vector is whatever was set via SetViewUp, - and is not necessarily perpendicular to the direction of projection. - The result is a horizontal rotation of the camera. - """ - self.GetActiveCamera().Azimuth(angle) - - def yaw(self, angle): - """ Rotate the focal point about the view up vector, using the camera's - position as the center of rotation. Note that the view up vector is - whatever was set via SetViewUp, and is not necessarily perpendicular - to the direction of projection. The result is a horizontal rotation of - the scene. - """ - self.GetActiveCamera().Yaw(angle) - - def elevation(self, angle): - """ Rotate the camera about the cross product of the negative of the - direction of projection and the view up vector, using the focal point - as the center of rotation. The result is a vertical rotation of the - scene. - """ - self.GetActiveCamera().Elevation(angle) - - def pitch(self, angle): - """ Rotate the focal point about the cross product of the view up - vector and the direction of projection, using the camera's position as - the center of rotation. The result is a vertical rotation of the - camera. - """ - self.GetActiveCamera().Pitch(angle) - - def roll(self, angle): - """ Rotate the camera about the direction of projection. This will - spin the camera about its axis. - """ - self.GetActiveCamera().Roll(angle) - - def dolly(self, value): - """ Divide the camera's distance from the focal point by the given - dolly value. Use a value greater than one to dolly-in toward the focal - point, and use a value less than one to dolly-out away from the focal - point. - """ - self.GetActiveCamera().Dolly(value) - - def camera_direction(self): - """ Get the vector in the direction from the camera position to the - focal point. This is usually the opposite of the ViewPlaneNormal, the - vector perpendicular to the screen, unless the view is oblique. - """ - return self.GetActiveCamera().GetDirectionOfProjection() - - -def renderer(background=None): - """ Create a renderer. - - Parameters - ---------- - background : tuple - Initial background color of renderer - - Returns - ------- - v : Renderer - - Examples - -------- - >>> from dipy.viz import window, actor - >>> import numpy as np - >>> r = window.Renderer() - >>> lines=[np.random.rand(10,3)] - >>> c=actor.line(lines, window.colors.red) - >>> r.add(c) - >>> #window.show(r) - """ - - deprecation_msg = ("Method 'dipy.viz.window.renderer' is deprecated, instead" - " use class 'dipy.viz.window.Renderer'.") - warn(DeprecationWarning(deprecation_msg)) - - ren = Renderer() - if background is not None: - ren.SetBackground(background) - - return ren - -if have_vtk: - ren = renderer - - -def add(ren, a): - """ Add a specific actor - """ - ren.add(a) - - -def rm(ren, a): - """ Remove a specific actor - """ - ren.rm(a) - - -def clear(ren): - """ Remove all actors from the renderer - """ - ren.clear() - - -def rm_all(ren): - """ Remove all actors from the renderer - """ - ren.rm_all() - - -def open_file_dialog(file_types=[("All files", "*")]): - """ Simple Tk file dialog for opening files - - Parameters - ---------- - file_types : tuples of tuples - Accepted file types. - - Returns - ------- - file_paths : sequence of str - Returns the full paths of all selected files - """ - - root = tkinter.Tk() - root.withdraw() - file_paths = filedialog.askopenfilenames(filetypes=file_types) - return file_paths - - -def save_file_dialog(initial_file='dipy.png', default_ext='.png', - file_types=(("PNG file", "*.png"), ("All Files", "*.*"))): - """ Simple Tk file dialog for saving a file - - Parameters - ---------- - initial_file : str - For example ``dipy.png``. - default_ext : str - Default extension to appear in the save dialog. - file_types : tuples of tuples - Accepted file types. - - Returns - ------- - filepath : str - Complete filename of saved file - """ - - root = tkinter.Tk() - root.withdraw() - file_path = filedialog.asksaveasfilename(initialfile=initial_file, - defaultextension=default_ext, - filetypes=file_types) - return file_path - - -class ShowManager(object): - """ This class is the interface between the renderer, the window and the - interactor. - """ - - def __init__(self, ren=None, title='DIPY', size=(300, 300), - png_magnify=1, reset_camera=True, order_transparent=False, - interactor_style='custom'): - - """ Manages the visualization pipeline - - Parameters - ---------- - ren : Renderer() or vtkRenderer() - The scene that holds all the actors. - title : string - A string for the window title bar. - size : (int, int) - ``(width, height)`` of the window. Default is (300, 300). - png_magnify : int - Number of times to magnify the screenshot. This can be used to save - high resolution screenshots when pressing 's' inside the window. - reset_camera : bool - Default is True. You can change this option to False if you want to - keep the camera as set before calling this function. - order_transparent : bool - True is useful when you want to order transparent - actors according to their relative position to the camera. The - default option which is False will order the actors according to - the order of their addition to the Renderer(). - interactor_style : str or vtkInteractorStyle - If str then if 'trackball' then vtkInteractorStyleTrackballCamera() - is used, if 'image' then vtkInteractorStyleImage() is used (no - rotation) or if 'custom' then CustomInteractorStyle is used. - Otherwise you can input your own interactor style. - - Attributes - ---------- - ren : vtkRenderer() - iren : vtkRenderWindowInteractor() - style : vtkInteractorStyle() - window : vtkRenderWindow() - - Methods - ------- - initialize() - render() - start() - add_window_callback() - - Notes - ----- - Default interaction keys for - - * 3d navigation are with left, middle and right mouse dragging - * resetting the camera press 'r' - * saving a screenshot press 's' - * for quiting press 'q' - - Examples - -------- - >>> from dipy.viz import actor, window - >>> renderer = window.Renderer() - >>> renderer.add(actor.axes()) - >>> showm = window.ShowManager(renderer) - >>> # showm.initialize() - >>> # showm.render() - >>> # showm.start() - """ - if ren is None: - ren = Renderer() - self.ren = ren - self.title = title - self.size = size - self.png_magnify = png_magnify - self.reset_camera = reset_camera - self.order_transparent = order_transparent - self.interactor_style = interactor_style - - if self.reset_camera: - self.ren.ResetCamera() - - self.window = vtk.vtkRenderWindow() - self.window.AddRenderer(ren) - - if self.title == 'DIPY': - self.window.SetWindowName(title + ' ' + dipy_version) - else: - self.window.SetWindowName(title) - self.window.SetSize(size[0], size[1]) - - if self.order_transparent: - - # Use a render window with alpha bits - # as default is 0 (false)) - self.window.SetAlphaBitPlanes(True) - - # Force to not pick a framebuffer with a multisample buffer - # (default is 8) - self.window.SetMultiSamples(0) - - # Choose to use depth peeling (if supported) - # (default is 0 (false)): - self.ren.UseDepthPeelingOn() - - # Set depth peeling parameters - # Set the maximum number of rendering passes (default is 4) - ren.SetMaximumNumberOfPeels(4) - - # Set the occlusion ratio (initial value is 0.0, exact image): - ren.SetOcclusionRatio(0.0) - - if self.interactor_style == 'image': - self.style = vtk.vtkInteractorStyleImage() - elif self.interactor_style == 'trackball': - self.style = vtk.vtkInteractorStyleTrackballCamera() - elif self.interactor_style == 'custom': - self.style = CustomInteractorStyle() - else: - self.style = interactor_style - - self.iren = vtk.vtkRenderWindowInteractor() - self.style.SetCurrentRenderer(self.ren) - # Hack: below, we explicitly call the Python version of SetInteractor. - self.style.SetInteractor(self.iren) - self.iren.SetInteractorStyle(self.style) - self.iren.SetRenderWindow(self.window) - - def initialize(self): - """ Initialize interaction - """ - self.iren.Initialize() - - def render(self): - """ Renders only once - """ - self.window.Render() - - def start(self): - """ Starts interaction - """ - try: - self.iren.Start() - except AttributeError: - self.__init__(self.ren, self.title, size=self.size, - png_magnify=self.png_magnify, - reset_camera=self.reset_camera, - order_transparent=self.order_transparent, - interactor_style=self.interactor_style) - self.initialize() - self.render() - self.iren.Start() - - self.window.RemoveRenderer(self.ren) - self.ren.SetRenderWindow(None) - del self.iren - del self.window - - def record_events(self): - """ Records events during the interaction. - - The recording is represented as a list of VTK events that happened - during the interaction. The recorded events are then returned. - - Returns - ------- - events : str - Recorded events (one per line). - - Notes - ----- - Since VTK only allows recording events to a file, we use a - temporary file from which we then read the events. - """ - with InTemporaryDirectory(): - filename = "recorded_events.log" - recorder = vtk.vtkInteractorEventRecorder() - recorder.SetInteractor(self.iren) - recorder.SetFileName(filename) - - def _stop_recording_and_close(obj, evt): - if recorder: - recorder.Stop() - self.iren.TerminateApp() - - self.iren.AddObserver("ExitEvent", _stop_recording_and_close) - - recorder.EnabledOn() - recorder.Record() - - self.initialize() - self.render() - self.iren.Start() - # Deleting this object is the unique way - # to close the file. - recorder = None - # Retrieved recorded events. - with open(filename, 'r') as f: - events = f.read() - return events - - def record_events_to_file(self, filename="record.log"): - """ Records events during the interaction. - - The recording is represented as a list of VTK events - that happened during the interaction. The recording is - going to be saved into `filename`. - - Parameters - ---------- - filename : str - Name of the file that will contain the recording (.log|.log.gz). - """ - events = self.record_events() - - # Compress file if needed - if filename.endswith(".gz"): - with gzip.open(filename, 'wb') as fgz: - fgz.write(asbytes(events)) - else: - with open(filename, 'w') as f: - f.write(events) - - def play_events(self, events): - """ Plays recorded events of a past interaction. - - The VTK events that happened during the recorded interaction will be - played back. - - Parameters - ---------- - events : str - Recorded events (one per line). - """ - recorder = vtk.vtkInteractorEventRecorder() - recorder.SetInteractor(self.iren) - - recorder.SetInputString(events) - recorder.ReadFromInputStringOn() - - self.initialize() - self.render() - recorder.Play() - - def play_events_from_file(self, filename): - """ Plays recorded events of a past interaction. - - The VTK events that happened during the recorded interaction will be - played back from `filename`. - - Parameters - ---------- - filename : str - Name of the file containing the recorded events (.log|.log.gz). - """ - # Uncompress file if needed. - if filename.endswith(".gz"): - with gzip.open(filename, 'r') as f: - events = f.read() - else: - with open(filename) as f: - events = f.read() - - self.play_events(events) - - def add_window_callback(self, win_callback): - """ Add window callbacks - """ - self.window.AddObserver(vtk.vtkCommand.ModifiedEvent, win_callback) - self.window.Render() - - -def show(ren, title='DIPY', size=(300, 300), - png_magnify=1, reset_camera=True, order_transparent=False): - """ Show window with current renderer - - Parameters - ------------ - ren : Renderer() or vtkRenderer() - The scene that holds all the actors. - title : string - A string for the window title bar. Default is DIPY and current version. - size : (int, int) - ``(width, height)`` of the window. Default is (300, 300). - png_magnify : int - Number of times to magnify the screenshot. Default is 1. This can be - used to save high resolution screenshots when pressing 's' inside the - window. - reset_camera : bool - Default is True. You can change this option to False if you want to - keep the camera as set before calling this function. - order_transparent : bool - True is useful when you want to order transparent - actors according to their relative position to the camera. The default - option which is False will order the actors according to the order of - their addition to the Renderer(). - - Notes - ----- - Default interaction keys for - - * 3d navigation are with left, middle and right mouse dragging - * resetting the camera press 'r' - * saving a screenshot press 's' - * for quiting press 'q' - - Examples - ---------- - >>> import numpy as np - >>> from dipy.viz import window, actor - >>> r = window.Renderer() - >>> lines=[np.random.rand(10,3),np.random.rand(20,3)] - >>> colors=np.array([[0.2,0.2,0.2],[0.8,0.8,0.8]]) - >>> c=actor.line(lines,colors) - >>> r.add(c) - >>> l=actor.label(text="Hello") - >>> r.add(l) - >>> #window.show(r) - - See also - --------- - dipy.viz.window.record - dipy.viz.window.snapshot - """ - - show_manager = ShowManager(ren, title, size, - png_magnify, reset_camera, order_transparent) - show_manager.initialize() - show_manager.render() - show_manager.start() - - -def record(ren=None, cam_pos=None, cam_focal=None, cam_view=None, - out_path=None, path_numbering=False, n_frames=1, az_ang=10, - magnification=1, size=(300, 300), reset_camera=True, verbose=False): - """ This will record a video of your scene - - Records a video as a series of ``.png`` files of your scene by rotating the - azimuth angle az_angle in every frame. - - Parameters - ----------- - ren : vtkRenderer() object - as returned from function ren() - cam_pos : None or sequence (3,), optional - Camera's position. If None then default camera's position is used. - cam_focal : None or sequence (3,), optional - Camera's focal point. If None then default camera's focal point is - used. - cam_view : None or sequence (3,), optional - Camera's view up direction. If None then default camera's view up - vector is used. - out_path : str, optional - Output path for the frames. If None a default dipy.png is created. - path_numbering : bool - When recording it changes out_path to out_path + str(frame number) - n_frames : int, optional - Number of frames to save, default 1 - az_ang : float, optional - Azimuthal angle of camera rotation. - magnification : int, optional - How much to magnify the saved frame. Default is 1. - size : (int, int) - ``(width, height)`` of the window. Default is (300, 300). - reset_camera : bool - If True Call ``ren.reset_camera()``. Otherwise you need to set the - camera before calling this function. - verbose : bool - print information about the camera. Default is False. - - - Examples - --------- - >>> from dipy.viz import window, actor - >>> ren = window.Renderer() - >>> a = actor.axes() - >>> ren.add(a) - >>> # uncomment below to record - >>> # window.record(ren) - >>> #check for new images in current directory - """ - - if ren is None: - ren = vtk.vtkRenderer() - - renWin = vtk.vtkRenderWindow() - renWin.AddRenderer(ren) - renWin.SetSize(size[0], size[1]) - iren = vtk.vtkRenderWindowInteractor() - iren.SetRenderWindow(renWin) - - # ren.GetActiveCamera().Azimuth(180) - - if reset_camera: - ren.ResetCamera() - - renderLarge = vtk.vtkRenderLargeImage() - if major_version <= 5: - renderLarge.SetInput(ren) - else: - renderLarge.SetInput(ren) - renderLarge.SetMagnification(magnification) - renderLarge.Update() - - writer = vtk.vtkPNGWriter() - ang = 0 - - if cam_pos is not None: - cx, cy, cz = cam_pos - ren.GetActiveCamera().SetPosition(cx, cy, cz) - if cam_focal is not None: - fx, fy, fz = cam_focal - ren.GetActiveCamera().SetFocalPoint(fx, fy, fz) - if cam_view is not None: - ux, uy, uz = cam_view - ren.GetActiveCamera().SetViewUp(ux, uy, uz) - - cam = ren.GetActiveCamera() - if verbose: - print('Camera Position (%.2f, %.2f, %.2f)' % cam.GetPosition()) - print('Camera Focal Point (%.2f, %.2f, %.2f)' % cam.GetFocalPoint()) - print('Camera View Up (%.2f, %.2f, %.2f)' % cam.GetViewUp()) - - for i in range(n_frames): - ren.GetActiveCamera().Azimuth(ang) - renderLarge = vtk.vtkRenderLargeImage() - renderLarge.SetInput(ren) - renderLarge.SetMagnification(magnification) - renderLarge.Update() - writer.SetInputConnection(renderLarge.GetOutputPort()) - - if path_numbering: - if out_path is None: - filename = str(i).zfill(6) + '.png' - else: - filename = out_path + str(i).zfill(6) + '.png' - else: - if out_path is None: - filename = 'dipy.png' - else: - filename = out_path - writer.SetFileName(filename) - writer.Write() - - ang = +az_ang - - -def snapshot(ren, fname=None, size=(300, 300), offscreen=True, - order_transparent=False): - """ Saves a snapshot of the renderer in a file or in memory - - Parameters - ----------- - ren : vtkRenderer - as returned from function renderer() - fname : str or None - Save PNG file. If None return only an array without saving PNG. - size : (int, int) - ``(width, height)`` of the window. Default is (300, 300). - offscreen : bool - Default True. Go stealthmode no window should appear. - order_transparent : bool - Default False. Use depth peeling to sort transparent objects. - - Returns - ------- - arr : ndarray - Color array of size (width, height, 3) where the last dimension - holds the RGB values. - """ - - width, height = size - - if offscreen: - graphics_factory = vtk.vtkGraphicsFactory() - graphics_factory.SetOffScreenOnlyMode(1) - # TODO check if the line below helps in something - # graphics_factory.SetUseMesaClasses(1) - - render_window = vtk.vtkRenderWindow() - if offscreen: - render_window.SetOffScreenRendering(1) - render_window.AddRenderer(ren) - render_window.SetSize(width, height) - - if order_transparent: - - # Use a render window with alpha bits - # as default is 0 (false)) - render_window.SetAlphaBitPlanes(True) - - # Force to not pick a framebuffer with a multisample buffer - # (default is 8) - render_window.SetMultiSamples(0) - - # Choose to use depth peeling (if supported) - # (default is 0 (false)): - ren.UseDepthPeelingOn() - - # Set depth peeling parameters - # Set the maximum number of rendering passes (default is 4) - ren.SetMaximumNumberOfPeels(4) - - # Set the occlusion ratio (initial value is 0.0, exact image): - ren.SetOcclusionRatio(0.0) - - render_window.Render() - - window_to_image_filter = vtk.vtkWindowToImageFilter() - window_to_image_filter.SetInput(render_window) - window_to_image_filter.Update() - - vtk_image = window_to_image_filter.GetOutput() - h, w, _ = vtk_image.GetDimensions() - vtk_array = vtk_image.GetPointData().GetScalars() - components = vtk_array.GetNumberOfComponents() - arr = vtk_to_numpy(vtk_array).reshape(h, w, components) - - if fname is None: - return arr - - writer = vtk.vtkPNGWriter() - writer.SetFileName(fname) - writer.SetInputConnection(window_to_image_filter.GetOutputPort()) - writer.Write() - return arr - - -def analyze_renderer(ren): - - class ReportRenderer(object): - bg_color = None - - report = ReportRenderer() - - report.bg_color = ren.GetBackground() - report.collection = ren.GetActors() - report.actors = report.collection.GetNumberOfItems() - - report.collection.InitTraversal() - report.actors_classnames = [] - for i in range(report.actors): - class_name = report.collection.GetNextActor().GetClassName() - report.actors_classnames.append(class_name) - - return report - - -def analyze_snapshot(im, bg_color=(0, 0, 0), colors=None, - find_objects=True, - strel=None): - """ Analyze snapshot from memory or file - - Parameters - ---------- - im: str or array - If string then the image is read from a file otherwise the image is - read from a numpy array. The array is expected to be of shape (X, Y, 3) - where the last dimensions are the RGB values. - colors: tuple (3,) or list of tuples (3,) - List of colors to search in the image - find_objects: bool - If True it will calculate the number of objects that are different - from the background and return their position in a new image. - strel: 2d array - Structure element to use for finding the objects. - - Returns - ------- - report : ReportSnapshot - This is an object with attibutes like ``colors_found`` that give - information about what was found in the current snapshot array ``im``. - - """ - if isinstance(im, string_types): - im = imread(im) - - class ReportSnapshot(object): - objects = None - labels = None - colors_found = False - - report = ReportSnapshot() - - if colors is not None: - if isinstance(colors, tuple): - colors = [colors] - flags = [False] * len(colors) - for (i, col) in enumerate(colors): - # find if the current color exist in the array - flags[i] = np.any(np.all(im == col, axis=-1)) - - report.colors_found = flags - - if find_objects is True: - weights = [0.299, 0.587, 0.144] - gray = np.dot(im[..., :3], weights) - bg_color = im[0, 0] - background = np.dot(bg_color, weights) - - if strel is None: - strel = np.array([[0, 1, 0], - [1, 1, 1], - [0, 1, 0]]) - - labels, objects = ndimage.label(gray != background, strel) - report.labels = labels - report.objects = objects - - return report diff --git a/dipy/workflows/align.py b/dipy/workflows/align.py index b230c68538..eaeb39d9c4 100644 --- a/dipy/workflows/align.py +++ b/dipy/workflows/align.py @@ -1,12 +1,16 @@ from __future__ import division, print_function, absolute_import import logging +import numpy as np from dipy.align.reslice import reslice from dipy.io.image import load_nifti, save_nifti from dipy.workflows.workflow import Workflow +from dipy.align.streamlinear import slr_with_qbx +from dipy.io.streamline import load_trk, save_trk +from dipy.tracking.streamline import transform_streamlines class ResliceFlow(Workflow): - + @classmethod def get_short_name(cls): return 'reslice' @@ -14,7 +18,7 @@ def get_short_name(cls): def run(self, input_files, new_vox_size, order=1, mode='constant', cval=0, num_processes=1, out_dir='', out_resliced='resliced.nii.gz'): """Reslice data with new voxel resolution defined by ``new_vox_sz`` - + Parameters ---------- input_files : string @@ -24,7 +28,7 @@ def run(self, input_files, new_vox_size, order=1, mode='constant', cval=0, new voxel size order : int, optional order of interpolation, from 0 to 5, for resampling/reslicing, - 0 nearest interpolation, 1 trilinear etc.. if you don't want any + 0 nearest interpolation, 1 trilinear etc.. if you don't want any smoothing 0 is the option you need (default 1) mode : string, optional Points outside the boundaries of the input are filled according @@ -45,11 +49,11 @@ def run(self, input_files, new_vox_size, order=1, mode='constant', cval=0, Name of the resliced dataset to be saved (default 'resliced.nii.gz') """ - + io_it = self.get_io_iterator() for inputfile, outpfile in io_it: - + data, affine, vox_sz = load_nifti(inputfile, return_voxsize=True) logging.info('Processing {0}'.format(inputfile)) new_data, new_affine = reslice(data, affine, vox_sz, new_vox_size, @@ -57,4 +61,132 @@ def run(self, input_files, new_vox_size, order=1, mode='constant', cval=0, num_processes=num_processes) save_nifti(outpfile, new_data, new_affine) logging.info('Resliced file save in {0}'.format(outpfile)) - \ No newline at end of file + + +class SlrWithQbxFlow(Workflow): + + @classmethod + def get_short_name(cls): + return 'slrwithqbx' + + def run(self, static_files, moving_files, + x0='affine', + rm_small_clusters=50, + qbx_thr=[40, 30, 20, 15], + num_threads=None, + greater_than=50, + less_than=250, + nb_pts=20, + progressive=True, + out_dir='', + out_moved='moved.trk', + out_affine='affine.txt', + out_stat_centroids='static_centroids.trk', + out_moving_centroids='moving_centroids.trk', + out_moved_centroids='moved_centroids.trk'): + """ Streamline-based linear registration. + + For efficiency we apply the registration on cluster centroids and + remove small clusters. + + Parameters + ---------- + static_files : string + moving_files : string + x0 : string, optional + rigid, similarity or affine transformation model (default affine) + rm_small_clusters : int, optional + Remove clusters that have less than `rm_small_clusters` + (default 50) + qbx_thr : variable int, optional + Thresholds for QuickBundlesX (default [40, 30, 20, 15]) + num_threads : int, optional + Number of threads. If None (default) then all available threads + will be used. Only metrics using OpenMP will use this variable. + greater_than : int, optional + Keep streamlines that have length greater than + this value (default 50) + less_than : int, optional + Keep streamlines have length less than this value (default 250) + np_pts : int, optional + Number of points for discretizing each streamline (default 20) + progressive : boolean, optional + (default True) + out_dir : string, optional + Output directory (default input file directory) + out_moved : string, optional + Filename of moved tractogram (default 'moved.trk') + out_affine : string, optional + Filename of affine for SLR transformation (default 'affine.txt') + out_stat_centroids : string, optional + Filename of static centroids (default 'static_centroids.trk') + out_moving_centroids : string, optional + Filename of moving centroids (default 'moving_centroids.trk') + out_moved_centroids : string, optional + Filename of moved centroids (default 'moved_centroids.trk') + + Notes + ----- + The order of operations is the following. First short or long + streamlines are removed. Second the tractogram or a random selection + of the tractogram is clustered with QuickBundlesX. Then SLR + [Garyfallidis15]_ is applied. + + References + ---------- + .. [Garyfallidis15] Garyfallidis et al. "Robust and efficient linear + registration of white-matter fascicles in the space of + streamlines", NeuroImage, 117, 124--140, 2015 + + .. [Garyfallidis14] Garyfallidis et al., "Direct native-space fiber + bundle alignment for group comparisons", ISMRM, 2014. + + .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter + bundles using local and global streamline-based registration + and clustering, Neuroimage, 2017. + """ + io_it = self.get_io_iterator() + + logging.info("QuickBundlesX clustering is in use") + logging.info('QBX thresholds {0}'.format(qbx_thr)) + + for static_file, moving_file, out_moved_file, out_affine_file, \ + static_centroids_file, moving_centroids_file, \ + moved_centroids_file in io_it: + + logging.info('Loading static file {0}'.format(static_file)) + logging.info('Loading moving file {0}'.format(moving_file)) + + static, static_header = load_trk(static_file) + moving, moving_header = load_trk(moving_file) + + moved, affine, centroids_static, centroids_moving = \ + slr_with_qbx( + static, moving, x0, rm_small_clusters=rm_small_clusters, + greater_than=greater_than, less_than=less_than, + qbx_thr=qbx_thr) + + logging.info('Saving output file {0}'.format(out_moved_file)) + save_trk(out_moved_file, moved, affine=np.eye(4), + header=static_header) + + logging.info('Saving output file {0}'.format(out_affine_file)) + np.savetxt(out_affine_file, affine) + + logging.info('Saving output file {0}' + .format(static_centroids_file)) + save_trk(static_centroids_file, centroids_static, affine=np.eye(4), + header=static_header) + + logging.info('Saving output file {0}' + .format(moving_centroids_file)) + save_trk(moving_centroids_file, centroids_moving, + affine=np.eye(4), + header=static_header) + + centroids_moved = transform_streamlines(centroids_moving, affine) + + logging.info('Saving output file {0}' + .format(moved_centroids_file)) + save_trk(moved_centroids_file, centroids_moved, affine=np.eye(4), + header=static_header) diff --git a/dipy/workflows/base.py b/dipy/workflows/base.py index 3708aa129a..c3d58afe6e 100644 --- a/dipy/workflows/base.py +++ b/dipy/workflows/base.py @@ -1,7 +1,7 @@ import sys import inspect -from dipy.fixes import argparse as arg +import argparse from dipy.workflows.docstring_parser import NumpyDocString @@ -20,11 +20,10 @@ def get_args_default(func): return names, defaults -class IntrospectiveArgumentParser(arg.ArgumentParser): +class IntrospectiveArgumentParser(argparse.ArgumentParser): def __init__(self, prog=None, usage=None, description=None, epilog=None, - version=None, parents=[], - formatter_class=arg.RawTextHelpFormatter, + parents=[], formatter_class=argparse.RawTextHelpFormatter, prefix_chars='-', fromfile_prefix_chars=None, argument_default=None, conflict_handler='resolve', add_help=True): @@ -41,8 +40,6 @@ def __init__(self, prog=None, usage=None, description=None, epilog=None, A description of what the program does epilog : str Text following the argument descriptions - version : None - Add a -v/--version option with the given version string parents : list Parsers whose arguments should be copied into this one formatter_class : obj @@ -68,10 +65,15 @@ def __init__(self, prog=None, usage=None, description=None, epilog=None, " library for the analysis of diffusion MRI data. Frontiers" " in Neuroinformatics, 1-18, 2014.") - super(iap, self).__init__(prog, usage, description, epilog, version, - parents, formatter_class, prefix_chars, - fromfile_prefix_chars, argument_default, - conflict_handler, add_help) + super(iap, self).__init__(prog=prog, usage=usage, + description=description, + epilog=epilog, parents=parents, + formatter_class=formatter_class, + prefix_chars=prefix_chars, + fromfile_prefix_chars=fromfile_prefix_chars, + argument_default=argument_default, + conflict_handler=conflict_handler, + add_help=add_help) self.doc = None @@ -101,24 +103,36 @@ def add_workflow(self, workflow): ref_text = [text if text else "\n" for text in npds['References']] ref_idx = self.epilog.find('References: \n') + len('References: \n') self.epilog = "{0}{1}\n{2}".format(self.epilog[:ref_idx], - ''.join([text for text in ref_text]), + ''.join(ref_text), self.epilog[ref_idx:]) - self.outputs = [param for param in npds['Parameters'] if - 'out_' in param[0]] + self._output_params = [param for param in npds['Parameters'] + if 'out_' in param[0]] + self._positional_params = [param for param in npds['Parameters'] + if 'optional' not in param[1] and + 'out_' not in param[0]] + self._optional_params = [param for param in npds['Parameters'] + if 'optional' in param[1]] args, defaults = get_args_default(workflow.run) + output_args = self.add_argument_group('output arguments(optional)') + len_args = len(args) len_defaults = len(defaults) + nb_positional_variable = 0 - output_args = \ - self.add_argument_group('output arguments(optional)') + if len_args != len(self.doc): + raise ValueError( + self.prog + ": Number of parameters in the " + "doc string and run method does not match. " + "Please ensure that the number of parameters " + "in the run method is same as the doc string.") for i, arg in enumerate(args): prefix = '' - is_optionnal = i >= len_args - len_defaults - if is_optionnal: + is_optional = i >= len_args - len_defaults + if is_optional: prefix = '--' typestr = self.doc[i][1] @@ -130,7 +144,7 @@ def add_workflow(self, workflow): 'type': dtype, 'action': 'store'} - if is_optionnal: + if is_optional: _kwargs['metavar'] = dtype.__name__ if dtype is bool: _kwargs['action'] = 'store_true' @@ -147,20 +161,33 @@ def add_workflow(self, workflow): _kwargs['type'] = str if isnarg: - _kwargs['nargs'] = '*' + if is_optional: + _kwargs['nargs'] = '*' + else: + _kwargs['nargs'] = '+' + nb_positional_variable += 1 if 'out_' in arg: output_args.add_argument(*_args, **_kwargs) else: self.add_argument(*_args, **_kwargs) + if nb_positional_variable > 1: + raise ValueError(self.prog + " : All positional arguments present" + " are gathered into a list. It does not make" + "much sense to have more than one positional" + " argument with 'variable string' as dtype." + " Please, ensure that 'variable (type)'" + " appears only once as a positional argument." + ) + return self.add_sub_flow_args(workflow.get_sub_runs()) def add_sub_flow_args(self, sub_flows): """ Take an array of workflow objects and use introspection to extract the parameters, types and docstrings of their run method. Only the - optional input parameters are extracted for these as they are treated as - sub workflows. + optional input parameters are extracted for these as they are treated + as sub workflows. Parameters ----------- @@ -265,9 +292,9 @@ def get_flow_args(self, args=None, namespace=None): """ Returns the parsed arguments as a dictionary that will be used as a workflow's run method arguments. """ + ns_args = self.parse_args(args, namespace) dct = vars(ns_args) - return dict((k, v) for k, v in dct.items() if v is not None) def update_argument(self, *args, **kargs): @@ -284,5 +311,14 @@ def add_epilogue(self): def add_description(self): pass - def get_outputs(self): - return self.outputs + @property + def output_parameters(self): + return self._output_params + + @property + def positional_parameters(self): + return self._positional_params + + @property + def optional_parameters(self): + return self._optional_params diff --git a/dipy/workflows/denoise.py b/dipy/workflows/denoise.py index 97cc50821a..c6af05cf25 100644 --- a/dipy/workflows/denoise.py +++ b/dipy/workflows/denoise.py @@ -3,8 +3,7 @@ import logging import shutil -import nibabel as nib - +from dipy.io.image import load_nifti, save_nifti from dipy.denoise.nlmeans import nlmeans from dipy.denoise.noise_estimate import estimate_sigma from dipy.workflows.workflow import Workflow @@ -43,8 +42,7 @@ def run(self, input_files, sigma=0, out_dir='', logging.warning('Denoising skipped for now.') else: logging.info('Denoising {0}'.format(fpath)) - image = nib.load(fpath) - data = image.get_data() + data, affine, image = load_nifti(fpath, return_img=True) if sigma == 0: logging.info('Estimating sigma') @@ -52,8 +50,6 @@ def run(self, input_files, sigma=0, out_dir='', logging.debug('Found sigma {0}'.format(sigma)) denoised_data = nlmeans(data, sigma) - denoised_image = nib.Nifti1Image( - denoised_data, image.affine, image.header) + save_nifti(odenoised, denoised_data, affine, image.header) - denoised_image.to_filename(odenoised) logging.info('Denoised volume saved as {0}'.format(odenoised)) diff --git a/dipy/workflows/flow_runner.py b/dipy/workflows/flow_runner.py index 2a7ed3fc20..eae691e96e 100644 --- a/dipy/workflows/flow_runner.py +++ b/dipy/workflows/flow_runner.py @@ -1,7 +1,13 @@ from __future__ import division, print_function, absolute_import +# Disabling the FutureWarning from h5py below. +# This disables the FutureWarning warning for all the workflows. +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) + import logging +from dipy import __version__ as dipy_version from dipy.utils.six import iteritems from dipy.workflows.base import IntrospectiveArgumentParser @@ -30,8 +36,11 @@ def run_flow(flow): action='store_true', default=False, help='Force overwriting output files.') + parser.add_argument('--version', action='version', + version='DIPY {}'.format(dipy_version)) + parser.add_argument('--out_strat', action='store', dest='out_strat', - metavar='string', required=False, default='append', + metavar='string', required=False, default='absolute', help='Strategy to manage output creation.') parser.add_argument('--mix_names', dest='mix_names', @@ -80,4 +89,3 @@ def run_flow(flow): flow.set_sub_flows_optionals(sub_flows_dicts) return flow.run(**args) - diff --git a/dipy/workflows/io.py b/dipy/workflows/io.py index 0a2d97869d..4b6cbed5ca 100644 --- a/dipy/workflows/io.py +++ b/dipy/workflows/io.py @@ -36,7 +36,9 @@ def run(self, input_files, np.set_printoptions(3, suppress=True) - for input_path in input_files: + io_it = self.get_io_iterator() + + for input_path in io_it: logging.info('------------------------------------------') logging.info('Looking at {0}'.format(input_path)) logging.info('------------------------------------------') diff --git a/dipy/workflows/mask.py b/dipy/workflows/mask.py index 690010d53e..9e8770dd72 100644 --- a/dipy/workflows/mask.py +++ b/dipy/workflows/mask.py @@ -1,7 +1,6 @@ #!/usr/bin/env python from __future__ import division -import inspect import logging import numpy as np @@ -25,7 +24,7 @@ def run(self, input_files, lb, ub=np.inf, out_dir='', Path to image to be masked. lb : float Lower bound value. - ub : float + ub : float, optional Upper bound value (default Inf) out_dir : string, optional Output directory (default input file directory) diff --git a/dipy/workflows/multi_io.py b/dipy/workflows/multi_io.py index 47c2c1af37..4c1f31814e 100644 --- a/dipy/workflows/multi_io.py +++ b/dipy/workflows/multi_io.py @@ -1,14 +1,15 @@ import inspect +import itertools import numpy as np import os -import os.path as path from glob import glob from dipy.utils.six import string_types from dipy.workflows.base import get_args_default + def common_start(sa, sb): - """ Returns the longest common substring from the beginning of sa and sb """ + """Return the longest common substring from the beginning of sa and sb.""" def _iter(): for a, b in zip(sa, sb): if a == b: @@ -23,8 +24,8 @@ def slash_to_under(dir_str): return ''.join(dir_str.replace('/', '_')) -def connect_output_paths(inputs, out_dir, out_files, output_strategy='append', - mix_names=True): +def connect_output_paths(inputs, out_dir, out_files, + output_strategy='absolute', mix_names=True): """ Generates a list of output files paths based on input files and output strategies. @@ -42,8 +43,8 @@ def connect_output_paths(inputs, out_dir, out_files, output_strategy='append', 'prepend': Add the input path directory tree to out_dir. 'absolute': Put directly in out_dir. mix_names : bool - Whether or not prepend a string composed of a mix of the input names - to the final output name. + Whether or not prepend a string composed of a mix of the input + names to the final output name. Returns ------- @@ -73,24 +74,23 @@ def connect_output_paths(inputs, out_dir, out_files, output_strategy='append', mixing_prefixes = [''] * len(inputs[0]) for (mix_pref, inp) in zip(mixing_prefixes, inputs[0]): - inp_dirname = path.dirname(inp) + inp_dirname = os.path.dirname(inp) if output_strategy == 'prepend': - if path.isabs(out_dir): + if os.path.isabs(out_dir): dname = out_dir + inp_dirname - if not path.isabs(out_dir): - dname = path.join( + if not os.path.isabs(out_dir): + dname = os.path.join( os.getcwd(), out_dir + inp_dirname) elif output_strategy == 'append': - dname = path.join(inp_dirname, out_dir) - + dname = os.path.join(inp_dirname, out_dir) else: dname = out_dir updated_out_files = [] for out_file in out_files: - updated_out_files.append(path.join(dname, mix_pref + out_file)) + updated_out_files.append(os.path.join(dname, mix_pref + out_file)) outputs.append(updated_out_files) @@ -111,7 +111,7 @@ def concatenate_inputs(multi_inputs): def basename_without_extension(fname): - base = path.basename(fname) + base = os.path.basename(fname) result = base.split('.')[0] if result[-4:] == '.nii': result = result.split('.')[0] @@ -119,7 +119,7 @@ def basename_without_extension(fname): return result -def io_iterator(inputs, out_dir, fnames, output_strategy='append', +def io_iterator(inputs, out_dir, fnames, output_strategy='absolute', mix_names=False, out_keys=None): """ Creates an IOIterator from the parameters. @@ -150,7 +150,7 @@ def io_iterator(inputs, out_dir, fnames, output_strategy='append', return io_it -def io_iterator_(frame, fnc, output_strategy='append', mix_names=False): +def io_iterator_(frame, fnc, output_strategy='absolute', mix_names=False): """ Creates an IOIterator using introspection. Parameters @@ -183,7 +183,7 @@ def io_iterator_(frame, fnc, output_strategy='append', mix_names=False): # inputs for arv in args[:split_at]: - inputs.append(values[arv]) + inputs.append(values[arv]) # defaults out_keys = [] @@ -202,20 +202,26 @@ class IOIterator(object): """ Create output filenames that work nicely with multiple input files from multiple directories (processing multiple subjects with one command) - Use information from input files, out_dir and out_fnames to generate correct - outputs which can come from long lists of multiple or single inputs. + Use information from input files, out_dir and out_fnames to generate + correct outputs which can come from long lists of multiple or single + inputs. """ - def __init__(self, output_strategy='append', mix_names=False): + def __init__(self, output_strategy='absolute', mix_names=False): self.output_strategy = output_strategy self.mix_names = mix_names self.inputs = [] self.out_keys = None - def set_inputs(self, *args): + self.file_existence_check(args) self.input_args = list(args) - self.inputs = [sorted(glob(inp)) for inp in self.input_args if type(inp) == str] + for inp in self.input_args: + if type(inp) == str: + self.inputs.append(sorted(glob(inp))) + if type(inp) == list and all(isinstance(s, str) for s in inp): + nested = [sorted(glob(i)) for i in inp if isinstance(i, str)] + self.inputs.append(list(itertools.chain.from_iterable(nested))) def set_out_dir(self, out_dir): self.out_dir = out_dir @@ -243,13 +249,22 @@ def create_outputs(self): def create_directories(self): for outputs in self.outputs: for output in outputs: - directory = path.dirname(output) + directory = os.path.dirname(output) if not (directory == '' or os.path.exists(directory)): os.makedirs(directory) def __iter__(self): - I = np.array(self.inputs).T - O = np.array(self.outputs) - IO = np.concatenate([I, O], axis=1) + ins = np.array(self.inputs).T + out = np.array(self.outputs) + IO = np.concatenate([ins, out], axis=1) for i_o in IO: - yield i_o + if len(i_o) == 1: + yield str(*i_o) + else: + yield i_o + + def file_existence_check(self, args): + input_args = [fname for fname in list(args) if isinstance(fname, str)] + for path in input_args: + if len(glob(path)) == 0: + raise IOError('File not found: ' + path) diff --git a/dipy/workflows/reconst.py b/dipy/workflows/reconst.py index 20de935012..54888a0615 100644 --- a/dipy/workflows/reconst.py +++ b/dipy/workflows/reconst.py @@ -4,6 +4,7 @@ import numpy as np import os.path from ast import literal_eval +from warnings import warn import nibabel as nib @@ -11,6 +12,7 @@ from dipy.data import get_sphere from dipy.io.gradients import read_bvals_bvecs from dipy.io.peaks import save_peaks, peaks_to_niftis +from dipy.io.image import load_nifti, save_nifti from dipy.reconst.csdeconv import (ConstrainedSphericalDeconvModel, auto_response) from dipy.reconst.dti import (TensorModel, color_fa, fractional_anisotropy, @@ -30,8 +32,8 @@ class ReconstMAPMRIFlow(Workflow): def get_short_name(cls): return 'mapmri' - def run(self, data_file, data_bvals, data_bvecs, small_delta, big_delta, - b0_threshold=0.0, laplacian=True, positivity=True, + def run(self, data_files, bvals_files, bvecs_files, small_delta, big_delta, + b0_threshold=50.0, laplacian=True, positivity=True, bval_threshold=2000, save_metrics=[], laplacian_weighting=0.05, radial_order=6, out_dir='', out_rtop='rtop.nii.gz', out_lapnorm='lapnorm.nii.gz', @@ -43,8 +45,8 @@ def run(self, data_file, data_bvals, data_bvecs, small_delta, big_delta, """ Workflow for fitting the MAPMRI model (with optional Laplacian regularization). Generates rtop, lapnorm, msd, qiv, rtap, rtpp, non-gaussian (ng), parallel ng, perpendicular ng saved in a nifti - format in input files provided by `data_file` and saves the nifti files - to an output directory specified by `out_dir`. + format in input files provided by `data_files` and saves the nifti + files to an output directory specified by `out_dir`. In order for the MAPMRI workflow to work in the way intended either the laplacian or positivity or both must @@ -52,11 +54,11 @@ def run(self, data_file, data_bvals, data_bvecs, small_delta, big_delta, Parameters ---------- - data_file : string + data_files : string Path to the input volume. - data_bvals : string + bvals_files : string Path to the bval files. - data_bvecs : string + bvecs_files : string Path to the bvec files. small_delta : float Small delta value used in generation of gradient table of provided @@ -66,26 +68,26 @@ def run(self, data_file, data_bvals, data_bvecs, small_delta, big_delta, bval and bvec. b0_threshold : float, optional Threshold used to find b=0 directions (default 0.0) - laplacian : bool + laplacian : bool, optional Regularize using the Laplacian of the MAP-MRI basis (default True) - positivity : bool + positivity : bool, optional Constrain the propagator to be positive. (default True) - bval_threshold : float + bval_threshold : float, optional Sets the b-value threshold to be used in the scale factor estimation. In order for the estimated non-Gaussianity to have meaning this value should set to a lower value (b<2000 s/mm^2) such that the scale factors are estimated on signal points that reasonably represent the spins at Gaussian diffusion. (default: 2000) - save_metrics : list of strings + save_metrics : variable string, optional List of metrics to save. Possible values: rtop, laplacian_signal, msd, qiv, rtap, rtpp, ng, perng, parng (default: [] (all)) - laplacian_weighting : float + laplacian_weighting : float, optional Weighting value used in fitting the MAPMRI model in the laplacian and both model types. (default: 0.05) - radial_order : unsigned int + radial_order : unsigned int, optional Even value used to set the order of the basis (default: 6) out_dir : string, optional @@ -114,11 +116,13 @@ def run(self, data_file, data_bvals, data_bvecs, small_delta, big_delta, out_rtap, out_rtpp, out_ng, out_perng, out_parng) in io_it: logging.info('Computing MAPMRI metrics for {0}'.format(dwi)) - img = nib.load(dwi) - data = img.get_data() - affine = img.affine - bvals, bvecs = read_bvals_bvecs(bval, bvec) + data, affine = load_nifti(dwi) + bvals, bvecs = read_bvals_bvecs(bval, bvec) + if b0_threshold < bvals.min(): + warn("b0_threshold (value: {0}) is too low, increase your " + "b0_threshold. It should higher than the first b0 value " + "({1}).".format(b0_threshold, bvals.min())) gtab = gradient_table(bvals=bvals, bvecs=bvecs, small_delta=small_delta, big_delta=big_delta, @@ -169,48 +173,39 @@ def run(self, data_file, data_bvals, data_bvecs, small_delta, big_delta, if 'rtop' in save_metrics: r = mapfit_aniso.rtop() - rtop = nib.nifti1.Nifti1Image(r.astype(np.float32), affine) - nib.save(rtop, out_rtop) + save_nifti(out_rtop, r.astype(np.float32), affine) if 'laplacian_signal' in save_metrics: ll = mapfit_aniso.norm_of_laplacian_signal() - lap = nib.nifti1.Nifti1Image(ll.astype(np.float32), affine) - nib.save(lap, out_lapnorm) + save_nifti(out_lapnorm, ll.astype(np.float32), affine) if 'msd' in save_metrics: m = mapfit_aniso.msd() - msd = nib.nifti1.Nifti1Image(m.astype(np.float32), affine) - nib.save(msd, out_msd) + save_nifti(out_msd, m.astype(np.float32), affine) if 'qiv' in save_metrics: q = mapfit_aniso.qiv() - qiv = nib.nifti1.Nifti1Image(q.astype(np.float32), affine) - nib.save(qiv, out_qiv) + save_nifti(out_qiv, q.astype(np.float32), affine) if 'rtap' in save_metrics: r = mapfit_aniso.rtap() - rtap = nib.nifti1.Nifti1Image(r.astype(np.float32), affine) - nib.save(rtap, out_rtap) + save_nifti(out_rtap, r.astype(np.float32), affine) if 'rtpp' in save_metrics: r = mapfit_aniso.rtpp() - rtpp = nib.nifti1.Nifti1Image(r.astype(np.float32), affine) - nib.save(rtpp, out_rtpp) + save_nifti(out_rtpp, r.astype(np.float32), affine) if 'ng' in save_metrics: n = mapfit_aniso.ng() - ng = nib.nifti1.Nifti1Image(n.astype(np.float32), affine) - nib.save(ng, out_ng) + save_nifti(out_ng, n.astype(np.float32), affine) if 'perng' in save_metrics: n = mapfit_aniso.ng_perpendicular() - ng = nib.nifti1.Nifti1Image(n.astype(np.float32), affine) - nib.save(ng, out_perng) + save_nifti(out_perng, n.astype(np.float32), affine) if 'parng' in save_metrics: n = mapfit_aniso.ng_parallel() - ng = nib.nifti1.Nifti1Image(n.astype(np.float32), affine) - nib.save(ng, out_parng) + save_nifti(out_parng, n.astype(np.float32), affine) logging.info('MAPMRI saved in {0}'. format(os.path.dirname(out_dir))) @@ -221,9 +216,8 @@ class ReconstDtiFlow(Workflow): def get_short_name(cls): return 'dti' - def run(self, input_files, bvalues, bvectors, mask_files, b0_threshold=0.0, - bvecs_tol=0.01, - save_metrics=[], + def run(self, input_files, bvalues_files, bvectors_files, mask_files, + b0_threshold=50, bvecs_tol=0.01, save_metrics=[], out_dir='', out_tensor='tensors.nii.gz', out_fa='fa.nii.gz', out_ga='ga.nii.gz', out_rgb='rgb.nii.gz', out_md='md.nii.gz', out_ad='ad.nii.gz', out_rd='rd.nii.gz', out_mode='mode.nii.gz', @@ -239,10 +233,10 @@ def run(self, input_files, bvalues, bvectors, mask_files, b0_threshold=0.0, input_files : string Path to the input volumes. This path may contain wildcards to process multiple inputs at once. - bvalues : string + bvalues_files : string Path to the bvalues files. This path may contain wildcards to use multiple bvalues files at once. - bvectors : string + bvectors_files : string Path to the bvectors files. This path may contain wildcards to use multiple bvectors files at once. mask_files : string @@ -310,9 +304,7 @@ def run(self, input_files, bvalues, bvectors, mask_files, b0_threshold=0.0, omode, oevecs, oevals in io_it: logging.info('Computing DTI metrics for {0}'.format(dwi)) - img = nib.load(dwi) - data = img.get_data() - affine = img.affine + data, affine = load_nifti(dwi) if mask is not None: mask = nib.load(mask).get_data().astype(np.bool) @@ -332,60 +324,42 @@ def run(self, input_files, bvalues, bvectors, mask_files, b0_threshold=0.0, tensor_vals = lower_triangular(tenfit.quadratic_form) correct_order = [0, 1, 3, 2, 4, 5] tensor_vals_reordered = tensor_vals[..., correct_order] - fiber_tensors = nib.Nifti1Image(tensor_vals_reordered.astype( - np.float32), affine) - nib.save(fiber_tensors, otensor) + + save_nifti(otensor, tensor_vals_reordered.astype(np.float32), + affine) if 'fa' in save_metrics: - fa_img = nib.Nifti1Image(FA.astype(np.float32), - affine) - nib.save(fa_img, ofa) + save_nifti(ofa, FA.astype(np.float32), affine) if 'ga' in save_metrics: GA = geodesic_anisotropy(tenfit.evals) - ga_img = nib.Nifti1Image(GA.astype(np.float32), - affine) - nib.save(ga_img, oga) + save_nifti(oga, GA.astype(np.float32), affine) if 'rgb' in save_metrics: RGB = color_fa(FA, tenfit.evecs) - rgb_img = nib.Nifti1Image(np.array(255 * RGB, 'uint8'), - affine) - nib.save(rgb_img, orgb) + save_nifti(orgb, np.array(255 * RGB, 'uint8'), affine) if 'md' in save_metrics: MD = mean_diffusivity(tenfit.evals) - md_img = nib.Nifti1Image(MD.astype(np.float32), - affine) - nib.save(md_img, omd) + save_nifti(omd, MD.astype(np.float32), affine) if 'ad' in save_metrics: AD = axial_diffusivity(tenfit.evals) - ad_img = nib.Nifti1Image(AD.astype(np.float32), - affine) - nib.save(ad_img, oad) + save_nifti(oad, AD.astype(np.float32), affine) if 'rd' in save_metrics: RD = radial_diffusivity(tenfit.evals) - rd_img = nib.Nifti1Image(RD.astype(np.float32), - affine) - nib.save(rd_img, orad) + save_nifti(orad, RD.astype(np.float32), affine) if 'mode' in save_metrics: MODE = get_mode(tenfit.quadratic_form) - mode_img = nib.Nifti1Image(MODE.astype(np.float32), - affine) - nib.save(mode_img, omode) + save_nifti(omode, MODE.astype(np.float32), affine) if 'evec' in save_metrics: - evecs_img = nib.Nifti1Image(tenfit.evecs.astype(np.float32), - affine) - nib.save(evecs_img, oevecs) + save_nifti(oevecs, tenfit.evecs.astype(np.float32), affine) if 'eval' in save_metrics: - evals_img = nib.Nifti1Image(tenfit.evals.astype(np.float32), - affine) - nib.save(evals_img, oevals) + save_nifti(oevals, tenfit.evals.astype(np.float32), affine) dname_ = os.path.dirname(oevals) if dname_ == '': @@ -398,7 +372,7 @@ def get_tensor_model(self, gtab): return TensorModel(gtab, fit_method="WLS") def get_fitted_tensor(self, data, mask, bval, bvec, - b0_threshold=0, bvecs_tol=0.01): + b0_threshold=50, bvecs_tol=0.01): logging.info('Tensor estimation...') bvals, bvecs = read_bvals_bvecs(bval, bvec) @@ -416,14 +390,9 @@ class ReconstCSDFlow(Workflow): def get_short_name(cls): return 'csd' - def run(self, input_files, bvalues, bvectors, mask_files, - b0_threshold=0.0, - bvecs_tol=0.01, - roi_center=None, - roi_radius=10, - fa_thr=0.7, - frf=None, extract_pam_values=False, - sh_order=8, + def run(self, input_files, bvalues_files, bvectors_files, mask_files, + b0_threshold=50.0, bvecs_tol=0.01, roi_center=None, roi_radius=10, + fa_thr=0.7, frf=None, extract_pam_values=False, sh_order=8, odf_to_sh_order=8, out_dir='', out_pam='peaks.pam5', out_shm='shm.nii.gz', @@ -437,10 +406,10 @@ def run(self, input_files, bvalues, bvectors, mask_files, input_files : string Path to the input volumes. This path may contain wildcards to process multiple inputs at once. - bvalues : string + bvalues_files : string Path to the bvalues files. This path may contain wildcards to use multiple bvalues files at once. - bvectors : string + bvectors_files : string Path to the bvectors files. This path may contain wildcards to use multiple bvectors files at once. mask_files : string @@ -502,23 +471,24 @@ def run(self, input_files, bvalues, bvectors, mask_files, opeaks_indices, ogfa) in io_it: logging.info('Loading {0}'.format(dwi)) - img = nib.load(dwi) - data = img.get_data() - affine = img.affine + data, affine = load_nifti(dwi) bvals, bvecs = read_bvals_bvecs(bval, bvec) + print(b0_threshold, bvals.min()) + if b0_threshold < bvals.min(): + warn("b0_threshold (value: {0}) is too low, increase your " + "b0_threshold. It should higher than the first b0 value " + "({1}).".format(b0_threshold, bvals.min())) gtab = gradient_table(bvals, bvecs, b0_threshold=b0_threshold, atol=bvecs_tol) mask_vol = nib.load(maskfile).get_data().astype(np.bool) - sh_order = 8 - if data.shape[-1] < 15: + n_params = ((sh_order + 1) * (sh_order + 2)) / 2 + if data.shape[-1] < n_params: raise ValueError( - 'You need at least 15 unique DWI volumes to ' - 'compute fiber odfs. You currently have: {0}' - ' DWI volumes.'.format(data.shape[-1])) - elif data.shape[-1] < 30: - sh_order = 6 + 'You need at least {0} unique DWI volumes to ' + 'compute fiber odfs. You currently have: {1}' + ' DWI volumes.'.format(n_params, data.shape[-1])) if frf is None: logging.info('Computing response function') @@ -547,9 +517,8 @@ def run(self, input_files, bvalues, bvectors, mask_files, ratio = l01[1] / l01[0] response = (response, ratio) - logging.info( - 'Eigenvalues for the frf of the input data are :{0}' - .format(response[0])) + logging.info("Eigenvalues for the frf of the input" + " data are :{0}".format(response[0])) logging.info('Ratio for smallest to largest eigen value is {0}' .format(ratio)) @@ -594,8 +563,8 @@ class ReconstCSAFlow(Workflow): def get_short_name(cls): return 'csa' - def run(self, input_files, bvalues, bvectors, mask_files, sh_order=6, - odf_to_sh_order=8, b0_threshold=0.0, bvecs_tol=0.01, + def run(self, input_files, bvalues_files, bvectors_files, mask_files, + sh_order=6, odf_to_sh_order=8, b0_threshold=50.0, bvecs_tol=0.01, extract_pam_values=False, out_dir='', out_pam='peaks.pam5', out_shm='shm.nii.gz', @@ -610,10 +579,10 @@ def run(self, input_files, bvalues, bvectors, mask_files, sh_order=6, input_files : string Path to the input volumes. This path may contain wildcards to process multiple inputs at once. - bvalues : string + bvalues_files : string Path to the bvalues files. This path may contain wildcards to use multiple bvalues files at once. - bvectors : string + bvectors_files : string Path to the bvectors files. This path may contain wildcards to use multiple bvectors files at once. mask_files : string @@ -649,7 +618,6 @@ def run(self, input_files, bvalues, bvectors, mask_files, sh_order=6, out_gfa : string, optional Name of the generalise fa volume to be saved (default 'gfa.nii.gz') - References ---------- .. [1] Aganj, I., et al. 2009. ODF Reconstruction in Q-Ball Imaging @@ -661,11 +629,13 @@ def run(self, input_files, bvalues, bvectors, mask_files, sh_order=6, opeaks_values, opeaks_indices, ogfa) in io_it: logging.info('Loading {0}'.format(dwi)) - vol = nib.load(dwi) - data = vol.get_data() - affine = vol.affine + data, affine = load_nifti(dwi) bvals, bvecs = read_bvals_bvecs(bval, bvec) + if b0_threshold < bvals.min(): + warn("b0_threshold (value: {0}) is too low, increase your " + "b0_threshold. It should higher than the first b0 value " + "({1}).".format(b0_threshold, bvals.min())) gtab = gradient_table(bvals, bvecs, b0_threshold=b0_threshold, atol=bvecs_tol) mask_vol = nib.load(maskfile).get_data().astype(np.bool) @@ -712,8 +682,8 @@ class ReconstDkiFlow(Workflow): def get_short_name(cls): return 'dki' - def run(self, input_files, bvalues, bvectors, mask_files, b0_threshold=0.0, - save_metrics=[], + def run(self, input_files, bvalues_files, bvectors_files, mask_files, + b0_threshold=50.0, save_metrics=[], out_dir='', out_dt_tensor='dti_tensors.nii.gz', out_fa='fa.nii.gz', out_ga='ga.nii.gz', out_rgb='rgb.nii.gz', out_md='md.nii.gz', out_ad='ad.nii.gz', out_rd='rd.nii.gz', out_mode='mode.nii.gz', @@ -730,10 +700,10 @@ def run(self, input_files, bvalues, bvectors, mask_files, b0_threshold=0.0, input_files : string Path to the input volumes. This path may contain wildcards to process multiple inputs at once. - bvalues : string + bvalues_files : string Path to the bvalues files. This path may contain wildcards to use multiple bvalues files at once. - bvectors : string + bvectors_files : string Path to the bvalues files. This path may contain wildcards to use multiple bvalues files at once. mask_files : string @@ -802,9 +772,7 @@ def run(self, input_files, bvalues, bvectors, mask_files, b0_threshold=0.0, omode, oevecs, oevals, odk_tensor, omk, oak, ork) in io_it: logging.info('Computing DKI metrics for {0}'.format(dwi)) - img = nib.load(dwi) - data = img.get_data() - affine = img.affine + data, affine = load_nifti(dwi) if mask is not None: mask = nib.load(mask).get_data().astype(np.bool) @@ -826,72 +794,53 @@ def run(self, input_files, bvalues, bvectors, mask_files, b0_threshold=0.0, tensor_vals = lower_triangular(dkfit.quadratic_form) correct_order = [0, 1, 3, 2, 4, 5] tensor_vals_reordered = tensor_vals[..., correct_order] - fiber_tensors = nib.Nifti1Image(tensor_vals_reordered.astype( - np.float32), affine) - nib.save(fiber_tensors, otensor) + save_nifti(otensor, tensor_vals_reordered.astype(np.float32), + affine) if 'dk_tensor' in save_metrics: - kt_img = nib.Nifti1Image(dkfit.kt.astype(np.float32), affine) - nib.save(kt_img, odk_tensor) + save_nifti(odk_tensor, dkfit.kt.astype(np.float32), affine) if 'fa' in save_metrics: - fa_img = nib.Nifti1Image(FA.astype(np.float32), affine) - nib.save(fa_img, ofa) + save_nifti(ofa, FA.astype(np.float32), affine) if 'ga' in save_metrics: GA = geodesic_anisotropy(dkfit.evals) - ga_img = nib.Nifti1Image(GA.astype(np.float32), affine) - nib.save(ga_img, oga) + save_nifti(oga, GA.astype(np.float32), affine) if 'rgb' in save_metrics: RGB = color_fa(FA, dkfit.evecs) - rgb_img = nib.Nifti1Image(np.array(255 * RGB, 'uint8'), affine) - nib.save(rgb_img, orgb) + save_nifti(orgb, np.array(255 * RGB, 'uint8'), affine) if 'md' in save_metrics: MD = mean_diffusivity(dkfit.evals) - md_img = nib.Nifti1Image(MD.astype(np.float32), affine) - nib.save(md_img, omd) + save_nifti(omd, MD.astype(np.float32), affine) if 'ad' in save_metrics: AD = axial_diffusivity(dkfit.evals) - ad_img = nib.Nifti1Image(AD.astype(np.float32), affine) - nib.save(ad_img, oad) + save_nifti(oad, AD.astype(np.float32), affine) if 'rd' in save_metrics: RD = radial_diffusivity(dkfit.evals) - rd_img = nib.Nifti1Image(RD.astype(np.float32), affine) - nib.save(rd_img, orad) + save_nifti(orad, RD.astype(np.float32), affine) if 'mode' in save_metrics: MODE = get_mode(dkfit.quadratic_form) - mode_img = nib.Nifti1Image(MODE.astype(np.float32), affine) - nib.save(mode_img, omode) + save_nifti(omode, MODE.astype(np.float32), affine) if 'evec' in save_metrics: - evecs_img = nib.Nifti1Image(dkfit.evecs.astype(np.float32), - affine) - nib.save(evecs_img, oevecs) + save_nifti(oevecs, dkfit.evecs.astype(np.float32), affine) if 'eval' in save_metrics: - evals_img = nib.Nifti1Image(dkfit.evals.astype(np.float32), - affine) - nib.save(evals_img, oevals) + save_nifti(oevals, dkfit.evals.astype(np.float32), affine) if 'mk' in save_metrics: - mk_img = nib.Nifti1Image(dkfit.mk().astype(np.float32), - affine) - nib.save(mk_img, omk) + save_nifti(omk, dkfit.mk().astype(np.float32), affine) if 'ak' in save_metrics: - ak_img = nib.Nifti1Image(dkfit.ak().astype(np.float32), - affine) - nib.save(ak_img, oak) + save_nifti(oak, dkfit.ak().astype(np.float32), affine) if 'rk' in save_metrics: - rk_img = nib.Nifti1Image(dkfit.rk().astype(np.float32), - affine) - nib.save(rk_img, ork) + save_nifti(ork, dkfit.rk().astype(np.float32), affine) logging.info('DKI metrics saved in {0}'. format(os.path.dirname(oevals))) @@ -899,11 +848,15 @@ def run(self, input_files, bvalues, bvectors, mask_files, b0_threshold=0.0, def get_dki_model(self, gtab): return DiffusionKurtosisModel(gtab) - def get_fitted_tensor(self, data, mask, bval, bvec, b0_threshold=0): + def get_fitted_tensor(self, data, mask, bval, bvec, b0_threshold=50): logging.info('Diffusion kurtosis estimation...') bvals, bvecs = read_bvals_bvecs(bval, bvec) - gtab = gradient_table(bvals, bvecs, b0_threshold=b0_threshold) + if b0_threshold < bvals.min(): + warn("b0_threshold (value: {0}) is too low, increase your " + "b0_threshold. It should higher than the first b0 value " + "({1}).".format(b0_threshold, bvals.min())) + gtab = gradient_table(bvals, bvecs, b0_threshold=b0_threshold) dkmodel = self.get_dki_model(gtab) dkfit = dkmodel.fit(data, mask) diff --git a/dipy/workflows/segment.py b/dipy/workflows/segment.py index 5a594452bc..9ed2303382 100644 --- a/dipy/workflows/segment.py +++ b/dipy/workflows/segment.py @@ -1,12 +1,13 @@ from __future__ import division, print_function, absolute_import import logging - -import numpy as np - -from dipy.segment.mask import median_otsu from dipy.workflows.workflow import Workflow from dipy.io.image import save_nifti, load_nifti +import numpy as np +from time import time +from dipy.segment.mask import median_otsu +from dipy.io.streamline import load_trk, save_trk +from dipy.segment.bundles import RecoBundles class MedianOtsuFlow(Workflow): @@ -28,7 +29,7 @@ def run(self, input_files, save_masked=False, median_radius=2, numpass=5, input_files : string Path to the input volumes. This path may contain wildcards to process multiple inputs at once. - save_masked : bool + save_masked : bool, optional Save mask median_radius : int, optional Radius (in voxels) of the applied median filter (default 2) @@ -79,3 +80,239 @@ def run(self, input_files, save_masked=False, median_radius=2, numpass=5, format(masked_out_path)) return io_it + + +class RecoBundlesFlow(Workflow): + @classmethod + def get_short_name(cls): + return 'recobundles' + + def run(self, streamline_files, model_bundle_files, + greater_than=50, less_than=1000000, + no_slr=False, clust_thr=15., + reduction_thr=15., + reduction_distance='mdf', + model_clust_thr=2.5, + pruning_thr=8., + pruning_distance='mdf', + slr_metric='symmetric', + slr_transform='similarity', + slr_matrix='small', + refine=False, r_reduction_thr=12., + r_pruning_thr=6., no_r_slr=False, + out_dir='', + out_recognized_transf='recognized.trk', + out_recognized_labels='labels.npy'): + """ Recognize bundles + + Parameters + ---------- + streamline_files : string + The path of streamline files where you want to recognize bundles + model_bundle_files : string + The path of model bundle files + greater_than : int, optional + Keep streamlines that have length greater than + this value (default 50) in mm. + less_than : int, optional + Keep streamlines have length less than this value + (default 1000000) in mm. + no_slr : bool, optional + Don't enable local Streamline-based Linear + Registration (default False). + clust_thr : float, optional + MDF distance threshold for all streamlines (default 15) + reduction_thr : float, optional + Reduce search space by (mm) (default 15) + reduction_distance : string, optional + Reduction distance type can be mdf or mam (default mdf) + model_clust_thr : float, optional + MDF distance threshold for the model bundles (default 2.5) + pruning_thr : float, optional + Pruning after matching (default 8). + pruning_distance : string, optional + Pruning distance type can be mdf or mam (default mdf) + slr_metric : string, optional + Options are None, symmetric, asymmetric or diagonal + (default symmetric). + slr_transform : string, optional + Transformation allowed. translation, rigid, similarity or scaling + (Default 'similarity'). + slr_matrix : string, optional + Options are 'nano', 'tiny', 'small', 'medium', 'large', 'huge' + (default 'small') + refine : bool, optional + Enable refine recognized bunle (default False) + r_reduction_thr : float, optional + Refine reduce search space by (mm) (default 12) + r_pruning_thr : float, optional + Refine pruning after matching (default 6). + no_r_slr : bool, optional + Don't enable Refine local Streamline-based Linear + Registration (default False). + out_dir : string, optional + Output directory (default input file directory) + out_recognized_transf : string, optional + Recognized bundle in the space of the model bundle + (default 'recognized.trk') + out_recognized_labels : string, optional + Indices of recognized bundle in the original tractogram + (default 'labels.npy') + + References + ---------- + .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter + bundles using local and global streamline-based registration and + clustering, Neuroimage, 2017. + """ + slr = not no_slr + r_slr = not no_r_slr + + bounds = [(-30, 30), (-30, 30), (-30, 30), + (-45, 45), (-45, 45), (-45, 45), + (0.8, 1.2), (0.8, 1.2), (0.8, 1.2)] + + slr_matrix = slr_matrix.lower() + if slr_matrix == 'nano': + slr_select = (100, 100) + if slr_matrix == 'tiny': + slr_select = (250, 250) + if slr_matrix == 'small': + slr_select = (400, 400) + if slr_matrix == 'medium': + slr_select = (600, 600) + if slr_matrix == 'large': + slr_select = (800, 800) + if slr_matrix == 'huge': + slr_select = (1200, 1200) + + slr_transform = slr_transform.lower() + if slr_transform == 'translation': + bounds = bounds[:3] + if slr_transform == 'rigid': + bounds = bounds[:6] + if slr_transform == 'similarity': + bounds = bounds[:7] + if slr_transform == 'scaling': + bounds = bounds[:9] + + logging.info('### RecoBundles ###') + + io_it = self.get_io_iterator() + + t = time() + logging.info(streamline_files) + streamlines, header = load_trk(streamline_files) + + logging.info(' Loading time %0.3f sec' % (time() - t,)) + + rb = RecoBundles(streamlines, greater_than=greater_than, + less_than=less_than) + + for _, mb, out_rec, out_labels in io_it: + t = time() + logging.info(mb) + model_bundle, _ = load_trk(mb) + logging.info(' Loading time %0.3f sec' % (time() - t,)) + logging.info("model file = ") + logging.info(mb) + + recognized_bundle, labels = \ + rb.recognize( + model_bundle, + model_clust_thr=model_clust_thr, + reduction_thr=reduction_thr, + reduction_distance=reduction_distance, + pruning_thr=pruning_thr, + pruning_distance=pruning_distance, + slr=slr, + slr_metric=slr_metric, + slr_x0=slr_transform, + slr_bounds=bounds, + slr_select=slr_select, + slr_method='L-BFGS-B') + + if refine: + + if len(recognized_bundle) > 1: + + # affine + x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0]) + affine_bounds = [(-30, 30), (-30, 30), (-30, 30), + (-45, 45), (-45, 45), (-45, 45), + (0.8, 1.2), (0.8, 1.2), (0.8, 1.2), + (-10, 10), (-10, 10), (-10, 10)] + + recognized_bundle, labels = \ + rb.refine( + model_bundle, + recognized_bundle, + model_clust_thr=model_clust_thr, + reduction_thr=r_reduction_thr, + reduction_distance=reduction_distance, + pruning_thr=r_pruning_thr, + pruning_distance=pruning_distance, + slr=r_slr, + slr_metric=slr_metric, + slr_x0=x0, + slr_bounds=affine_bounds, + slr_select=slr_select, + slr_method='L-BFGS-B') + + if len(labels) > 0: + ba, bmd = rb.evaluate_results( + model_bundle, recognized_bundle, + slr_select) + + logging.info("Bundle adjacency Metric {0}".format(ba)) + logging.info("Bundle Min Distance Metric {0}".format(bmd)) + + save_trk(out_rec, recognized_bundle, np.eye(4)) + + logging.info('Saving output files ...') + np.save(out_labels, np.array(labels)) + logging.info(out_rec) + logging.info(out_labels) + + +class LabelsBundlesFlow(Workflow): + @classmethod + def get_short_name(cls): + return 'labelsbundles' + + def run(self, streamline_files, labels_files, + out_dir='', + out_bundle='recognized_orig.trk'): + """ Extract bundles using existing indices (labels) + + Parameters + ---------- + streamline_files : string + The path of streamline files where you want to recognize bundles + labels_files : string + The path of model bundle files + out_dir : string, optional + Output directory (default input file directory) + out_bundle : string, optional + Recognized bundle in the space of the model bundle + (default 'recognized_orig.trk') + + References + ---------- + .. [Garyfallidis17] Garyfallidis et al. Recognition of white matter + bundles using local and global streamline-based registration and + clustering, Neuroimage, 2017. + + """ + logging.info('### Labels to Bundles ###') + + io_it = self.get_io_iterator() + for sf, lb, out_bundle in io_it: + + logging.info(sf) + streamlines, header = load_trk(sf) + logging.info(lb) + location = np.load(lb) + logging.info('Saving output files ...') + save_trk(out_bundle, streamlines[location], np.eye(4)) + logging.info(out_bundle) diff --git a/dipy/workflows/stats.py b/dipy/workflows/stats.py new file mode 100755 index 0000000000..2dfe3e5828 --- /dev/null +++ b/dipy/workflows/stats.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python + +import logging +import numpy as np +import os +import json +from scipy.ndimage.morphology import binary_dilation + +from dipy.io import read_bvals_bvecs +from dipy.io.image import load_nifti, save_nifti +from dipy.core.gradients import gradient_table +from dipy.segment.mask import median_otsu +from dipy.reconst.dti import TensorModel + +from dipy.segment.mask import segment_from_cfa +from dipy.segment.mask import bounding_box + +from dipy.workflows.workflow import Workflow + + +class SNRinCCFlow(Workflow): + + @classmethod + def get_short_name(cls): + return 'snrincc' + + def run(self, data_files, bvals_files, bvecs_files, mask_files, + bbox_threshold=[0.6, 1, 0, 0.1, 0, 0.1], out_dir='', + out_file='product.json', out_mask_cc='cc.nii.gz', + out_mask_noise='mask_noise.nii.gz'): + """Compute the signal-to-noise ratio in the corpus callosum. + + Parameters + ---------- + data_files : string + Path to the dwi.nii.gz file. This path may contain wildcards to + process multiple inputs at once. + bvals_files : string + Path of bvals. + bvecs_files : string + Path of bvecs. + mask_files : string + Path of brain mask + bbox_threshold : variable float, optional + Threshold for bounding box, values separated with commas for ex. + [0.6,1,0,0.1,0,0.1]. (default (0.6, 1, 0, 0.1, 0, 0.1)) + out_dir : string, optional + Where the resulting file will be saved. (default '') + out_file : string, optional + Name of the result file to be saved. (default 'product.json') + out_mask_cc : string, optional + Name of the CC mask volume to be saved (default 'cc.nii.gz') + out_mask_noise : string, optional + Name of the mask noise volume to be saved + (default 'mask_noise.nii.gz') + + """ + io_it = self.get_io_iterator() + + for dwi_path, bvals_path, bvecs_path, mask_path, out_path, \ + cc_mask_path, mask_noise_path in io_it: + data, affine = load_nifti(dwi_path) + bvals, bvecs = read_bvals_bvecs(bvals_path, bvecs_path) + gtab = gradient_table(bvals=bvals, bvecs=bvecs) + + logging.info('Computing brain mask...') + _, calc_mask = median_otsu(data) + + mask, affine = load_nifti(mask_path) + mask = np.array(calc_mask == mask.astype(bool)).astype(int) + + logging.info('Computing tensors...') + tenmodel = TensorModel(gtab) + tensorfit = tenmodel.fit(data, mask=mask) + + logging.info( + 'Computing worst-case/best-case SNR using the CC...') + + if np.ndim(data) == 4: + CC_box = np.zeros_like(data[..., 0]) + elif np.ndim(data) == 3: + CC_box = np.zeros_like(data) + else: + raise IOError('DWI data has invalid dimensions') + + mins, maxs = bounding_box(mask) + mins = np.array(mins) + maxs = np.array(maxs) + diff = (maxs - mins) // 4 + bounds_min = mins + diff + bounds_max = maxs - diff + + CC_box[bounds_min[0]:bounds_max[0], + bounds_min[1]:bounds_max[1], + bounds_min[2]:bounds_max[2]] = 1 + + if len(bbox_threshold) != 6: + raise IOError('bbox_threshold should have 6 float values') + + mask_cc_part, cfa = segment_from_cfa(tensorfit, CC_box, + bbox_threshold, + return_cfa=True) + + save_nifti(cc_mask_path, mask_cc_part.astype(np.uint8), affine) + logging.info('CC mask saved as {0}'.format(cc_mask_path)) + + mean_signal = np.mean(data[mask_cc_part], axis=0) + mask_noise = binary_dilation(mask, iterations=10) + mask_noise[..., :mask_noise.shape[-1]//2] = 1 + mask_noise = ~mask_noise + + save_nifti(mask_noise_path, mask_noise.astype(np.uint8), affine) + logging.info('Mask noise saved as {0}'.format(mask_noise_path)) + + noise_std = np.std(data[mask_noise, :]) + logging.info('Noise standard deviation sigma= ' + str(noise_std)) + + idx = np.sum(gtab.bvecs, axis=-1) == 0 + gtab.bvecs[idx] = np.inf + axis_X = np.argmin( + np.sum((gtab.bvecs-np.array([1, 0, 0])) ** 2, axis=-1)) + axis_Y = np.argmin( + np.sum((gtab.bvecs-np.array([0, 1, 0])) ** 2, axis=-1)) + axis_Z = np.argmin( + np.sum((gtab.bvecs-np.array([0, 0, 1])) ** 2, axis=-1)) + + SNR_output = [] + SNR_directions = [] + for direction in ['b0', axis_X, axis_Y, axis_Z]: + if direction == 'b0': + SNR = mean_signal[0]/noise_std + logging.info("SNR for the b=0 image is :" + str(SNR)) + else: + logging.info("SNR for direction " + str(direction) + + " " + str(gtab.bvecs[direction]) + "is :" + + str(SNR)) + SNR_directions.append(direction) + SNR = mean_signal[direction]/noise_std + SNR_output.append(SNR) + + data = [] + data.append({ + 'data': str(SNR_output[0]) + ' ' + str(SNR_output[1]) + + ' ' + str(SNR_output[2]) + ' ' + str(SNR_output[3]), + 'directions': 'b0' + ' ' + str(SNR_directions[0]) + + ' ' + str(SNR_directions[1]) + ' ' + + str(SNR_directions[2]) + }) + + with open(os.path.join(out_dir, out_path), 'w') as myfile: + json.dump(data, myfile) diff --git a/dipy/workflows/tests/test_align.py b/dipy/workflows/tests/test_align.py index 1e67d2a2b7..6bff424f7c 100644 --- a/dipy/workflows/tests/test_align.py +++ b/dipy/workflows/tests/test_align.py @@ -3,14 +3,17 @@ import nibabel as nib from nibabel.tmpdirs import TemporaryDirectory - -from dipy.data import get_data -from dipy.workflows.align import ResliceFlow +from dipy.tracking.streamline import Streamlines +from dipy.data import get_fnames +from dipy.workflows.align import ResliceFlow, SlrWithQbxFlow +from os.path import join as pjoin +from dipy.io.streamline import save_trk +import os.path def test_reslice(): with TemporaryDirectory() as out_dir: - data_path, _, _ = get_data('small_25') + data_path, _, _ = get_fnames('small_25') vol_img = nib.load(data_path) volume = vol_img.get_data() @@ -20,12 +23,39 @@ def test_reslice(): out_path = reslice_flow.last_generated_outputs['out_resliced'] out_img = nib.load(out_path) resliced = out_img.get_data() - + npt.assert_equal(resliced.shape[0] > volume.shape[0], True) npt.assert_equal(resliced.shape[1] > volume.shape[1], True) npt.assert_equal(resliced.shape[2] > volume.shape[2], True) npt.assert_equal(resliced.shape[-1], volume.shape[-1]) - + + +def test_slr_flow(): + with TemporaryDirectory() as out_dir: + data_path = get_fnames('fornix') + + streams, hdr = nib.trackvis.read(data_path) + fornix = [s[0] for s in streams] + + f = Streamlines(fornix) + f1 = f.copy() + + f1_path = pjoin(out_dir, "f1.trk") + save_trk(f1_path, Streamlines(f1), affine=np.eye(4)) + + f2 = f1.copy() + f2._data += np.array([50, 0, 0]) + + f2_path = pjoin(out_dir, "f2.trk") + save_trk(f2_path, Streamlines(f2), affine=np.eye(4)) + + slr_flow = SlrWithQbxFlow(force=True) + slr_flow.run(f1_path, f2_path) + + out_path = slr_flow.last_generated_outputs['out_moved'] + + npt.assert_equal(os.path.isfile(out_path), True) + if __name__ == '__main__': npt.run_module_suite() diff --git a/dipy/workflows/tests/test_iap.py b/dipy/workflows/tests/test_iap.py index 5211a95095..81834d8603 100644 --- a/dipy/workflows/tests/test_iap.py +++ b/dipy/workflows/tests/test_iap.py @@ -1,10 +1,35 @@ import numpy.testing as npt import sys +from os.path import join as pjoin +from nibabel.tmpdirs import TemporaryDirectory from dipy.workflows.base import IntrospectiveArgumentParser from dipy.workflows.flow_runner import run_flow from dipy.workflows.tests.workflow_tests_utils import TestFlow, \ - DummyCombinedWorkflow, DummyWorkflow1 + DummyCombinedWorkflow, DummyWorkflow1, TestVariableTypeWorkflow, \ + TestVariableTypeErrorWorkflow + + +def test_variable_type(): + with TemporaryDirectory() as out_dir: + open(pjoin(out_dir, 'test'), 'w').close() + open(pjoin(out_dir, 'test1'), 'w').close() + open(pjoin(out_dir, 'test2'), 'w').close() + + sys.argv = [sys.argv[0]] + pos_results = [pjoin(out_dir, 'test'), pjoin(out_dir, 'test1'), + pjoin(out_dir, 'test2'), 12] + inputs = inputs_from_results(pos_results) + sys.argv.extend(inputs) + dcwf = TestVariableTypeWorkflow() + _, positional_res, positional_res2 = run_flow(dcwf) + npt.assert_equal(positional_res2, 12) + + for k, v in zip(positional_res, pos_results[:-1]): + npt.assert_equal(k, v) + + dcwf = TestVariableTypeErrorWorkflow() + npt.assert_raises(ValueError, run_flow, dcwf) def test_iap(): @@ -49,8 +74,8 @@ def test_flow_runner(): old_argv = sys.argv sys.argv = [sys.argv[0]] - opt_keys = ['param_combined', 'dwf1.param1', 'dwf2.param2', 'force', 'out_strat', - 'mix_names'] + opt_keys = ['param_combined', 'dwf1.param1', 'dwf2.param2', 'force', + 'out_strat', 'mix_names'] pos_results = ['dipy.txt'] opt_results = [30, 10, 20, True, 'absolute', True] @@ -90,6 +115,8 @@ def inputs_from_results(results, keys=None, optional=False): return inputs + if __name__ == '__main__': - test_iap() - test_flow_runner() + # test_iap() + # test_flow_runner() + test_variable_type() diff --git a/dipy/workflows/tests/test_io.py b/dipy/workflows/tests/test_io.py index f59da43964..671a559ef1 100644 --- a/dipy/workflows/tests/test_io.py +++ b/dipy/workflows/tests/test_io.py @@ -1,4 +1,4 @@ -from dipy.data import get_data +from dipy.data import get_fnames from dipy.workflows.io import IoInfoFlow import logging @@ -13,14 +13,14 @@ def test_io_info(): - fimg, fbvals, fbvecs=get_data('small_101D') + fimg, fbvals, fbvecs = get_fnames('small_101D') io_info_flow = IoInfoFlow() io_info_flow.run([fimg, fbvals, fbvecs]) - - fimg, fbvals, fvecs = get_data('small_25') + + fimg, fbvals, fvecs = get_fnames('small_25') io_info_flow = IoInfoFlow() io_info_flow.run([fimg, fbvals, fvecs]) - + io_info_flow = IoInfoFlow() io_info_flow.run([fimg, fbvals, fvecs], b0_threshold=20, bvecs_tol=0.001) @@ -30,9 +30,10 @@ def test_io_info(): np.testing.assert_equal( lines[-3], 'INFO Total number of unit bvectors 25\n') - except IndexError: # logging maybe disabled in IDE setting + except IndexError: # logging maybe disabled in IDE setting pass file.close() - + + if __name__ == '__main__': - test_io_info() \ No newline at end of file + test_io_info() diff --git a/dipy/workflows/tests/test_masking.py b/dipy/workflows/tests/test_masking.py index e41f5077f9..a9b64df57e 100644 --- a/dipy/workflows/tests/test_masking.py +++ b/dipy/workflows/tests/test_masking.py @@ -5,13 +5,13 @@ import nibabel as nib from nibabel.tmpdirs import TemporaryDirectory -from dipy.data import get_data +from dipy.data import get_fnames from dipy.workflows.mask import MaskFlow def test_mask(): with TemporaryDirectory() as out_dir: - data_path, _, _ = get_data('small_25') + data_path, _, _ = get_fnames('small_25') vol_img = nib.load(data_path) volume = vol_img.get_data() diff --git a/dipy/workflows/tests/test_reconst_csa_csd.py b/dipy/workflows/tests/test_reconst_csa_csd.py index 98a341f82b..d8f674cfd2 100644 --- a/dipy/workflows/tests/test_reconst_csa_csd.py +++ b/dipy/workflows/tests/test_reconst_csa_csd.py @@ -1,15 +1,19 @@ + import logging import numpy as np from nose.tools import assert_equal -from os.path import join +from os.path import join as pjoin import numpy.testing as npt import nibabel as nib from dipy.io.peaks import load_peaks +from dipy.io.gradients import read_bvals_bvecs +from dipy.core.gradients import generate_bvecs from nibabel.tmpdirs import TemporaryDirectory -from dipy.data import get_data +from dipy.data import get_fnames from dipy.workflows.reconst import ReconstCSDFlow, ReconstCSAFlow +from dipy.reconst.shm import sph_harm_ind_list logging.getLogger().setLevel(logging.INFO) @@ -23,72 +27,100 @@ def test_reconst_csd(): def reconst_flow_core(flow): with TemporaryDirectory() as out_dir: - data_path, bval_path, bvec_path = get_data('small_64D') + data_path, bval_path, bvec_path = get_fnames('small_64D') vol_img = nib.load(data_path) volume = vol_img.get_data() mask = np.ones_like(volume[:, :, :, 0]) mask_img = nib.Nifti1Image(mask.astype(np.uint8), vol_img.affine) - mask_path = join(out_dir, 'tmp_mask.nii.gz') + mask_path = pjoin(out_dir, 'tmp_mask.nii.gz') nib.save(mask_img, mask_path) reconst_flow = flow() - - reconst_flow.run(data_path, bval_path, bvec_path, mask_path, - out_dir=out_dir, extract_pam_values=True) - - gfa_path = reconst_flow.last_generated_outputs['out_gfa'] - gfa_data = nib.load(gfa_path).get_data() - assert_equal(gfa_data.shape, volume.shape[:-1]) - - peaks_dir_path = reconst_flow.last_generated_outputs['out_peaks_dir'] - peaks_dir_data = nib.load(peaks_dir_path).get_data() - assert_equal(peaks_dir_data.shape[-1], 15) - assert_equal(peaks_dir_data.shape[:-1], volume.shape[:-1]) - - peaks_idx_path = \ - reconst_flow.last_generated_outputs['out_peaks_indices'] - peaks_idx_data = nib.load(peaks_idx_path).get_data() - assert_equal(peaks_idx_data.shape[-1], 5) - assert_equal(peaks_idx_data.shape[:-1], volume.shape[:-1]) - - peaks_vals_path = \ - reconst_flow.last_generated_outputs['out_peaks_values'] - peaks_vals_data = nib.load(peaks_vals_path).get_data() - assert_equal(peaks_vals_data.shape[-1], 5) - assert_equal(peaks_vals_data.shape[:-1], volume.shape[:-1]) - - shm_path = reconst_flow.last_generated_outputs['out_shm'] - shm_data = nib.load(shm_path).get_data() - assert_equal(shm_data.shape[-1], 45) - assert_equal(shm_data.shape[:-1], volume.shape[:-1]) - - pam = load_peaks(reconst_flow.last_generated_outputs['out_pam']) - npt.assert_allclose(pam.peak_dirs.reshape(peaks_dir_data.shape), - peaks_dir_data) - npt.assert_allclose(pam.peak_values, peaks_vals_data) - npt.assert_allclose(pam.peak_indices, peaks_idx_data) - npt.assert_allclose(pam.shm_coeff, shm_data) - npt.assert_allclose(pam.gfa, gfa_data) - - if flow.get_short_name() == 'csd': - - reconst_flow = flow() - reconst_flow._force_overwrite = True - reconst_flow.run(data_path, bval_path, bvec_path, mask_path, - out_dir=out_dir, frf=[15, 5, 5]) - reconst_flow = flow() - reconst_flow._force_overwrite = True - reconst_flow.run(data_path, bval_path, bvec_path, mask_path, - out_dir=out_dir, frf='15, 5, 5') - reconst_flow = flow() + for sh_order in [4, 6, 8]: + if flow.get_short_name() == 'csd': + + reconst_flow.run(data_path, bval_path, bvec_path, mask_path, + sh_order=sh_order, + out_dir=out_dir, extract_pam_values=True) + + elif flow.get_short_name() == 'csa': + + reconst_flow.run(data_path, bval_path, bvec_path, mask_path, + sh_order=sh_order, + odf_to_sh_order=sh_order, + out_dir=out_dir, extract_pam_values=True) + + gfa_path = reconst_flow.last_generated_outputs['out_gfa'] + gfa_data = nib.load(gfa_path).get_data() + assert_equal(gfa_data.shape, volume.shape[:-1]) + + peaks_dir_path =\ + reconst_flow.last_generated_outputs['out_peaks_dir'] + peaks_dir_data = nib.load(peaks_dir_path).get_data() + assert_equal(peaks_dir_data.shape[-1], 15) + assert_equal(peaks_dir_data.shape[:-1], volume.shape[:-1]) + + peaks_idx_path = \ + reconst_flow.last_generated_outputs['out_peaks_indices'] + peaks_idx_data = nib.load(peaks_idx_path).get_data() + assert_equal(peaks_idx_data.shape[-1], 5) + assert_equal(peaks_idx_data.shape[:-1], volume.shape[:-1]) + + peaks_vals_path = \ + reconst_flow.last_generated_outputs['out_peaks_values'] + peaks_vals_data = nib.load(peaks_vals_path).get_data() + assert_equal(peaks_vals_data.shape[-1], 5) + assert_equal(peaks_vals_data.shape[:-1], volume.shape[:-1]) + + shm_path = reconst_flow.last_generated_outputs['out_shm'] + shm_data = nib.load(shm_path).get_data() + # Test that the number of coefficients is what you would expect + # given the order of the sh basis: + assert_equal(shm_data.shape[-1], + sph_harm_ind_list(sh_order)[0].shape[0]) + assert_equal(shm_data.shape[:-1], volume.shape[:-1]) + + pam = load_peaks(reconst_flow.last_generated_outputs['out_pam']) + npt.assert_allclose(pam.peak_dirs.reshape(peaks_dir_data.shape), + peaks_dir_data) + npt.assert_allclose(pam.peak_values, peaks_vals_data) + npt.assert_allclose(pam.peak_indices, peaks_idx_data) + npt.assert_allclose(pam.shm_coeff, shm_data) + npt.assert_allclose(pam.gfa, gfa_data) + + bvals, bvecs = read_bvals_bvecs(bval_path, bvec_path) + bvals[0] = 5. + bvecs = generate_bvecs(len(bvals)) + + tmp_bval_path = pjoin(out_dir, "tmp.bval") + tmp_bvec_path = pjoin(out_dir, "tmp.bvec") + np.savetxt(tmp_bval_path, bvals) + np.savetxt(tmp_bvec_path, bvecs.T) reconst_flow._force_overwrite = True - reconst_flow.run(data_path, bval_path, bvec_path, mask_path, - out_dir=out_dir, frf=None) - reconst_flow2 = flow() - reconst_flow2._force_overwrite = True - reconst_flow2.run(data_path, bval_path, bvec_path, mask_path, - out_dir=out_dir, frf=None, - roi_center=[10, 10, 10]) + with npt.assert_raises(BaseException): + npt.assert_warns(UserWarning, reconst_flow.run, data_path, + tmp_bval_path, tmp_bvec_path, mask_path, + out_dir=out_dir, extract_pam_values=True) + + if flow.get_short_name() == 'csd': + + reconst_flow = flow() + reconst_flow._force_overwrite = True + reconst_flow.run(data_path, bval_path, bvec_path, mask_path, + out_dir=out_dir, frf=[15, 5, 5]) + reconst_flow = flow() + reconst_flow._force_overwrite = True + reconst_flow.run(data_path, bval_path, bvec_path, mask_path, + out_dir=out_dir, frf='15, 5, 5') + reconst_flow = flow() + reconst_flow._force_overwrite = True + reconst_flow.run(data_path, bval_path, bvec_path, mask_path, + out_dir=out_dir, frf=None) + reconst_flow2 = flow() + reconst_flow2._force_overwrite = True + reconst_flow2.run(data_path, bval_path, bvec_path, mask_path, + out_dir=out_dir, frf=None, + roi_center=[10, 10, 10]) if __name__ == '__main__': diff --git a/dipy/workflows/tests/test_reconst_dki.py b/dipy/workflows/tests/test_reconst_dki.py index 9f44ebe60c..21997cfe45 100644 --- a/dipy/workflows/tests/test_reconst_dki.py +++ b/dipy/workflows/tests/test_reconst_dki.py @@ -1,4 +1,4 @@ -from os.path import join +from os.path import join as pjoin import nibabel as nib from nibabel.tmpdirs import TemporaryDirectory @@ -6,19 +6,22 @@ import numpy as np from nose.tools import assert_true, assert_equal +import numpy.testing as npt -from dipy.data import get_data +from dipy.data import get_fnames +from dipy.io.gradients import read_bvals_bvecs +from dipy.core.gradients import generate_bvecs from dipy.workflows.reconst import ReconstDkiFlow def test_reconst_dki(): with TemporaryDirectory() as out_dir: - data_path, bval_path, bvec_path = get_data('small_101D') + data_path, bval_path, bvec_path = get_fnames('small_101D') vol_img = nib.load(data_path) volume = vol_img.get_data() mask = np.ones_like(volume[:, :, :, 0]) mask_img = nib.Nifti1Image(mask.astype(np.uint8), vol_img.affine) - mask_path = join(out_dir, 'tmp_mask.nii.gz') + mask_path = pjoin(out_dir, 'tmp_mask.nii.gz') nib.save(mask_img, mask_path) dki_flow = ReconstDkiFlow() @@ -88,6 +91,19 @@ def test_reconst_dki(): assert_equal(evals_data.shape[-1], 3) assert_equal(evals_data.shape[:-1], volume.shape[:-1]) + bvals, bvecs = read_bvals_bvecs(bval_path, bvec_path) + bvals[0] = 5. + bvecs = generate_bvecs(len(bvals)) + + tmp_bval_path = pjoin(out_dir, "tmp.bval") + tmp_bvec_path = pjoin(out_dir, "tmp.bvec") + np.savetxt(tmp_bval_path, bvals) + np.savetxt(tmp_bvec_path, bvecs.T) + dki_flow._force_overwrite = True + npt.assert_warns(UserWarning, dki_flow.run, data_path, + tmp_bval_path, tmp_bvec_path, mask_path, + out_dir=out_dir, b0_threshold=0) + if __name__ == '__main__': test_reconst_dki() diff --git a/dipy/workflows/tests/test_reconst_dti.py b/dipy/workflows/tests/test_reconst_dti.py index 8d69572331..c2cdb95a6c 100644 --- a/dipy/workflows/tests/test_reconst_dti.py +++ b/dipy/workflows/tests/test_reconst_dti.py @@ -7,7 +7,7 @@ from nose.tools import assert_equal -from dipy.data import get_data +from dipy.data import get_fnames from dipy.workflows.reconst import ReconstDtiFlow @@ -16,7 +16,7 @@ def test_reconst_dti_wls(): def reconst_mmri_core(flow, extra_args=[]): with TemporaryDirectory() as out_dir: - data_path, bval_path, bvec_path = get_data('small_25') + data_path, bval_path, bvec_path = get_fnames('small_25') vol_img = nib.load(data_path) vol_img.get_data() # mask = np.ones_like(volume[:, :, :, 0]) @@ -36,7 +36,7 @@ def test_reconst_dti_nlls(): def reconst_flow_core(flow, extra_args=[]): with TemporaryDirectory() as out_dir: - data_path, bval_path, bvec_path = get_data('small_25') + data_path, bval_path, bvec_path = get_fnames('small_25') vol_img = nib.load(data_path) volume = vol_img.get_data() mask = np.ones_like(volume[:, :, :, 0]) diff --git a/dipy/workflows/tests/test_reconst_mapmri.py b/dipy/workflows/tests/test_reconst_mapmri.py index 078702ad05..65d0e06b2f 100644 --- a/dipy/workflows/tests/test_reconst_mapmri.py +++ b/dipy/workflows/tests/test_reconst_mapmri.py @@ -1,4 +1,4 @@ -from os.path import join +from os.path import join as pjoin import nibabel as nib from nibabel.tmpdirs import TemporaryDirectory @@ -6,9 +6,12 @@ import numpy as np from nose.tools import eq_ +import numpy.testing as npt from dipy.reconst import mapmri -from dipy.data import get_data +from dipy.data import get_fnames +from dipy.io.gradients import read_bvals_bvecs +from dipy.core.gradients import generate_bvecs from dipy.workflows.reconst import ReconstMAPMRIFlow @@ -32,13 +35,13 @@ def test_reconst_mmri_positivity(): def reconst_mmri_core(flow, lap, pos): with TemporaryDirectory() as out_dir: - data_path, bval_path, bvec_path = get_data('small_25') + data_path, bval_path, bvec_path = get_fnames('small_25') vol_img = nib.load(data_path) volume = vol_img.get_data() mmri_flow = flow() - mmri_flow.run(data_file=data_path, data_bvals=bval_path, - data_bvecs=bvec_path, small_delta=0.0129, + mmri_flow.run(data_files=data_path, bvals_files=bval_path, + bvecs_files=bvec_path, small_delta=0.0129, big_delta=0.0218, laplacian=lap, positivity=pos, out_dir=out_dir) @@ -78,6 +81,20 @@ def reconst_mmri_core(flow, lap, pos): perng_data = nib.load(perng).get_data() eq_(perng_data.shape, volume.shape[:-1]) + bvals, bvecs = read_bvals_bvecs(bval_path, bvec_path) + bvals[0] = 5. + bvecs = generate_bvecs(len(bvals)) + tmp_bval_path = pjoin(out_dir, "tmp.bval") + tmp_bvec_path = pjoin(out_dir, "tmp.bvec") + np.savetxt(tmp_bval_path, bvals) + np.savetxt(tmp_bvec_path, bvecs.T) + mmri_flow._force_overwrite = True + with npt.assert_raises(BaseException): + npt.assert_warns(UserWarning, mmri_flow.run, data_path, + tmp_bval_path, tmp_bvec_path, small_delta=0.0129, + big_delta=0.0218, laplacian=lap, + positivity=pos, out_dir=out_dir) + if __name__ == '__main__': test_reconst_mmri_laplacian() diff --git a/dipy/workflows/tests/test_segment.py b/dipy/workflows/tests/test_segment.py index a99a0a4ba8..0b7793a7dd 100644 --- a/dipy/workflows/tests/test_segment.py +++ b/dipy/workflows/tests/test_segment.py @@ -1,17 +1,23 @@ import numpy.testing as npt from os.path import join - import nibabel as nib +import numpy as np from nibabel.tmpdirs import TemporaryDirectory - -from dipy.data import get_data +from dipy.data import get_fnames from dipy.segment.mask import median_otsu +from dipy.tracking.streamline import Streamlines from dipy.workflows.segment import MedianOtsuFlow +from dipy.workflows.segment import RecoBundlesFlow, LabelsBundlesFlow +from dipy.io.streamline import load_trk, save_trk +from os.path import join as pjoin +from dipy.tracking.streamline import (set_number_of_points, + select_random_set_of_streamlines) +from dipy.align.streamlinear import BundleMinDistanceMetric def test_median_otsu_flow(): with TemporaryDirectory() as out_dir: - data_path, _, _ = get_data('small_25') + data_path, _, _ = get_fnames('small_25') volume = nib.load(data_path).get_data() save_masked = True median_radius = 3 @@ -22,8 +28,8 @@ def test_median_otsu_flow(): mo_flow = MedianOtsuFlow() mo_flow.run(data_path, out_dir=out_dir, save_masked=save_masked, - median_radius=median_radius, numpass=numpass, - autocrop=autocrop, vol_idx=vol_idx, dilate=dilate) + median_radius=median_radius, numpass=numpass, + autocrop=autocrop, vol_idx=vol_idx, dilate=dilate) mask_name = mo_flow.last_generated_outputs['out_mask'] masked_name = mo_flow.last_generated_outputs['out_masked'] @@ -38,5 +44,54 @@ def test_median_otsu_flow(): result_masked_data = nib.load(join(out_dir, masked_name)).get_data() npt.assert_array_equal(result_masked_data, masked) + +def test_recobundles_flow(): + with TemporaryDirectory() as out_dir: + data_path = get_fnames('fornix') + streams, hdr = nib.trackvis.read(data_path) + fornix = [s[0] for s in streams] + + f = Streamlines(fornix) + f1 = f.copy() + + f2 = f1[:15].copy() + f2._data += np.array([40, 0, 0]) + + f.extend(f2) + + f2_path = pjoin(out_dir, "f2.trk") + save_trk(f2_path, f2, affine=np.eye(4)) + + f1_path = pjoin(out_dir, "f1.trk") + save_trk(f1_path, f, affine=np.eye(4)) + + rb_flow = RecoBundlesFlow(force=True) + rb_flow.run(f1_path, f2_path, greater_than=0, clust_thr=10, + model_clust_thr=5., reduction_thr=10, out_dir=out_dir) + + labels = rb_flow.last_generated_outputs['out_recognized_labels'] + recog_trk = rb_flow.last_generated_outputs['out_recognized_transf'] + + rec_bundle, _ = load_trk(recog_trk) + npt.assert_equal(len(rec_bundle) == len(f2), True) + + label_flow = LabelsBundlesFlow(force=True) + label_flow.run(f1_path, labels) + + recog_bundle = label_flow.last_generated_outputs['out_bundle'] + rec_bundle_org, _ = load_trk(recog_bundle) + + BMD = BundleMinDistanceMetric() + nb_pts = 20 + static = set_number_of_points(f2, nb_pts) + moving = set_number_of_points(rec_bundle_org, nb_pts) + + BMD.setup(static, moving) + x0 = np.array([0, 0, 0, 0, 0, 0, 1., 1., 1, 0, 0, 0]) # affine + bmd_value = BMD.distance(x0.tolist()) + + npt.assert_equal(bmd_value < 1, True) + + if __name__ == '__main__': - test_median_otsu_flow() + npt.run_module_suite() diff --git a/dipy/workflows/tests/test_stats.py b/dipy/workflows/tests/test_stats.py new file mode 100755 index 0000000000..6f24fe1ad1 --- /dev/null +++ b/dipy/workflows/tests/test_stats.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python + +import os +from os.path import join + +import nibabel as nib +from nibabel.tmpdirs import TemporaryDirectory + +import numpy as np + +from nose.tools import assert_true + +from dipy.data import get_data +from dipy.workflows.stats import SNRinCCFlow + + +def test_stats(): + with TemporaryDirectory() as out_dir: + data_path, bval_path, bvec_path = get_data('small_101D') + vol_img = nib.load(data_path) + volume = vol_img.get_data() + mask = np.ones_like(volume[:, :, :, 0]) + mask_img = nib.Nifti1Image(mask.astype(np.uint8), vol_img.affine) + mask_path = join(out_dir, 'tmp_mask.nii.gz') + nib.save(mask_img, mask_path) + + snr_flow = SNRinCCFlow(force=True) + args = [data_path, bval_path, bvec_path, mask_path] + + snr_flow.run(*args, out_dir=out_dir) + assert_true(os.path.exists(os.path.join(out_dir, 'product.json'))) + assert_true(os.stat(os.path.join( + out_dir, 'product.json')).st_size != 0) + assert_true(os.path.exists(os.path.join(out_dir, 'cc.nii.gz'))) + assert_true(os.stat(os.path.join(out_dir, 'cc.nii.gz')).st_size != 0) + assert_true(os.path.exists(os.path.join(out_dir, 'mask_noise.nii.gz'))) + assert_true(os.stat(os.path.join( + out_dir, 'mask_noise.nii.gz')).st_size != 0) + + snr_flow._force_overwrite = True + snr_flow.run(*args, out_dir=out_dir) + assert_true(os.path.exists(os.path.join(out_dir, 'product.json'))) + assert_true(os.stat(os.path.join( + out_dir, 'product.json')).st_size != 0) + assert_true(os.path.exists(os.path.join(out_dir, 'cc.nii.gz'))) + assert_true(os.stat(os.path.join(out_dir, 'cc.nii.gz')).st_size != 0) + assert_true(os.path.exists(os.path.join(out_dir, 'mask_noise.nii.gz'))) + assert_true(os.stat(os.path.join( + out_dir, 'mask_noise.nii.gz')).st_size != 0) + + snr_flow._force_overwrite = True + snr_flow.run(*args, bbox_threshold=(0.5, 1, 0, + 0.15, 0, 0.2), out_dir=out_dir) + assert_true(os.path.exists(os.path.join(out_dir, 'product.json'))) + assert_true(os.stat(os.path.join( + out_dir, 'product.json')).st_size != 0) + assert_true(os.path.exists(os.path.join(out_dir, 'cc.nii.gz'))) + assert_true(os.stat(os.path.join(out_dir, 'cc.nii.gz')).st_size != 0) + assert_true(os.path.exists(os.path.join(out_dir, 'mask_noise.nii.gz'))) + assert_true(os.stat(os.path.join( + out_dir, 'mask_noise.nii.gz')).st_size != 0) + + +if __name__ == '__main__': + test_stats() diff --git a/dipy/workflows/tests/test_tracking.py b/dipy/workflows/tests/test_tracking.py index 9fcd2fcd32..bebcbd3456 100644 --- a/dipy/workflows/tests/test_tracking.py +++ b/dipy/workflows/tests/test_tracking.py @@ -5,7 +5,7 @@ import nibabel as nib from nibabel.tmpdirs import TemporaryDirectory -from dipy.data import get_data +from dipy.data import get_fnames from dipy.io.image import save_nifti from dipy.workflows.mask import MaskFlow from dipy.workflows.reconst import ReconstCSDFlow @@ -14,7 +14,7 @@ def test_det_track(): with TemporaryDirectory() as out_dir: - data_path, bval_path, bvec_path = get_data('small_64D') + data_path, bval_path, bvec_path = get_fnames('small_64D') vol_img = nib.load(data_path) volume = vol_img.get_data() mask = np.ones_like(volume[:, :, :, 0]) diff --git a/dipy/workflows/tests/test_workflow.py b/dipy/workflows/tests/test_workflow.py index 69dcc623af..a843d8e46a 100644 --- a/dipy/workflows/tests/test_workflow.py +++ b/dipy/workflows/tests/test_workflow.py @@ -2,17 +2,19 @@ import os import time +from os.path import join as pjoin from nibabel.tmpdirs import TemporaryDirectory -from dipy.data import get_data +from dipy.data import get_fnames from dipy.workflows.segment import MedianOtsuFlow from dipy.workflows.workflow import Workflow +import numpy.testing as npt def test_force_overwrite(): with TemporaryDirectory() as out_dir: - data_path, _, _ = get_data('small_25') + data_path, _, _ = get_fnames('small_25') mo_flow = MedianOtsuFlow(output_strategy='absolute') # Generate the first results @@ -46,7 +48,34 @@ def test_run(): wf = Workflow() assert_raises(Exception, wf.run, None) + +def test_missing_file(): + # The function is invoking a dummy workflow with a non-existent file. + # So, an IOError will be raised. + + class TestMissingFile(Workflow): + + def run(self, input, out_dir=''): + """Dummy Workflow used to test if input file is absent. + + Parameters + ---------- + + input : string, positional + path of the first input file. + out_dir: string, optional + folder path to save the results. + """ + io = self.get_io_iterator() + + dummyflow = TestMissingFile() + with TemporaryDirectory() as tempdir: + npt.assert_raises(IOError, dummyflow.run, + pjoin(tempdir, 'dummy_file.txt')) + + if __name__ == '__main__': test_force_overwrite() test_get_sub_runs() test_run() + test_missing_file() diff --git a/dipy/workflows/tests/workflow_tests_utils.py b/dipy/workflows/tests/workflow_tests_utils.py index 9d0350c6b9..1e4eb48a79 100644 --- a/dipy/workflows/tests/workflow_tests_utils.py +++ b/dipy/workflows/tests/workflow_tests_utils.py @@ -111,6 +111,61 @@ def run(self, positional_str, positional_bool, positional_int, out_dir : string output directory (default '') """ - return positional_str, positional_bool, positional_int,\ - positional_float, optional_str, optional_bool,\ - optional_int, optional_float, optional_float_2 + return (positional_str, positional_bool, positional_int, + positional_float, optional_str, optional_bool, + optional_int, optional_float, optional_float_2) + + +class TestVariableTypeWorkflow(Workflow): + + @classmethod + def get_short_name(cls): + return 'tvtwf' + + def run(self, positional_variable_str, positional_int, + out_dir=''): + """ Workflow used to test variable string in general. + + Parameters + ---------- + positional_variable_str : variable string + fake input string param + positional_variable_int : int + fake positional param (default 2) + out_dir : string + fake output directory (default '') + """ + result = [] + io_it = self.get_io_iterator() + + for variable1 in io_it: + result.append(variable1) + return result, positional_variable_str, positional_int + + +class TestVariableTypeErrorWorkflow(Workflow): + + @classmethod + def get_short_name(cls): + return 'tvtwfe' + + def run(self, positional_variable_str, positional_variable_int, + out_dir=''): + """ Workflow used to test variable string error. + + Parameters + ---------- + positional_variable_str : variable string + fake input string param + positional_variable_int : variable int + fake positional param (default 2) + out_dir : string + fake output directory (default '') + """ + result = [] + io_it = self.get_io_iterator() + + for variable1, variable2 in io_it: + result.append((variable1, variable2)) + + return result diff --git a/dipy/workflows/workflow.py b/dipy/workflows/workflow.py index 6505058f90..6afbcffbad 100644 --- a/dipy/workflows/workflow.py +++ b/dipy/workflows/workflow.py @@ -8,9 +8,9 @@ class Workflow(object): - def __init__(self, output_strategy='append', mix_names=False, + def __init__(self, output_strategy='absolute', mix_names=False, force=False, skip=False): - """ The basic workflow object. + """Initialize the basic workflow object. This object takes care of any workflow operation that is common to all the workflows. Every new workflow should extend this class. @@ -22,13 +22,12 @@ def __init__(self, output_strategy='append', mix_names=False, self._skip = skip def get_io_iterator(self): - """ Create an iterator for IO. + """Create an iterator for IO. Use a couple of inspection tricks to build an IOIterator using the previous frame (values of local variables and other contextuals) and the run method's docstring. """ - # To manage different python versions. frame = inspect.stack()[1] if isinstance(frame, tuple): @@ -56,7 +55,7 @@ def get_io_iterator(self): return [] def manage_output_overwrite(self): - """ Check if a file will be overwritten upon processing the inputs. + """Check if a file will be overwritten upon processing the inputs. If it is bound to happen, an action is taken depending on self._force_overwrite (or --force via command line). A log message is @@ -86,8 +85,10 @@ def manage_output_overwrite(self): return True - def run(self): - """ Since this is an abstract class, raise exception if this code is + def run(self, *args, **kwargs): + """Execute the workflow. + + Since this is an abstract class, raise exception if this code is reached (not impletemented in child class or literally called on this class) """ @@ -95,14 +96,12 @@ def run(self): format(self.__class__)) def get_sub_runs(self): - """No sub runs since this is a simple workflow. - """ + """Return No sub runs since this is a simple workflow.""" return [] - @classmethod def get_short_name(cls): - """A short name for the workflow used to subdivide + """Return A short name for the workflow used to subdivide. The short name is used by CombinedWorkflows and the argparser to subdivide the commandline parameters avoiding the trouble of having @@ -114,5 +113,6 @@ def get_short_name(cls): Returns class name by default but it is strongly advised to set it to something shorter and easier to write on commandline. + """ return cls.__name__ diff --git a/doc/api_changes.rst b/doc/api_changes.rst index 85e90fe135..42bffd160e 100644 --- a/doc/api_changes.rst +++ b/doc/api_changes.rst @@ -5,6 +5,29 @@ API changes Here we provide information about functions or classes that have been removed, renamed or are deprecated (not recommended) during different release circles. +DIPY 0.15 Changes +----------------- + +**IO** + +``load_tck`` and ``save_tck`` from ``dipy.io.streamline`` has been added. They are highly recommended for managing streamlines. + +**Gradient Table** + +The default value of ``b0_thresold`` has been changed(from 0 to 50). This change can impact your algorithm. +If you want to assure that your code runs in exactly the same manner as before, please initialize your gradient table with the keyword argument ``b0_threshold`` set to 0. + +**Visualization** + +``dipy.viz.fvtk`` module has been removed. Use ``dipy.viz.*`` instead. This implies the following important changes: +- Use ``from dipy.viz import window, actor`` instead of ``from dipy.viz import fvtk`. +- Use ``window.Renderer()`` instead of ``fvtk.ren()``. +- All available actors are in ``dipy.viz.actor`` instead of ``dipy.fvtk.actor``. +- UI elements are available in ``dipy.viz.ui``. + +``dipy.viz`` depends on FURY package. To get more informations about FURY, go to https://fury.gl + + DIPY 0.14 Changes ----------------- diff --git a/doc/conf.py b/doc/conf.py index 7292a7595b..104e679ef3 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -66,7 +66,7 @@ # General information about the project. project = u'dipy' -copyright = u'2008-2016, %(AUTHOR)s <%(AUTHOR_EMAIL)s>' % rel +copyright = u'2008-2019, %(AUTHOR)s <%(AUTHOR_EMAIL)s>' % rel # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the diff --git a/doc/dependencies.rst b/doc/dependencies.rst index 63ccbe178c..e18fb56e2e 100644 --- a/doc/dependencies.rst +++ b/doc/dependencies.rst @@ -6,7 +6,7 @@ Dependencies Depends on a few standard libraries: python_ (the core language), numpy_ (for numerical computation), scipy_ (for more specific mathematical operations), -cython_ (for extra speed), nibabel_ (for file formats; we require version 2.1 +cython_ (for extra speed), nibabel_ (for file formats; we require version 2.3 or higher) and h5py_ (for handling large datasets). Optionally, it can use python-vtk_ (for visualisation), matplotlib_ (for scientific plotting), and ipython_ (for interaction with the code and its results). cvxpy_ is required for diff --git a/doc/devel/make_release.rst b/doc/devel/make_release.rst index 7bbac8f626..293c608be8 100644 --- a/doc/devel/make_release.rst +++ b/doc/devel/make_release.rst @@ -44,6 +44,10 @@ Release checklist outstanding issues that can be closed, and whether there are any issues that should delay the release. Label them ! +* Check whether there are no build failing on `Travis`. Indeed, ``PRE`` build is + allowed to fail and does not block a PR merge but it should block release ! + So make sure that ``PRE`` build is not failing. + * Review and update the release notes. Review and update the :file:`Changelog` file. Get a partial list of contributors with something like:: diff --git a/doc/developers.rst b/doc/developers.rst index 75fd5deb34..45db4fffae 100644 --- a/doc/developers.rst +++ b/doc/developers.rst @@ -5,26 +5,27 @@ Developers The core development team consists of the following individuals: -- **Eleftherios Garyfallidis**, University of Indiana, IN, USA +- **Eleftherios Garyfallidis**, Indiana University, IN, USA - **Ariel Rokem**, University of Washington, WA, USA - **Matthew Brett**, Birmingham University, Birmingham, UK - **Bago Amirbekian**, Databricks, San Francisco, CA, USA -- **Omar Ocegueda**, Center for Research in Mathematics, Guanajuato, MX -- **Stefan Van der Walt**, University of California, Berkeley, CA, USA -- **Marc-Alexandre Côté**, Maluuba, Sherbrooke, QC, CA -- **Ian Nimmo-Smith**, retired, formerly at MRC Cognition and Brain Sciences Unit, Cambridge, UK -- **Maxime Descoteaux**, University of Sherbrooke, QC, CA -- **Serge Koudoro**, University of Indiana, IN, USA - +- **Omar Ocegueda**, Google, San Francisco, CA +- **Marc-Alexandre Côté**, Microsoft Research, Montreal, QC, CA +- **Serge Koudoro**, Indiana University, IN, USA +- **Gabriel Girard**, Swiss Federal Institute of Technology (EPFL), Lausanne, CH +- **Mauro Zucchelli**, INRIA, Sophia-Antipolis, France +- **Rafael Neto Henriques**, Cambridge University, UK +- **Matthieu Dumont**, Imeka, Sherbrooke, QC, CA +- **Ranveer Aggarwal**, Microsoft, Hyderabad, Telangana, India And here is the rest of the wonderful contributors: -- **Mauro Zucchelli**, University of Verona, IT -- **Matthieu Dumont**, PAVI, Sherbrooke, QC, CA +- **Ian Nimmo-Smith**, retired, formerly at MRC Cognition and Brain Sciences Unit, Cambridge, UK +- **Maxime Descoteaux**, University of Sherbrooke, QC, CA +- **Stefan Van der Walt**, University of California, Berkeley, CA, USA - **Samuel St-Jean**, University Medical Center (UMC) Utrecht, Utrecht, NL -- **Gabriel Girard**, Swiss Federal Institute of Technology, Lausanne, CH -- **Michael Paquette**, University of Sherbrooke, QC, CA -- **Jean-Christophe Houde**, University of Sherbrooke, QC, CA +- **Michael Paquette**, Max Planck Institute for Human Cognitive and Brain Sciences, Leipzig, DE +- **Jean-Christophe Houde**, University of Sherbrooke, QC, CA and Imeka, Sherbrooke, QC, CA - **Christopher Nguyen**, Massachusetts General Hospital, MA, USA - **Emanuele Olivetti**, NeuroInformatics Laboratory (NILab), Trento, IT - **Yaroslav Halchenco**, PBS Department, Dartmouth, NH, USA @@ -34,6 +35,7 @@ And here is the rest of the wonderful contributors: - **Kimberly Chan**, Stanford University, CA, USA - **Chantal Tax**, Cardiff University, Cardiff, UK - **Demian Wassermann**, INRIA, Sophia Antipolis, FR +- **Rutger Fick**, INRIA, Sophia Antipolis, FR - **Gregory R. Lee**, Cincinnati Children's Hospital Medical Center, Cincinnati, OH, USA - **Endolith**, New-York, NY, USA - **Matthias Ekman**, Donders Institute for Brain, Cognition and Behaviour, Nijmegen, NL @@ -42,7 +44,7 @@ And here is the rest of the wonderful contributors: - **Maria Luisa Mandelli**, University of California, San Francisco, CA, USA - **Adam Rybinski**, Jagiellonian University, Krakow, PL - **Qiyuan Tian**, Stanford University, Stanford, CA, USA -- **Rafael Neto Henriques**, Cambridge University, UK +- **Rafael Neto Henriques**, Champalimaud Neuroscience Programme, Champalimaud Centre for the Unknown, Lisbon, PT - **Stephan Meesters**, Eindhoven University of Technology, NL - **Himanshu Mishra**, Indian Institute of Technology, Karaghpur, IN - **Alexander Gauvin**, University of Sherbrooke, QC, CA @@ -53,6 +55,8 @@ And here is the rest of the wonderful contributors: - **Sagun Pai**, Indian Institute of Technology, Bombay, IN - **Vatsala Swaroop**, Mombai, IN - **Shahnawaz Ahmed**, Birla Institute of Technology and Science, Pilani, Goa, IN +- **Nil Goyette**, Imeka, Sherbrooke, QC, CA + Boundless collaboration is in the heart of DIPY_. We encourage everyone from anywhere in the world to join the team. You can start sharing your code `here`__. If you want to contribute but you don't know in area to focus, please send us an e-mail. We will be more than happy to help. diff --git a/doc/examples/bundle_extraction.py b/doc/examples/bundle_extraction.py new file mode 100644 index 0000000000..3382d9fcc9 --- /dev/null +++ b/doc/examples/bundle_extraction.py @@ -0,0 +1,175 @@ +""" +================================================== +Automatic Fiber Bundle Extraction with RecoBundles +================================================== + +This example explains how we can use RecokBundles [Garyfallidis17]_ to +extract bundles from tractograms. + +First import the necessary modules. +""" + +from dipy.segment.bundles import RecoBundles +from dipy.align.streamlinear import whole_brain_slr +from dipy.viz import window, actor +from dipy.io.streamline import load_trk, save_trk + + +""" +Download and read data for this tutorial +""" + +from dipy.data.fetcher import (fetch_target_tractogram_hcp, + fetch_bundle_atlas_hcp842, + get_bundle_atlas_hcp842, + get_target_tractogram_hcp) + +target_file, target_folder = fetch_target_tractogram_hcp() +atlas_file, atlas_folder = fetch_bundle_atlas_hcp842() + +atlas_file, all_bundles_files = get_bundle_atlas_hcp842() +target_file = get_target_tractogram_hcp() + +atlas, atlas_header = load_trk(atlas_file) +target, target_header = load_trk(target_file) + +""" +let's visualize atlas tractogram and target tractogram before registration +""" + +interactive = False + +ren = window.Renderer() +ren.SetBackground(1, 1, 1) +ren.add(actor.line(atlas, colors=(1,0,1))) +ren.add(actor.line(target, colors=(1,1,0))) +window.record(ren, out_path='tractograms_initial.png', size=(600, 600)) +if interactive: + window.show(ren) + +""" +.. figure:: tractograms_initial.png + :align: center + + Atlas and target before registration. + +""" + +""" +We will register target tractogram to model atlas' space using streamlinear +registeration (SLR) [Garyfallidis15]_ +""" + +moved, transform, qb_centroids1, qb_centroids2 = whole_brain_slr( + atlas, target, x0='affine', verbose=True, progressive=True) + +""" +let's visualize atlas tractogram and target tractogram after registration +""" + +interactive = False + +ren = window.Renderer() +ren.SetBackground(1, 1, 1) +ren.add(actor.line(atlas, colors=(1,0,1))) +ren.add(actor.line(moved, colors=(1,1,0))) +window.record(ren, out_path='tractograms_after_registration.png', + size=(600, 600)) +if interactive: + window.show(ren) + +""" +.. figure:: tractograms_after_registration.png + :align: center + + Atlas and target after registration. + +""" + +""" +Read AF left and CST left bundles from already fetched atlas data to use them +as model bundles +""" + +from dipy.data.fetcher import get_two_hcp842_bundles +bundle1, bundle2 = get_two_hcp842_bundles() + +""" +Extracting bundles using recobundles [Garyfallidis17]_ +""" + +model_bundle, _ = load_trk(bundle1) + +rb = RecoBundles(moved, verbose=True) + +recognized_bundle, rec_labels = rb.recognize(model_bundle=model_bundle, + model_clust_thr=5., + reduction_thr=10, + reduction_distance='mam', + slr=True, + slr_metric='asymmetric', + pruning_distance='mam') + +""" +let's visualize extracted Arcuate Fasciculus Left bundle and model bundle +together +""" + +interactive = False + +ren = window.Renderer() +ren.SetBackground(1, 1, 1) +ren.add(actor.line(model_bundle, colors=(.1,.7,.26))) +ren.add(actor.line(recognized_bundle, colors=(.1,.1,6))) +ren.set_camera(focal_point=(320.21296692, 21.28884506, 17.2174015), + position=(2.11, 200.46, 250.44) , view_up=(0.1, -1.028, 0.18)) +window.record(ren, out_path='AF_L_recognized_bundle.png', + size=(600, 600)) +if interactive: + window.show(ren) + +""" +.. figure:: AF_L_recognized_bundle.png + :align: center + + Extracted Arcuate Fasciculus Left bundle and model bundle + +""" + +model_bundle, _ = load_trk(bundle2) + +recognized_bundle, rec_labels = rb.recognize(model_bundle=model_bundle, + model_clust_thr=5., + reduction_thr=10, + reduction_distance='mam', + slr=True, + slr_metric='asymmetric', + pruning_distance='mam') + +""" +let's visualize extracted Corticospinal Tract (CST) Left bundle and model +bundle together +""" + +interactive = False + +ren = window.Renderer() +ren.SetBackground(1, 1, 1) +ren.add(actor.line(model_bundle, colors=(.1,.7,.26))) +ren.add(actor.line(recognized_bundle, colors=(.1,.1,6))) +ren.set_camera(focal_point=(-18.17281532, -19.55606842, 6.92485857), + position=(-360.11, -340.46, -40.44), + view_up=(-0.03, 0.028, 0.89)) +window.record(ren, out_path='CST_L_recognized_bundle.png', + size=(600, 600)) +if interactive: + window.show(ren) + + +""" +.. figure:: CST_L_recognized_bundle.png + :align: center + + Extracted Corticospinal Tract (CST) Left bundle and model bundle + +""" diff --git a/doc/examples/cluster_confidence.py b/doc/examples/cluster_confidence.py new file mode 100644 index 0000000000..dbe85eded0 --- /dev/null +++ b/doc/examples/cluster_confidence.py @@ -0,0 +1,188 @@ +""" +================================== +Calculation of Outliers with Cluster Confidence Index +================================== + +This is an outlier scoring method that compares the pathways of each streamline +in a bundle (pairwise) and scores each streamline by how many other streamlines +have similar pathways. The details can be found in [Jordan_2018_plm]_. + +""" + +from dipy.data import read_stanford_labels +from dipy.reconst.shm import CsaOdfModel +from dipy.data import default_sphere +from dipy.direction import peaks_from_model +from dipy.tracking.local import ThresholdTissueClassifier +from dipy.tracking import utils +from dipy.tracking.local import LocalTracking +from dipy.tracking.streamline import Streamlines +from dipy.viz import actor, window +from dipy.tracking.utils import length + +import matplotlib.pyplot as plt +import matplotlib + +from dipy.tracking.streamline import cluster_confidence + + +""" +First, we need to generate some streamlines. For a more complete +description of these steps, please refer to the CSA Probabilistic Tracking and +the Visualization of ROI Surface Rendered with Streamlines Tutorials. + """ + + +hardi_img, gtab, labels_img = read_stanford_labels() +data = hardi_img.get_data() +labels = labels_img.get_data() +affine = hardi_img.affine +white_matter = (labels == 1) | (labels == 2) +csa_model = CsaOdfModel(gtab, sh_order=6) +csa_peaks = peaks_from_model(csa_model, data, default_sphere, + relative_peak_threshold=.8, + min_separation_angle=45, + mask=white_matter) +classifier = ThresholdTissueClassifier(csa_peaks.gfa, .25) + + +""" +We will use a slice of the anatomically-based corpus callosum ROI as our +seed mask to demonstrate the method. + """ + + +# Make a corpus callosum seed mask for tracking +seed_mask = labels == 2 +seeds = utils.seeds_from_mask(seed_mask, density=[1, 1, 1], affine=affine) +# Make a streamline bundle model of the corpus callosum ROI connectivity +streamlines = LocalTracking(csa_peaks, classifier, seeds, affine, + step_size=2) +streamlines = Streamlines(streamlines) + + +""" +We do not want our results inflated by short streamlines, so we remove +streamlines shorter than 40mm prior to calculating the CCI. +""" + +lengths = list(length(streamlines)) +long_streamlines = Streamlines() +for i, sl in enumerate(streamlines): + if lengths[i] > 40: + long_streamlines.append(sl) + + +""" +Now we calculate the Cluster Confidence Index using the corpus callosum +streamline bundle and visualize them. +""" + + +cci = cluster_confidence(long_streamlines) + +# Visualize the streamlines, colored by cci +ren = window.Renderer() + +hue = [0.5, 1] +saturation = [0.0, 1.0] + +lut_cmap = actor.colormap_lookup_table(scale_range=(cci.min(), cci.max()/4), + hue_range=hue, + saturation_range=saturation) + +bar3 = actor.scalar_bar(lut_cmap) +ren.add(bar3) + +stream_actor = actor.line(long_streamlines, cci, linewidth=0.1, + lookup_colormap=lut_cmap) +ren.add(stream_actor) + + +""" +If you set interactive to True (below), the rendering will pop up in an +interactive window. +""" + + +interactive = False +if interactive: + window.show(ren) +window.record(ren, n_frames=1, out_path='cci_streamlines.png', + size=(800, 800)) + +""" +.. figure:: cci_streamlines.png + :align: center + + Cluster Confidence Index of corpus callosum dataset. + + +If you think of each streamline as a sample of a potential pathway through a +complex landscape of white matter anatomy probed via water diffusion, +intuitively we have more confidence that pathways represented by many samples +(streamlines) reflect a more stable representation of the underlying phenomenon +we are trying to model (anatomical landscape) than do lone samples. + +The CCI provides a voting system where by each streamline (within a set +tolerance) gets to vote on how much support it lends to. Outlier pathways score +relatively low on CCI, since they do not have many streamlines voting for them. +These outliers can be removed by thresholding on the CCI metric. + +""" + + +fig, ax = plt.subplots(1) +ax.hist(cci, bins=100, histtype='step') +ax.set_xlabel('CCI') +ax.set_ylabel('# streamlines') +fig.savefig('cci_histogram.png') + + +""" +.. figure:: cci_histogram.png + :align: center + + Histogram of Cluster Confidence Index values. + +Now we threshold the CCI, defining outliers as streamlines that score below 1. + +""" + +keep_streamlines = Streamlines() +for i, sl in enumerate(long_streamlines): + if cci[i] >= 1: + keep_streamlines.append(sl) + +# Visualize the streamlines we kept +ren = window.Renderer() + +keep_streamlines_actor = actor.line(keep_streamlines, linewidth=0.1) + +ren.add(keep_streamlines_actor) + + +interactive = False +if interactive: + window.show(ren) +window.record(ren, n_frames=1, out_path='filtered_cci_streamlines.png', + size=(800, 800)) + +""" + +.. figure:: filtered_cci_streamlines.png + :align: center + + Outliers, defined as streamlines scoring CCI < 1, were excluded. + + +References +---------- + +.. [Jordan_2018_plm] Jordan, K., Amirbekian, B., Keshavan, A., Henry, R.G. +"Cluster Confidence Index: A Streamline‐Wise Pathway Reproducibility Metric +for Diffusion‐Weighted MRI Tractography", Journal of Neuroimaging, 2017. + +.. include:: ../links_names.inc + +""" diff --git a/doc/examples/denoise_localpca.py b/doc/examples/denoise_localpca.py index 2db39f3e22..4ad5f07930 100644 --- a/doc/examples/denoise_localpca.py +++ b/doc/examples/denoise_localpca.py @@ -39,7 +39,7 @@ img, gtab = read_isbi2013_2shell() data = img.get_data() -affine = img.get_affine() +affine = img.affine print("Input Volume", data.shape) diff --git a/doc/examples/gradients_spheres.py b/doc/examples/gradients_spheres.py index a419910d0a..3c80ecaf29 100644 --- a/doc/examples/gradients_spheres.py +++ b/doc/examples/gradients_spheres.py @@ -74,7 +74,7 @@ sph = Sphere(xyz=np.vstack((hsph_updated.vertices, -hsph_updated.vertices))) window.rm_all(ren) -ren.add(actor.point(sph.vertices, actor.colors.green, point_radius=0.05)) +ren.add(actor.point(sph.vertices, window.colors.green, point_radius=0.05)) print('Saving illustration as full_sphere.png') window.record(ren, out_path='full_sphere.png', size=(300, 300)) diff --git a/doc/examples/introduction_to_basic_tracking.py b/doc/examples/introduction_to_basic_tracking.py index 39c69f2cb6..dc6cf4b253 100644 --- a/doc/examples/introduction_to_basic_tracking.py +++ b/doc/examples/introduction_to_basic_tracking.py @@ -84,9 +84,10 @@ """ from dipy.tracking import utils +import numpy as np seed_mask = labels == 2 -seeds = utils.seeds_from_mask(seed_mask, density=[2, 2, 2], affine=affine) +seeds = utils.seeds_from_mask(seed_mask, density=[2, 2, 2], affine=np.eye(4)) """ Finally, we can bring it all together using ``LocalTracking``. We will then @@ -94,24 +95,24 @@ """ from dipy.tracking.local import LocalTracking -from dipy.viz import window, actor -from dipy.viz.colormap import line_colors +from dipy.viz import window, actor, colormap as cmap, have_fury from dipy.tracking.streamline import Streamlines # Enables/disables interactive visualization interactive = False # Initialization of LocalTracking. The computation happens in the next step. -streamlines_generator = LocalTracking(csa_peaks, classifier, seeds, affine, step_size=.5) +streamlines_generator = LocalTracking(csa_peaks, classifier, seeds, + affine=np.eye(4), step_size=.5) # Generate streamlines object streamlines = Streamlines(streamlines_generator) # Prepare the display objects. -color = line_colors(streamlines) +color = cmap.line_colors(streamlines) -if window.have_vtk: - streamlines_actor = actor.line(streamlines, line_colors(streamlines)) +if have_fury: + streamlines_actor = actor.line(streamlines, cmap.line_colors(streamlines)) # Create the 3D display. r = window.Renderer() @@ -193,14 +194,15 @@ callosum. """ -streamlines_generator = LocalTracking(prob_dg, classifier, seeds, affine, - step_size=.5, max_cross=1) +streamlines_generator = LocalTracking(prob_dg, classifier, seeds, + affine=np.eye(4), step_size=.5, + max_cross=1) # Generate streamlines object. streamlines = Streamlines(streamlines_generator) -if window.have_vtk: - streamlines_actor = actor.line(streamlines, line_colors(streamlines)) +if have_fury: + streamlines_actor = actor.line(streamlines, cmap.line_colors(streamlines)) # Create the 3D display. r = window.Renderer() diff --git a/doc/examples/linear_fascicle_evaluation.py b/doc/examples/linear_fascicle_evaluation.py index d0bf46dd80..5651e5c1aa 100644 --- a/doc/examples/linear_fascicle_evaluation.py +++ b/doc/examples/linear_fascicle_evaluation.py @@ -5,8 +5,8 @@ Evaluating the results of tractography algorithms is one of the biggest challenges for diffusion MRI. One proposal for evaluation of tractography -results is to use a forward model that predicts the signal from each of a set of -streamlines, and then fit a linear model to these simultaneous predictions +results is to use a forward model that predicts the signal from each of a set +of streamlines, and then fit a linear model to these simultaneous predictions [Pestilli2014]_. We will use streamlines generated using probabilistic tracking on CSA @@ -55,18 +55,18 @@ """ -Let's visualize the initial candidate group of streamlines in 3D, relative to the -anatomical structure of this brain: +Let's visualize the initial candidate group of streamlines in 3D, relative to +the anatomical structure of this brain: """ -from dipy.viz.colormap import line_colors -from dipy.viz import window, actor +from dipy.viz import window, actor, colormap as cmap # Enables/disables interactive visualization interactive = False -candidate_streamlines_actor = actor.streamtube(candidate_sl, line_colors(candidate_sl)) +candidate_streamlines_actor = actor.streamtube(candidate_sl, + cmap.line_colors(candidate_sl)) cc_ROI_actor = actor.contour_from_roi(cc_slice, color=(1., 1., 0.), opacity=0.5) @@ -182,7 +182,7 @@ optimized_sl = list(np.array(candidate_sl)[np.where(fiber_fit.beta > 0)[0]]) ren = window.Renderer() -ren.add(actor.streamtube(optimized_sl, line_colors(optimized_sl))) +ren.add(actor.streamtube(optimized_sl, cmap.line_colors(optimized_sl))) ren.add(cc_ROI_actor) ren.add(vol_actor) window.record(ren, n_frames=1, out_path='life_optimized.png', diff --git a/doc/examples/particle_filtering_fiber_tracking.py b/doc/examples/particle_filtering_fiber_tracking.py index 570a5b8244..5cef4cb426 100644 --- a/doc/examples/particle_filtering_fiber_tracking.py +++ b/doc/examples/particle_filtering_fiber_tracking.py @@ -32,8 +32,7 @@ auto_response) from dipy.tracking.local import LocalTracking, ParticleFilteringTracking from dipy.tracking import utils -from dipy.viz import window, actor -from dipy.viz.colormap import line_colors +from dipy.viz import window, actor, colormap as cmap renderer = window.Renderer() @@ -43,7 +42,7 @@ data = hardi_img.get_data() labels = labels_img.get_data() -affine = hardi_img.get_affine() +affine = hardi_img.affine shape = labels.shape response, ratio = auto_response(gtab, data, roi_radius=10, fa_thr=0.7) @@ -98,13 +97,13 @@ particle_count=15, return_all=False) -#streamlines = list(pft_streamline_generator) +# streamlines = list(pft_streamline_generator) streamlines = Streamlines(pft_streamline_generator) save_trk("pft_streamline.trk", streamlines, affine, shape) renderer.clear() -renderer.add(actor.line(streamlines, line_colors(streamlines))) +renderer.add(actor.line(streamlines, cmap.line_colors(streamlines))) window.record(renderer, out_path='pft_streamlines.png', size=(600, 600)) """ @@ -123,12 +122,12 @@ step_size=step_size, maxlen=1000, return_all=False) -#streamlines = list(pro) +# streamlines = list(pro) streamlines = Streamlines(prob_streamline_generator) save_trk("probabilistic_streamlines.trk", streamlines, affine, shape) renderer.clear() -renderer.add(actor.line(streamlines, line_colors(streamlines))) +renderer.add(actor.line(streamlines, cmap.line_colors(streamlines))) window.record(renderer, out_path='probabilistic_streamlines.png', size=(600, 600)) diff --git a/doc/examples/path_length_map.py b/doc/examples/path_length_map.py new file mode 100644 index 0000000000..7397fdcf58 --- /dev/null +++ b/doc/examples/path_length_map.py @@ -0,0 +1,186 @@ +""" +================================== +Calculate Path Length Map +================================== + +We show how to calculate a Path Length Map for Anisotropic Radiation Therapy +Contours given a set of streamlines and a region of interest (ROI). +The Path Length Map is a volume in which each voxel's value is the shortest +distance along a streamline to a given region of interest (ROI). This map can +be used to anisotropically modify radiation therapy treatment contours based +on a tractography model of the local white matter anatomy, as described in +[Jordan_2018_plm]_, by executing this tutorial with the gross tumor volume +(GTV) as the ROI. + +NOTE: The background value is set to -1 by default +""" + +from dipy.data import read_stanford_labels, fetch_stanford_t1, read_stanford_t1 +from dipy.reconst.shm import CsaOdfModel +from dipy.data import default_sphere +from dipy.direction import peaks_from_model +from dipy.tracking.local import ThresholdTissueClassifier +from dipy.tracking import utils +from dipy.tracking.local import LocalTracking +from dipy.tracking.streamline import Streamlines +from dipy.viz import actor, window, colormap as cmap +from dipy.tracking.utils import path_length +import nibabel as nib +import numpy as np +import matplotlib as mpl +from mpl_toolkits.axes_grid1 import AxesGrid + +""" +First, we need to generate some streamlines and visualize. For a more complete +description of these steps, please refer to the :ref:`example_probabilistic_fiber_tracking` +and the Visualization of ROI Surface Rendered with Streamlines Tutorials. + +""" + +hardi_img, gtab, labels_img = read_stanford_labels() +data = hardi_img.get_data() +labels = labels_img.get_data() +affine = hardi_img.affine + +white_matter = (labels == 1) | (labels == 2) + +csa_model = CsaOdfModel(gtab, sh_order=6) +csa_peaks = peaks_from_model(csa_model, data, default_sphere, + relative_peak_threshold=.8, + min_separation_angle=45, + mask=white_matter) + +classifier = ThresholdTissueClassifier(csa_peaks.gfa, .25) + +""" +We will use an anatomically-based corpus callosum ROI as our seed mask to +demonstrate the method. In practice, this corpus callosum mask (labels == 2) +should be replaced with the desired ROI mask (e.g. gross tumor volume (GTV), +lesion mask, or electrode mask). + +""" + +# Make a corpus callosum seed mask for tracking +seed_mask = labels == 2 +seeds = utils.seeds_from_mask(seed_mask, density=[1, 1, 1], affine=affine) + +# Make a streamline bundle model of the corpus callosum ROI connectivity +streamlines = LocalTracking(csa_peaks, classifier, seeds, affine, + step_size=2) +streamlines = Streamlines(streamlines) + +# Visualize the streamlines and the Path Length Map base ROI +# (in this case also the seed ROI) + +streamlines_actor = actor.line(streamlines, cmap.line_colors(streamlines)) +surface_opacity = 0.5 +surface_color = [0, 1, 1] +seedroi_actor = actor.contour_from_roi(seed_mask, affine, + surface_color, surface_opacity) + +ren = window.Renderer() +ren.add(streamlines_actor) +ren.add(seedroi_actor) + +""" +If you set interactive to True (below), the rendering will pop up in an +interactive window. +""" + +interactive = False +if interactive: + window.show(ren) + +window.record(ren, n_frames=1, out_path='plm_roi_sls.png', + size=(800, 800)) + + +""" +.. figure:: plm_roi_sls.png + :align: center + + **A top view of corpus callosum streamlines with the blue transparent ROI in + the center**. +""" + +""" +Now we calculate the Path Length Map using the corpus callosum streamline +bundle and corpus callosum ROI. + +NOTE: the mask used to seed the tracking does not have to be the Path +Length Map base ROI, as we do here, but it often makes sense for them to be the +same ROI if we want a map of the whole brain's distance back to our ROI. +(e.g. we could test a hypothesis about the motor system by making a streamline +bundle model of the cortico-spinal track (CST) and input a lesion mask as our +Path Length Map base ROI to restrict the analysis to the CST) +""" + +path_length_map_base_roi = seed_mask + +# calculate the WMPL + +wmpl = path_length(streamlines, path_length_map_base_roi, affine) + +# save the WMPL as a nifti +path_length_img = nib.Nifti1Image(wmpl.astype(np.float32), affine) +nib.save(path_length_img, 'example_cc_path_length_map.nii.gz') + +# get the T1 to show anatomical context of the WMPL +fetch_stanford_t1() +t1 = read_stanford_t1() +t1_data = t1.get_data() + + +fig = mpl.pyplot.figure() +fig.subplots_adjust(left=0.05, right=0.95) +ax = AxesGrid(fig, 111, + nrows_ncols=(1, 3), + cbar_location="right", + cbar_mode="single", + cbar_size="10%", + cbar_pad="5%") + +''' +We will mask our WMPL to ignore values less than zero because negative numbers +indicate no path back to the ROI was found in the provided streamlines +''' + +wmpl_show = np.ma.masked_where(wmpl < 0, wmpl) + +slx, sly, slz = [60, 50, 35] +ax[0].matshow(np.rot90(t1_data[:, slx, :]), cmap=mpl.cm.bone) +im = ax[0].matshow(np.rot90(wmpl_show[:, slx, :]), + cmap=mpl.cm.cool, vmin=0, vmax=80) + +ax[1].matshow(np.rot90(t1_data[:, sly, :]), cmap=mpl.cm.bone) +im = ax[1].matshow(np.rot90(wmpl_show[:, sly, :]), cmap=mpl.cm.cool, + vmin=0, vmax=80) + +ax[2].matshow(np.rot90(t1_data[:, slz, :]), cmap=mpl.cm.bone) +im = ax[2].matshow(np.rot90(wmpl_show[:, slz, :]), + cmap=mpl.cm.cool, vmin=0, vmax=80) + +ax.cbar_axes[0].colorbar(im) +for lax in ax: + lax.set_xticks([]) + lax.set_yticks([]) +fig.savefig("Path_Length_Map.png") + + +""" +.. figure:: Path_Length_Map.png + :align: center + + **Path Length Map showing the shortest distance, along a streamline, + from the corpus callosum ROI with the background set to -1**. + +References +---------- + +.. [Jordan_2018_plm] Jordan K. et al., "An Open-Source Tool for Anisotropic +Radiation Therapy Planning in Neuro-oncology Using DW-MRI Tractography", +PREPRINT (biorxiv), 2018. + +.. include:: ../links_names.inc + +""" diff --git a/doc/examples/quick_start.py b/doc/examples/quick_start.py index c9ed78e725..f767bf4856 100644 --- a/doc/examples/quick_start.py +++ b/doc/examples/quick_start.py @@ -8,7 +8,7 @@ one with the b-vectors. In DIPY_ we provide tools to load and process these files and we also provide -access to publically available datasets for those who haven't acquired yet +access to publicly available datasets for those who haven't acquired yet their own datasets. With the following commands we can download a dMRI dataset @@ -76,7 +76,7 @@ print(data.shape) """ -``(128, 128, 60, 194)`` +``(128, 128, 60, 193)`` We can also check the dimensions of each voxel in the following way: """ diff --git a/doc/examples/reconst_dki.py b/doc/examples/reconst_dki.py index 955b759d36..d1267b2d03 100644 --- a/doc/examples/reconst_dki.py +++ b/doc/examples/reconst_dki.py @@ -329,8 +329,9 @@ AWF = dki_micro_fit.awf TORT = dki_micro_fit.tortuosity - -""" These parameters are plotted below on top of the mean kurtosis maps: """ +""" +These parameters are plotted below on top of the mean kurtosis maps: +""" fig3, ax = plt.subplots(1, 2, figsize=(9, 4), subplot_kw={'xticks': [], 'yticks': []}) diff --git a/doc/examples/reconst_dsid.py b/doc/examples/reconst_dsid.py index 84e37366e1..30b04edc5a 100644 --- a/doc/examples/reconst_dsid.py +++ b/doc/examples/reconst_dsid.py @@ -14,7 +14,7 @@ import numpy as np from dipy.sims.voxel import multi_tensor, multi_tensor_odf -from dipy.data import get_data, get_sphere +from dipy.data import get_fnames, get_sphere from dipy.core.gradients import gradient_table from dipy.reconst.dsi import (DiffusionSpectrumDeconvModel, DiffusionSpectrumModel) @@ -24,7 +24,7 @@ gradient directions and 1 S0. """ -btable = np.loadtxt(get_data('dsi515btable')) +btable = np.loadtxt(get_fnames('dsi515btable')) gtab = gradient_table(btable[:, 0], btable[:, 1:]) diff --git a/doc/examples/reconst_qtdmri.py b/doc/examples/reconst_qtdmri.py new file mode 100644 index 0000000000..d160d740a4 --- /dev/null +++ b/doc/examples/reconst_qtdmri.py @@ -0,0 +1,421 @@ +# -*- coding: utf-8 -*- +""" +================================================================ +Estimating diffusion time dependent q-space indices using qt-dMRI +================================================================ +Effective representation of the four-dimensional diffusion MRI signal -- +varying over three-dimensional q-space and diffusion time -- is a sought-after +and still unsolved challenge in diffusion MRI (dMRI). We propose a functional +basis approach that is specifically designed to represent the dMRI signal in +this qtau-space [Fick2017]_. Following recent terminology, we refer to our +qtau-functional basis as ``q$\tau$-dMRI''. We use GraphNet regularization -- +imposing both signal smoothness and sparsity -- to drastically reduce the +number of diffusion-weighted images (DWIs) that is needed to represent the dMRI +signal in the qtau-space. As the main contribution, q$\tau$-dMRI provides the +framework to -- without making biophysical assumptions -- represent the +q$\tau$-space signal and estimate time-dependent q-space indices +(q$\tau$-indices), providing a new means for studying diffusion in nervous +tissue. qtau-dMRI is the first of its kind in being specifically designed to +provide open interpretation of the qtau-diffusion signal. + +q$\tau$-dMRI can be seen as a time-dependent extension of the MAP-MRI +functional basis [Ozarslan2013]_, and all the previously proposed q-space +can be estimated for any diffusion time. These include rotationally +invariant quantities such as the Mean Squared Displacement (MSD), Q-space +Inverse Variance (QIV) and Return-To-Origin Probability (RTOP). Also +directional indices such as the Return To the Axis Probability (RTAP) and +Return To the Plane Probability (RTPP) are available, as well as the +Orientation Distribution Function (ODF). + +In this example we illustrate how to use the qtau-dMRI to estimate +time-dependent q-space indices from a qtau-acquisition of a mouse. + +First import the necessary modules: +""" + +from dipy.data.fetcher import (fetch_qtdMRI_test_retest_2subjects, + read_qtdMRI_test_retest_2subjects) +from dipy.reconst import qtdmri, dti +import matplotlib.pyplot as plt +import numpy as np + +""" +Download and read the data for this tutorial. + +qt-dMRI requires data with multiple gradient directions, gradient strength and +diffusion times. We will use the test-retest acquisitions of two mice that were +used in the test-retest study by [Fick2017]_. The data itself is freely +available and citeable at [Wassermann2017]_. +""" + +fetch_qtdMRI_test_retest_2subjects() +data, cc_masks, gtabs = read_qtdMRI_test_retest_2subjects() + +""" +data contains 4 qt-dMRI datasets of size [80, 160, 5, 515]. The first two are +the test-retest datasets of the first mouse and the second two are those of the +second mouse. cc_masks contains 4 corresponding binary masks for the corpus +callosum voxels in the middle slice that were used in the test-retest study. +Finally, gtab contains the qt-dMRI gradient tables for the DWIs in the dataset. + +The data consists of 515 DWIs, divided over 35 shells, with 7 "gradient +strength shells" up to 491 mT/m, 5 equally spaced "pulse separation shells" +(big_delta) between [10.8-20] ms and a pulse duration (small_delta) of 5ms. + +To visualize qt-dMRI acquisition schemes in an intuitive way, the qtdmri module +provides a visualization function to illustrate the relationship between +gradient strength (G), pulse separation (big_delta) and b-value: +""" + +plt.figure() +qtdmri.visualise_gradient_table_G_Delta_rainbow(gtabs[0]) +plt.savefig('qt-dMRI_acquisition_scheme.png') + +""" +.. figure:: qt-dMRI_acquisition_scheme.png + :align: center + +In the figure the dots represent measured DWIs in any direction, for a given +gradient strength and pulse separation. The background isolines represent the +corresponding b-values for different combinations of G and big_delta. + +Next, we visualize the middle slices of the test-retest data sets with their +corresponding masks. To better illustrate the white matter architecture in the +data, we calculate DTI's fractional anisotropy (FA) over the whole slice and +project the corpus callosum mask on the FA image.: +""" + +subplot_titles = ["Subject1 Test", "Subject1 Retest", + "Subject2 Test", "Subject2 Tetest"] +fig = plt.figure() +plt.subplots(nrows=2, ncols=2) +for i, (data_, mask_, gtab_) in enumerate(zip(data, cc_masks, gtabs)): + # take the middle slice + data_middle_slice = data_[:, :, 2] + mask_middle_slice = mask_[:, :, 2] + + # estimate fractional anisotropy (FA) for this slice + tenmod = dti.TensorModel(gtab_) + tenfit = tenmod.fit(data_middle_slice, data_middle_slice[..., 0] > 0) + fa = tenfit.fa + + # set mask color to green with 0.5 opacity as overlay + mask_template = np.zeros(np.r_[mask_middle_slice.shape, 4]) + mask_template[mask_middle_slice == 1] = np.r_[0., 1., 0., .5] + + # produce the FA images with corpus callosum masks. + plt.subplot(2, 2, 1 + i) + plt.title(subplot_titles[i], fontsize=15) + plt.imshow(fa, cmap='Greys_r', origin=True, interpolation='nearest') + plt.imshow(mask_template, origin=True, interpolation='nearest') + plt.axis('off') +plt.tight_layout() +plt.savefig('qt-dMRI_datasets_fa_with_ccmasks.png') + +""" +.. figure:: qt-dMRI_datasets_fa_with_ccmasks.png + : align: center + +Next, we use qt-dMRI to estimate of time-dependent q-space indices +(q$\tau$-indices) for the masked voxels in the corpus callosum of each dataset. +In particular, we estimate the Return-to-Original, Return-to-Axis and +Return-to-Plane Probability (RTOP, RTAP and RTPP), as well as the Mean Squared +Displacement (MSD). + +In this example we don't extrapolate the data beyond the maximum diffusion +time, so we estimate q$\tau$ indices between the minimum and maximum diffusion +times of the data at 5 equally spaced points. However, it should the noted that +qt-dMRI's combined smoothness and sparsity regularization allows for smooth +interpolation at any q$\tau$ position. In other words, once the basis is +fitted to the data, its coefficients describe the the entire q$\tau$-space, and +any q$\tau$-position can be freely recovered. This including points beyond the +dataset's maximum q/$\tau$ value (although this should be done with caution). +""" + +tau_min = gtabs[0].tau.min() +tau_max = gtabs[0].tau.max() +taus = np.linspace(tau_min, tau_max, 5) + +qtdmri_fits = [] +msds = [] +rtops = [] +rtaps = [] +rtpps = [] +for i, (data_, mask_, gtab_) in enumerate(zip(data, cc_masks, gtabs)): + # select the corpus callsoum voxel for every dataset + cc_voxels = data_[mask_ == 1] + # initialize the qt-dMRI model. + # recommended basis orders are radial_order=6 and time_order=2. + # The combined Laplacian and l1-regularization using Generalized + # Cross-Validation (GCV) and Cross-Validation (CV) settings is most robust, + # but can be used separately and with weightings preset to any positive + # value to optimize for speed. + qtdmri_mod = qtdmri.QtdmriModel( + gtab_, radial_order=6, time_order=2, + laplacian_regularization=True, laplacian_weighting='GCV', + l1_regularization=True, l1_weighting='CV' + ) + # fit the model. + # Here we take every 5th voxel for speed, but of course all voxels can be + # fit for a more robust result later on. + qtdmri_fit = qtdmri_mod.fit(cc_voxels[::5]) + qtdmri_fits.append(qtdmri_fit) + # We estimate MSD, RTOP, RTAP and RTPP for the chosen diffusion times. + msds.append(np.array(list(map(qtdmri_fit.msd, taus)))) + rtops.append(np.array(list(map(qtdmri_fit.rtop, taus)))) + rtaps.append(np.array(list(map(qtdmri_fit.rtap, taus)))) + rtpps.append(np.array(list(map(qtdmri_fit.rtpp, taus)))) + +""" +The estimated q$\tau$-indices, for the chosen diffusion times, are now stored +in msds, rtops, rtaps and rtpps. The trends of these q$\tau$-indices over time +say something about the restriction of diffusing particles over time, which +is currently a hot topic in the dMRI community. We evaluate the test-retest +reproducibility for the two subjects by plotting the q$\tau$-indices for each +subject together. This example will produce similar results as Fig. 10 in +[Fick2017]_. + +We first define a small function to plot the mean and standard deviation of the +q$\tau$-index trends in a subject. +""" + + +def plot_mean_with_std(ax, time, ind1, plotcolor, ls='-', std_mult=1, + label=''): + means = np.mean(ind1, axis=1) + stds = np.std(ind1, axis=1) + ax.plot(time, means, c=plotcolor, lw=3, label=label, ls=ls) + ax.fill_between(time, + means + std_mult * stds, + means - std_mult * stds, + alpha=0.15, color=plotcolor) + ax.plot(time, means + std_mult * stds, alpha=0.25, color=plotcolor) + ax.plot(time, means - std_mult * stds, alpha=0.25, color=plotcolor) + + +""" +We start by showing the test-retest MSD of both subjects over time. We plot the +q$\tau$-indices together with q$\tau$-index trends of free diffusion with +different diffusivities as background. +""" + +# we first generate the data to produce the background index isolines. +Delta_ = np.linspace(0.005, 0.02, 100) +MSD_ = np.linspace(4e-5, 10e-5, 100) +Delta_grid, MSD_grid = np.meshgrid(Delta_, MSD_) +D_grid = MSD_grid / (6 * Delta_grid) +D_levels = np.r_[1, 5, 7, 10, 14, 23, 30] * 1e-4 + +fig = plt.figure(figsize=(10, 3)) +# start with the plot of subject 1. +ax = plt.subplot(1, 2, 1) +# first plot the background +plt.contourf(Delta_ * 1e3, 1e5 * MSD_, D_grid, levels=D_levels, cmap='Greys', + alpha=.5) + +# plot the test-retest mean MSD and standard deviation of subject 1. +plot_mean_with_std(ax, taus * 1e3, 1e5 * msds[0], 'r', 'dashdot', + label='MSD Test') +plot_mean_with_std(ax, taus * 1e3, 1e5 * msds[1], 'g', 'dashdot', + label='MSD Retest') +ax.legend(fontsize=13) +# plot some text markers to clarify the background diffusivity lines. +ax.text(.0091 * 1e3, 6.33, 'D=14e-4', fontsize=12, rotation=35) +ax.text(.0091 * 1e3, 4.55, 'D=10e-4', fontsize=12, rotation=25) +ax.set_ylim(4, 9.5) +ax.set_xlim(.009 * 1e3, 0.0185 * 1e3) +ax.set_title(r'Test-Retest MSD($\tau$) Subject 1', fontsize=15) +ax.set_xlabel('Diffusion Time (ms)', fontsize=17) +ax.set_ylabel('MSD ($10^{-5}mm^2$)', fontsize=17) + +# then do the same thing for subject 2. +ax = plt.subplot(1, 2, 2) +plt.contourf(Delta_ * 1e3, 1e5 * MSD_, D_grid, levels=D_levels, cmap='Greys', + alpha=.5) +cb = plt.colorbar() +cb.set_label('Free Diffusivity ($mm^2/s$)', fontsize=18) + +plot_mean_with_std(ax, taus * 1e3, 1e5 * msds[2], 'r', 'dashdot') +plot_mean_with_std(ax, taus * 1e3, 1e5 * msds[3], 'g', 'dashdot') +ax.set_ylim(4, 9.5) +ax.set_xlim(.009 * 1e3, 0.0185 * 1e3) +ax.set_xlabel('Diffusion Time (ms)', fontsize=17) +ax.set_title(r'Test-Retest MSD($\tau$) Subject 2', fontsize=15) +plt.savefig('qt_indices_msd.png') + +""" +.. figure:: qt_indices_msd.png + : align: center + +You can see that the MSD in both subjects increases over time, but also slowly +levels off as time progresses. This makes sense as diffusing particles are +becoming more restricted by surrounding tissue as time goes on. You can also +see that for Subject 1 the index trends nearly perfectly overlap, but for +subject 2 they are slightly off, which is also what we found in the paper. + +Next, we follow the same procedure to estimate the test-retest RTAP, RTOP and +RTPP over diffusion time for both subject. For ease of comparison, we will +estimate all three in the same unit [1/mm] by taking the square root of RTAP +and the cubed root of RTOP. +""" + +# Again, first we define the data for the background illustration. +Delta_ = np.linspace(0.005, 0.02, 100) +RTXP_ = np.linspace(1, 200, 100) +Delta_grid, RTXP_grid = np.meshgrid(Delta_, RTXP_) +D_grid = 1 / (4 * RTXP_grid ** 2 * np.pi * Delta_grid) +D_levels = np.r_[1, 2, 3, 4, 6, 9, 15, 30] * 1e-4 +D_colors = np.tile(np.linspace(.8, 0, 7), (3, 1)).T + +# We start with estimating the RTOP illustration. +fig = plt.figure(figsize=(10, 3)) +ax = plt.subplot(1, 2, 1) +plt.contourf(Delta_ * 1e3, RTXP_, D_grid, colors=D_colors, levels=D_levels, + alpha=.5) + +plot_mean_with_std(ax, taus * 1e3, rtops[0] ** (1 / 3.), 'r', '--', + label='RTOP$^{1/3}$ Test') +plot_mean_with_std(ax, taus * 1e3, rtops[1] ** (1 / 3.), 'g', '--', + label='RTOP$^{1/3}$ Retest') +ax.legend(fontsize=13) +ax.text(.0091 * 1e3, 162, 'D=3e-4', fontsize=12, rotation=-22) +ax.text(.0091 * 1e3, 140, 'D=4e-4', fontsize=12, rotation=-20) +ax.text(.0091 * 1e3, 113, 'D=6e-4', fontsize=12, rotation=-16) +ax.set_ylim(54, 170) +ax.set_xlim(.009 * 1e3, 0.0185 * 1e3) +ax.set_title(r'Test-Retest RTOP($\tau$) Subject 1', fontsize=15) +ax.set_xlabel('Diffusion Time (ms)', fontsize=17) +ax.set_ylabel('RTOP$^{1/3}$ (1/mm)', fontsize=17) + +ax = plt.subplot(1, 2, 2) +plt.contourf(Delta_ * 1e3, RTXP_, D_grid, colors=D_colors, levels=D_levels, + alpha=.5) +cb = plt.colorbar() +cb.set_label('Free Diffusivity ($mm^2/s$)', fontsize=18) + +plot_mean_with_std(ax, taus * 1e3, rtops[2] ** (1 / 3.), 'r', '--') +plot_mean_with_std(ax, taus * 1e3, rtops[3] ** (1 / 3.), 'g', '--') +ax.set_ylim(54, 170) +ax.set_xlim(.009 * 1e3, 0.0185 * 1e3) +ax.set_xlabel('Diffusion Time (ms)', fontsize=17) +ax.set_title(r'Test-Retest RTOP($\tau$) Subject 2', fontsize=15) +plt.savefig('qt_indices_rtop.png') +""" +.. figure:: qt_indices_rtop.png + : align: center + +Similarly as MSD, the RTOP is related to the restriction that particles are +experiencing and is also rotationally invariant. RTOP is defined as the +probability that particles are found at the same position at the time of both +gradient pulses. As time increases, the odds become smaller that a particle +will arrive at the same position it left, which is illustrated by all RTOP +trends in the figure. Notice that the estimated RTOP trends decrease less fast +than free diffusion, meaning that particles experience restriction over time. +Also notice that the RTOP trends in both subjects nearly perfectly overlap. + +Next, we estimate two directional q$\tau$-indices, RTAP and RTPP, describing +particle restriction perpendicular and parallel to the orientation of the +principal diffusivity in that voxel. If the voxel describes coherent white +matter (which it does in our corpus callosum example), then they describe +properties related to restriction perpendicular and parallel to the axon +bundles. +""" + +# First, we estimate the RTAP trends. +fig = plt.figure(figsize=(10, 3)) +ax = plt.subplot(1, 2, 1) +plt.contourf(Delta_ * 1e3, RTXP_, D_grid, colors=D_colors, levels=D_levels, + alpha=.5) + +plot_mean_with_std(ax, taus * 1e3, np.sqrt(rtaps[0]), 'r', '-', + label='RTAP$^{1/2}$ Test') +plot_mean_with_std(ax, taus * 1e3, np.sqrt(rtaps[1]), 'g', '-', + label='RTAP$^{1/2}$ Retest') +ax.legend(fontsize=13) +ax.text(.0091 * 1e3, 162, 'D=3e-4', fontsize=12, rotation=-22) +ax.text(.0091 * 1e3, 140, 'D=4e-4', fontsize=12, rotation=-20) +ax.text(.0091 * 1e3, 113, 'D=6e-4', fontsize=12, rotation=-16) +ax.set_ylim(54, 170) +ax.set_xlim(.009 * 1e3, 0.0185 * 1e3) +ax.set_title(r'Test-Retest RTAP($\tau$) Subject 1', fontsize=15) +ax.set_xlabel('Diffusion Time (ms)', fontsize=17) +ax.set_ylabel('RTAP$^{1/2}$ (1/mm)', fontsize=17) + +ax = plt.subplot(1, 2, 2) +plt.contourf(Delta_ * 1e3, RTXP_, D_grid, colors=D_colors, levels=D_levels, + alpha=.5) +cb = plt.colorbar() +cb.set_label('Free Diffusivity ($mm^2/s$)', fontsize=18) + +plot_mean_with_std(ax, taus * 1e3, np.sqrt(rtaps[2]), 'r', '-') +plot_mean_with_std(ax, taus * 1e3, np.sqrt(rtaps[3]), 'g', '-') +ax.set_ylim(54, 170) +ax.set_xlim(.009 * 1e3, 0.0185 * 1e3) +ax.set_xlabel('Diffusion Time (ms)', fontsize=17) +ax.set_title(r'Test-Retest RTAP($\tau$) Subject 2', fontsize=15) +plt.savefig('qt_indices_rtap.png') + + +# Finally the last one for RTPP. +fig = plt.figure(figsize=(10, 3)) +ax = plt.subplot(1, 2, 1) +plt.contourf(Delta_ * 1e3, RTXP_, D_grid, colors=D_colors, levels=D_levels, + alpha=.5) + +plot_mean_with_std(ax, taus * 1e3, rtpps[0], 'r', ':', label='RTPP Test') +plot_mean_with_std(ax, taus * 1e3, rtpps[1], 'g', ':', label='RTPP Retest') +ax.legend(fontsize=13) +ax.text(.0091 * 1e3, 113, 'D=6e-4', fontsize=12, rotation=-16) +ax.text(.0091 * 1e3, 91, 'D=9e-4', fontsize=12, rotation=-13) +ax.text(.0091 * 1e3, 69, 'D=15e-4', fontsize=12, rotation=-10) +ax.set_ylim(54, 170) +ax.set_xlim(.009 * 1e3, 0.0185 * 1e3) +ax.set_title(r'Test-Retest RTPP($\tau$) Subject 1', fontsize=15) +ax.set_xlabel('Diffusion Time (ms)', fontsize=17) +ax.set_ylabel('RTPP (1/mm)', fontsize=17) + +ax = plt.subplot(1, 2, 2) +plt.contourf(Delta_ * 1e3, RTXP_, D_grid, colors=D_colors, levels=D_levels, + alpha=.5) +cb = plt.colorbar() +cb.set_label('Free Diffusivity ($mm^2/s$)', fontsize=18) + +plot_mean_with_std(ax, taus * 1e3, rtpps[2], 'r', ':') +plot_mean_with_std(ax, taus * 1e3, rtpps[3], 'g', ':') +ax.set_ylim(54, 170) +ax.set_xlim(.009 * 1e3, 0.0185 * 1e3) +ax.set_xlabel('Diffusion Time (ms)', fontsize=17) +ax.set_title(r'Test-Retest RTPP($\tau$) Subject 2', fontsize=15) +plt.savefig('qt_indices_rtpp.png') +""" +.. figure:: qt_indices_rtap.png + : align: center +.. figure:: qt_indices_rtpp.png + : align: center + +As those of RTOP, the trends in RTAP and RTPP also decrease over time. It can +be seen that RTAP$^{1/2}$ is always bigger than RTPP, which makes sense as +particles in coherent white matter experience more restriction perpendicular to +the white matter orientation than parallel to it. Again, in both subjects the +test-retest RTAP and RTPP is nearly perfectly consistent. + +Aside from the estimation of q$\tau$-space indices, q$\tau$-dMRI also allows +for the estimation of time-dependent ODFs. Once the Qtdmri model is fitted +it can be simply called by qtdmri_fit.odf(sphere, s=sharpening_factor). This +is identical to how the mapmri module functions, and allows to study the +time-dependence of ODF directionallity. + +This concludes the example on qt-dMRI. As we showed, approaches such as qt-dMRI +can help in studying the (finite-$\tau$) temporal properties of diffusion in +biological tissues. Differences in q$\tau$-index trends could be indicative +of underlying structural differences that affect the time-dependence of the +diffusion process. + +.. [Fick2017]_ Fick, Rutger HJ, et al. "Non-Parametric GraphNet-Regularized + Representation of dMRI in Space and Time", Medical Image Analysis, + 2017. +.. [Wassermann2017]_ Wassermann, Demian, et al. "Test-Retest qt-dMRI datasets + for 'Non-Parametric GraphNet-Regularized Representation of dMRI in + Space and Time' [Data set]". Zenodo. + http://doi.org/10.5281/zenodo.996889, 2017. +""" diff --git a/doc/examples/register_binary_fuzzy.py b/doc/examples/register_binary_fuzzy.py new file mode 100644 index 0000000000..758285f4a3 --- /dev/null +++ b/doc/examples/register_binary_fuzzy.py @@ -0,0 +1,171 @@ +""" +========================================================= +Diffeomorphic Registration with binary and fuzzy images +========================================================= + +This example demonstrates registration of a binary and a fuzzy image. +This could be seen as aligning a fuzzy (sensed) image to a binary +(e.g., template) image. +""" + +import numpy as np +import matplotlib.pyplot as plt +from skimage import draw, filters +from dipy.align.imwarp import SymmetricDiffeomorphicRegistration +from dipy.align.metrics import SSDMetric +from dipy.viz import regtools + +""" +Let's generate a sample template image as the combination of three ellipses. +We will generate the fuzzy (sensed) version of the image by smoothing +the reference image. +""" + + +def draw_ellipse(img, center, axis): + rr, cc = draw.ellipse(center[0], center[1], axis[0], axis[1], + shape=img.shape) + img[rr, cc] = 1 + return img + + +img_ref = np.zeros((64, 64)) +img_ref = draw_ellipse(img_ref, (25, 15), (10, 5)) +img_ref = draw_ellipse(img_ref, (20, 45), (15, 10)) +img_ref = draw_ellipse(img_ref, (50, 40), (7, 15)) + +img_in = filters.gaussian(img_ref, sigma=3) + +""" +Let's define a small visualization function. +""" + + +def show_images(img_ref, img_warp, fig_name): + fig, axarr = plt.subplots(ncols=2, figsize=(12, 5)) + axarr[0].set_title('warped image & reference contour') + axarr[0].imshow(img_warp) + axarr[0].contour(img_ref, colors='r') + ssd = np.sum((img_warp - img_ref) ** 2) + axarr[1].set_title('difference, SSD=%.02f' % ssd) + im = axarr[1].imshow(img_warp - img_ref) + plt.colorbar(im) + fig.tight_layout() + fig.savefig(fig_name + '.png') + + +show_images(img_ref, img_in, 'input') + +""" +.. figure:: input.png + :align: center + + Input images before alignment. +""" + +""" +Let's the use the general Registration function with some naive parameters, +such as set `step_length` as 1 assuming maximal step 1 pixel and reasonable +small number of iteration since the deformation with already aligned images +should be minimal. +""" + +sdr = SymmetricDiffeomorphicRegistration(metric=SSDMetric(img_ref.ndim), + step_length=1.0, + level_iters=[50, 100], + inv_iter=50, + ss_sigma_factor=0.1, + opt_tol=1.e-3) + +""" +Perform the registration with equal images. +""" + +mapping = sdr.optimize(img_ref.astype(float), img_ref.astype(float)) +img_warp = mapping.transform(img_ref, 'linear') +show_images(img_ref, img_warp, 'output-0') +regtools.plot_2d_diffeomorphic_map(mapping, 5, 'map-0.png') + +""" +.. figure:: output-0.png + :align: center +.. figure:: map-0.png + :align: center + + Registration results for default parameters and equal images. +""" + +""" +Perform the registration with binary and fuzzy images. +""" + +mapping = sdr.optimize(img_ref.astype(float), img_in.astype(float)) +img_warp = mapping.transform(img_in, 'linear') +show_images(img_ref, img_warp, 'output-1') +regtools.plot_2d_diffeomorphic_map(mapping, 5, 'map-1.png') + +""" +.. figure:: output-1.png + :align: center +.. figure:: map-1.png + :align: center + + Registration results for a naive parameter configuration. +""" + +""" +Note, we are still using multi-scale approach which makes `step_length` +in the upper level multiplicatively larger. +What happens if we set `step_length` to a rather small value? +""" + +sdr.step_length = 0.1 + +""" +Perform the registration and examine the output. +""" + +mapping = sdr.optimize(img_ref.astype(float), img_in.astype(float)) +img_warp = mapping.transform(img_in, 'linear') +show_images(img_ref, img_warp, 'output-2') +regtools.plot_2d_diffeomorphic_map(mapping, 5, 'map-2.png') + +""" +.. figure:: output-2.png + :align: center +.. figure:: map-2.png + :align: center + + Registration results for decreased step size. +""" + +""" +An alternative scenario is to use just a single scale level. +Even though the warped image may look fine, the estimated deformations show +that it is off the mark. +""" + +sdr = SymmetricDiffeomorphicRegistration(metric=SSDMetric(img_ref.ndim), + step_length=1.0, + level_iters=[100], + inv_iter=50, + ss_sigma_factor=0.1, + opt_tol=1.e-3) + +""" +Perform the registration. +""" + +mapping = sdr.optimize(img_ref.astype(float), img_in.astype(float)) +img_warp = mapping.transform(img_in, 'linear') +show_images(img_ref, img_warp, 'output-3') +regtools.plot_2d_diffeomorphic_map(mapping, 5, 'map-3.png') + +""" +.. figure:: output-3.png + :align: center +.. figure:: map-3.png + :align: center + + Registration results for single level. +""" diff --git a/doc/examples/reslice_datasets.py b/doc/examples/reslice_datasets.py index 50e0f5dc3c..f19217bab9 100644 --- a/doc/examples/reslice_datasets.py +++ b/doc/examples/reslice_datasets.py @@ -20,14 +20,14 @@ """ from dipy.align.reslice import reslice -from dipy.data import get_data +from dipy.data import get_fnames """ We use here a very small dataset to show the basic principles but you can replace the following line with the path of your image. """ -fimg = get_data('aniso_vox') +fimg = get_fnames('aniso_vox') """ We load the image and print the shape of the volume diff --git a/doc/examples/segment_clustering_features.py b/doc/examples/segment_clustering_features.py index b7bcd7e807..5821b8a627 100644 --- a/doc/examples/segment_clustering_features.py +++ b/doc/examples/segment_clustering_features.py @@ -22,9 +22,9 @@ def get_streamlines(): from nibabel import trackvis as tv - from dipy.data import get_data + from dipy.data import get_fnames - fname = get_data('fornix') + fname = get_fnames('fornix') streams, hdr = tv.read(fname) streamlines = [i[0] for i in streams] return streamlines diff --git a/doc/examples/segment_clustering_metrics.py b/doc/examples/segment_clustering_metrics.py index 780699aeeb..b771eb2a09 100644 --- a/doc/examples/segment_clustering_metrics.py +++ b/doc/examples/segment_clustering_metrics.py @@ -22,9 +22,9 @@ def get_streamlines(): from nibabel import trackvis as tv - from dipy.data import get_data + from dipy.data import get_fnames - fname = get_data('fornix') + fname = get_fnames('fornix') streams, hdr = tv.read(fname) streamlines = [i[0] for i in streams] return streamlines diff --git a/doc/examples/segment_extending_clustering_framework.py b/doc/examples/segment_extending_clustering_framework.py index f06fa1f8ae..50f466b832 100644 --- a/doc/examples/segment_extending_clustering_framework.py +++ b/doc/examples/segment_extending_clustering_framework.py @@ -103,10 +103,10 @@ def extract(self, streamline): import numpy as np from nibabel import trackvis as tv -from dipy.data import get_data +from dipy.data import get_fnames from dipy.viz import window, actor -fname = get_data('fornix') +fname = get_fnames('fornix') streams, hdr = tv.read(fname) streamlines = [i[0] for i in streams] @@ -210,10 +210,10 @@ def dist(self, v1, v2): import numpy as np from nibabel import trackvis as tv -from dipy.data import get_data +from dipy.data import get_fnames from dipy.viz import window, actor -fname = get_data('fornix') +fname = get_fnames('fornix') streams, hdr = tv.read(fname) streamlines = [i[0] for i in streams] diff --git a/doc/examples/segment_quickbundles.py b/doc/examples/segment_quickbundles.py index c7682fe57d..e2ced6aede 100644 --- a/doc/examples/segment_quickbundles.py +++ b/doc/examples/segment_quickbundles.py @@ -14,7 +14,7 @@ from dipy.tracking.streamline import Streamlines from dipy.segment.clustering import QuickBundles from dipy.io.pickles import save_pickle -from dipy.data import get_data +from dipy.data import get_fnames from dipy.viz import window, actor """ @@ -22,7 +22,7 @@ from neuroanatomy as the fornix. """ -fname = get_data('fornix') +fname = get_fnames('fornix') """ Load fornix streamlines. diff --git a/doc/examples/sfm_tracking.py b/doc/examples/sfm_tracking.py index de06c7a876..018aff517e 100644 --- a/doc/examples/sfm_tracking.py +++ b/doc/examples/sfm_tracking.py @@ -107,15 +107,14 @@ subject's T1-weighted anatomy: """ -from dipy.viz import window, actor -from dipy.viz.colormap import line_colors +from dipy.viz import window, actor, colormap as cmap from dipy.data import read_stanford_t1 from dipy.tracking.utils import move_streamlines from numpy.linalg import inv t1 = read_stanford_t1() t1_data = t1.get_data() t1_aff = t1.affine -color = line_colors(streamlines) +color = cmap.line_colors(streamlines) # Enables/disables interactive visualization interactive = False @@ -132,7 +131,7 @@ streamlines_actor = actor.streamtube( list(move_streamlines(plot_streamlines, inv(t1_aff))), - line_colors(streamlines), linewidth=0.1) + cmap.line_colors(streamlines), linewidth=0.1) vol_actor = actor.slicer(t1_data) diff --git a/doc/examples/simulate_dki.py b/doc/examples/simulate_dki.py index c2c98f121c..435090d0f1 100644 --- a/doc/examples/simulate_dki.py +++ b/doc/examples/simulate_dki.py @@ -20,7 +20,7 @@ import numpy as np import matplotlib.pyplot as plt from dipy.sims.voxel import (multi_tensor_dki, single_tensor) -from dipy.data import get_data +from dipy.data import get_fnames from dipy.io.gradients import read_bvals_bvecs from dipy.core.gradients import gradient_table from dipy.reconst.dti import (decompose_tensor, from_lower_triangular) @@ -31,7 +31,7 @@ ``small_64D``. """ -fimg, fbvals, fbvecs = get_data('small_64D') +fimg, fbvals, fbvecs = get_fnames('small_64D') bvals, bvecs = read_bvals_bvecs(fbvals, fbvecs) """ diff --git a/doc/examples/simulate_multi_tensor.py b/doc/examples/simulate_multi_tensor.py index 5070cea270..5a217fc543 100644 --- a/doc/examples/simulate_multi_tensor.py +++ b/doc/examples/simulate_multi_tensor.py @@ -8,10 +8,7 @@ """ import numpy as np -from dipy.sims.voxel import (multi_tensor, - multi_tensor_odf, - single_tensor_odf, - all_tensor_evecs) +from dipy.sims.voxel import multi_tensor, multi_tensor_odf from dipy.data import get_sphere """ diff --git a/doc/examples/streamline_formats.py b/doc/examples/streamline_formats.py index c758374ac9..0af19d7e0f 100644 --- a/doc/examples/streamline_formats.py +++ b/doc/examples/streamline_formats.py @@ -8,14 +8,15 @@ ======== DIPY_ can read and write many different file formats. In this example -we give a short introduction on how to use it for loading or saving streamlines. +we give a short introduction on how to use it for loading or saving +streamlines. Read :ref:`faq` """ import numpy as np -from dipy.data import get_data +from dipy.data import get_fnames from dipy.io.streamline import load_trk, save_trk from dipy.tracking.streamline import Streamlines @@ -23,7 +24,7 @@ 1. Read/write streamline files with DIPY. """ -fname = get_data('fornix') +fname = get_fnames('fornix') print(fname) # Read Streamlines @@ -35,8 +36,9 @@ """ -2. We also work on our HDF5 based file format which can read/write massive datasets - (as big as the size of you free disk space). With `Dpy` we can support +2. We also work on our HDF5 based file format which can read/write massive + datasets (as big as the size of your free disk space). With `Dpy` we can + support * direct indexing from the disk * memory usage always low diff --git a/doc/examples/streamline_tools.py b/doc/examples/streamline_tools.py index a728e69c45..1a0a9b002b 100644 --- a/doc/examples/streamline_tools.py +++ b/doc/examples/streamline_tools.py @@ -96,15 +96,15 @@ region near the center of the axial image. """ -from dipy.viz import window, actor -from dipy.viz.colormap import line_colors +from dipy.viz import window, actor, colormap as cmap # Enables/disables interactive visualization interactive = False # Make display objects -color = line_colors(cc_streamlines) -cc_streamlines_actor = actor.line(cc_streamlines, line_colors(cc_streamlines)) +color = cmap.line_colors(cc_streamlines) +cc_streamlines_actor = actor.line(cc_streamlines, + cmap.line_colors(cc_streamlines)) cc_ROI_actor = actor.contour_from_roi(cc_slice, color=(1., 1., 0.), opacity=0.5) diff --git a/doc/examples/syn_registration_2d.py b/doc/examples/syn_registration_2d.py index dabe98017c..717b5cd7d5 100644 --- a/doc/examples/syn_registration_2d.py +++ b/doc/examples/syn_registration_2d.py @@ -10,15 +10,15 @@ """ import numpy as np -from dipy.data import get_data +from dipy.data import get_fnames from dipy.align.imwarp import SymmetricDiffeomorphicRegistration from dipy.align.metrics import SSDMetric, CCMetric, EMMetric import dipy.align.imwarp as imwarp from dipy.viz import regtools -fname_moving = get_data('reg_o') -fname_static = get_data('reg_c') +fname_moving = get_fnames('reg_o') +fname_static = get_fnames('reg_c') moving = np.load(fname_moving) static = np.load(fname_static) @@ -143,7 +143,7 @@ def callback_CC(sdr, status): fetch_syn_data() t1, b0 = read_syn_data() -data = np.array(b0.get_data(), dtype = np.float64) +data = np.array(b0.get_data(), dtype=np.float64) """ We first remove the skull from the b0 volume diff --git a/doc/examples/tracking_bootstrap_peaks.py b/doc/examples/tracking_bootstrap_peaks.py index 94d017553e..09dd3a7f1e 100644 --- a/doc/examples/tracking_bootstrap_peaks.py +++ b/doc/examples/tracking_bootstrap_peaks.py @@ -5,8 +5,9 @@ This example shows how choices in direction-getter impact fiber tracking results by demonstrating the bootstrap direction getter (a type of -probabilistic tracking) and the closest peak direction getter (a type of -deterministic tracking). +probabilistic tracking, as described in [Berman2008]_) and the closest peak +direction getter (a type of deterministic tracking). +(Amirbekian, PhD thesis, 2016) Let's load the necessary modules for executing this tutorial. """ @@ -15,8 +16,7 @@ from dipy.tracking import utils from dipy.tracking.local import (ThresholdTissueClassifier, LocalTracking) from dipy.io.trackvis import save_trk -from dipy.viz import window, actor -from dipy.viz.colormap import line_colors +from dipy.viz import window, actor, colormap as cmap renderer = window.Renderer() @@ -34,7 +34,7 @@ hardi_img, gtab, labels_img = read_stanford_labels() data = hardi_img.get_data() labels = labels_img.get_data() -affine = hardi_img.get_affine() +affine = hardi_img.affine seed_mask = labels == 2 white_matter = (labels == 1) | (labels == 2) @@ -77,7 +77,7 @@ streamlines = Streamlines(boot_streamline_generator) renderer.clear() -renderer.add(actor.line(streamlines, line_colors(streamlines))) +renderer.add(actor.line(streamlines, cmap.line_colors(streamlines))) window.record(renderer, out_path='bootstrap_dg_CSD.png', size=(600, 600)) """ @@ -108,7 +108,7 @@ streamlines = Streamlines(peak_streamline_generator) renderer.clear() -renderer.add(actor.line(streamlines, line_colors(streamlines))) +renderer.add(actor.line(streamlines, cmap.line_colors(streamlines))) window.record(renderer, out_path='closest_peak_dg_CSD.png', size=(600, 600)) """ @@ -127,11 +127,9 @@ save_trk("closest_peak_dg_CSD.trk", streamlines, affine, labels.shape) """ -.. [Berman_boot] Berman, J. et al. Probabilistic streamline q-ball -tractography using the residual bootstrap - -.. [Jeurissen_boot] Jeurissen, B. et al. Probabilistic fiber tracking -using the residual bootstrap with constrained spherical deconvolution. +.. [Berman2008] Berman, J. et al., Probabilistic streamline q-ball +tractography using the residual bootstrap, NeuroImage, vol 39, no 1, 2008 +.. include:: ../links_names.inc """ diff --git a/doc/examples/tracking_eudx_odf.py b/doc/examples/tracking_eudx_odf.py index 43d5039678..33e0b46d44 100644 --- a/doc/examples/tracking_eudx_odf.py +++ b/doc/examples/tracking_eudx_odf.py @@ -13,9 +13,9 @@ In this example we do deterministic fiber tracking on fields of ODF peaks. EuDX [Garyfallidis12]_ will be used for this. -This example requires importing example `reconst_csa.py` in order to run. EuDX was -primarily made with cpu efficiency in mind. The main idea can be used with any -model that is a child of OdfModel. +This example requires importing example `reconst_csa.py` in order to run. EuDX +was primarily made with cpu efficiency in mind. The main idea can be used with +any model that is a child of OdfModel. """ @@ -23,9 +23,9 @@ import numpy as np """ -This time we will not use FA as input to EuDX but we will use GFA (generalized FA), -which is more suited for ODF functions. Tracking will stop when GFA is less -than 0.2. +This time we will not use FA as input to EuDX but we will use GFA +(generalized FA), which is more suited for ODF functions. Tracking will stop +when GFA is less than 0.2. """ from dipy.tracking.eudx import EuDX @@ -62,15 +62,14 @@ Visualize the streamlines with `dipy.viz` module (python vtk is required). """ -from dipy.viz import window, actor -from dipy.viz.colormap import line_colors +from dipy.viz import window, actor, colormap as cmap # Enables/disables interactive visualization interactive = False ren = window.Renderer() -ren.add(actor.line(csa_streamlines, line_colors(csa_streamlines))) +ren.add(actor.line(csa_streamlines, cmap.line_colors(csa_streamlines))) print('Saving illustration as tensor_tracks.png') @@ -84,8 +83,8 @@ Deterministic streamlines with EuDX on ODF peaks field modulated by GFA. -It is also possible to use EuDX with multiple ODF peaks, which is very helpful when -tracking in crossing areas. +It is also possible to use EuDX with multiple ODF peaks, which is very helpful +when tracking in crossing areas. """ eu = EuDX(csapeaks.peak_values, @@ -99,7 +98,8 @@ window.clear(ren) -ren.add(actor.line(csa_streamlines_mult_peaks, line_colors(csa_streamlines_mult_peaks))) +ren.add(actor.line(csa_streamlines_mult_peaks, + cmap.line_colors(csa_streamlines_mult_peaks))) print('Saving illustration as csa_tracking_mpeaks.png') diff --git a/doc/examples/tracking_eudx_tensor.py b/doc/examples/tracking_eudx_tensor.py index bdc31de671..af37533e55 100644 --- a/doc/examples/tracking_eudx_tensor.py +++ b/doc/examples/tracking_eudx_tensor.py @@ -71,7 +71,8 @@ from dipy.tracking.eudx import EuDX from dipy.tracking.streamline import Streamlines -eu = EuDX(FA.astype('f8'), peak_indices, seeds=50000, odf_vertices = sphere.vertices, a_low=0.2) +eu = EuDX(FA.astype('f8'), peak_indices, seeds=50000, + odf_vertices=sphere.vertices, a_low=0.2) tensor_streamlines = Streamlines(eu) @@ -101,7 +102,7 @@ try: from dipy.viz import window, actor except ImportError: - raise ImportError('Python vtk module is not installed') + raise ImportError('Python fury module is not installed') import sys sys.exit() @@ -115,14 +116,15 @@ Every streamline will be coloured according to its orientation """ -from dipy.viz.colormap import line_colors +from dipy.viz import colormap as cmap """ `actor.line` creates a streamline actor for streamline visualization and `ren.add` adds this actor to the scene """ -ren.add(actor.streamtube(tensor_streamlines, line_colors(tensor_streamlines))) +ren.add(actor.streamtube(tensor_streamlines, + cmap.line_colors(tensor_streamlines))) print('Saving illustration as tensor_tracks.png') diff --git a/doc/examples/tracking_tissue_classifier.py b/doc/examples/tracking_tissue_classifier.py index d87393a3e9..4b62c7d3a5 100644 --- a/doc/examples/tracking_tissue_classifier.py +++ b/doc/examples/tracking_tissue_classifier.py @@ -35,8 +35,7 @@ from dipy.tracking.local import LocalTracking from dipy.tracking.streamline import Streamlines from dipy.tracking import utils -from dipy.viz import window, actor -from dipy.viz.colormap import line_colors +from dipy.viz import window, actor, colormap as cmap, have_fury # Enables/disables interactive visualization interactive = False @@ -127,9 +126,9 @@ streamlines = Streamlines(all_streamlines_threshold_classifier) -if window.have_vtk: +if have_fury: window.clear(ren) - ren.add(actor.line(streamlines, line_colors(streamlines))) + ren.add(actor.line(streamlines, cmap.line_colors(streamlines))) window.record(ren, out_path='all_streamlines_threshold_classifier.png', size=(600, 600)) if interactive: @@ -172,8 +171,9 @@ plt.xticks([]) plt.yticks([]) fig.tight_layout() -plt.imshow(white_matter[:, :, data.shape[2] // 2].T, cmap='gray', origin='lower', - interpolation='nearest') +plt.imshow(white_matter[:, :, data.shape[2] // 2].T, cmap='gray', + origin='lower', interpolation='nearest') + fig.savefig('white_matter_mask.png') """ @@ -197,9 +197,9 @@ streamlines = Streamlines(all_streamlines_binary_classifier) -if window.have_vtk: +if have_fury: window.clear(ren) - ren.add(actor.line(streamlines, line_colors(streamlines))) + ren.add(actor.line(streamlines, cmap.line_colors(streamlines))) window.record(ren, out_path='all_streamlines_binary_classifier.png', size=(600, 600)) if interactive: @@ -218,9 +218,9 @@ Anatomically-constrained tractography (ACT) [Smith2012]_ uses information from anatomical images to determine when the tractography stops. The ``include_map`` defines when the streamline reached a 'valid' stopping region (e.g. gray -matter partial volume estimation (PVE) map) and the ``exclude_map`` defines when -the streamline reached an 'invalid' stopping region (e.g. corticospinal fluid -PVE map). The background of the anatomical image should be added to the +matter partial volume estimation (PVE) map) and the ``exclude_map`` defines +when the streamline reached an 'invalid' stopping region (e.g. corticospinal +fluid PVE map). The background of the anatomical image should be added to the ``include_map`` to keep streamlines exiting the brain (e.g. through the brain stem). The ACT tissue classifier uses a trilinear interpolation at the tracking position. @@ -257,13 +257,15 @@ plt.subplot(121) plt.xticks([]) plt.yticks([]) -plt.imshow(include_map[:, :, data.shape[2] // 2].T, cmap='gray', origin='lower', - interpolation='nearest') +plt.imshow(include_map[:, :, data.shape[2] // 2].T, cmap='gray', + origin='lower', interpolation='nearest') + plt.subplot(122) plt.xticks([]) plt.yticks([]) -plt.imshow(exclude_map[:, :, data.shape[2] // 2].T, cmap='gray', origin='lower', - interpolation='nearest') +plt.imshow(exclude_map[:, :, data.shape[2] // 2].T, cmap='gray', + origin='lower', interpolation='nearest') + fig.tight_layout() fig.savefig('act_maps.png') @@ -288,9 +290,9 @@ streamlines = Streamlines(all_streamlines_act_classifier) -if window.have_vtk: +if have_fury: window.clear(ren) - ren.add(actor.line(streamlines, line_colors(streamlines))) + ren.add(actor.line(streamlines, cmap.line_colors(streamlines))) window.record(ren, out_path='all_streamlines_act_classifier.png', size=(600, 600)) if interactive: @@ -317,9 +319,9 @@ streamlines = Streamlines(valid_streamlines_act_classifier) -if window.have_vtk: +if have_fury: window.clear(ren) - ren.add(actor.line(streamlines, line_colors(streamlines))) + ren.add(actor.line(streamlines, cmap.line_colors(streamlines))) window.record(ren, out_path='valid_streamlines_act_classifier.png', size=(600, 600)) if interactive: @@ -336,10 +338,10 @@ """ The threshold and binary tissue classifiers use respectively a scalar map and a binary mask to stop the tracking. The ACT tissue classifier use partial volume -fraction (PVE) maps from an anatomical image to stop the tracking. Additionally, -the ACT tissue classifier determines if the tracking stopped in expected regions -(e.g. gray matter) and allows the user to get only streamlines stopping in those -regions. +fraction (PVE) maps from an anatomical image to stop the tracking. +Additionally, the ACT tissue classifier determines if the tracking stopped in +expected regions (e.g. gray matter) and allows the user to get only +streamlines stopping in those regions. Notes ------ diff --git a/doc/examples/valid_examples.txt b/doc/examples/valid_examples.txt index 0b62b9a485..851ee5ddaf 100644 --- a/doc/examples/valid_examples.txt +++ b/doc/examples/valid_examples.txt @@ -15,6 +15,7 @@ reconst_dsid.py reconst_ivim.py reconst_mapmri.py + reconst_qtdmri.py kfold_xval.py reslice_datasets.py segment_quickbundles.py @@ -61,3 +62,8 @@ viz_surfaces.py viz_roi_contour.py viz_ui.py + register_binary_fuzzy.py + bundle_extraction.py + viz_timers.py + cluster_confidence.py + path_length_map.py diff --git a/doc/examples/viz_advanced.py b/doc/examples/viz_advanced.py index d73b17c920..d039faefe5 100644 --- a/doc/examples/viz_advanced.py +++ b/doc/examples/viz_advanced.py @@ -6,9 +6,8 @@ In DIPY_ we created a thin interface to access many of the capabilities available in the Visualization Toolkit framework (VTK) but tailored to the needs of structural and diffusion imaging. Initially the 3D visualization -module was named ``fvtk``, meaning functions using vtk. This is still available -for backwards compatibility but now there is a more comprehensive way to access -the main functions using the following modules. +module was named ``fvtk``, meaning functions using vtk. This is not available +anymore. """ import numpy as np @@ -171,54 +170,32 @@ """ -def change_slice_z(i_ren, obj, slider): +def change_slice_z(slider): z = int(np.round(slider.value)) image_actor_z.display_extent(0, shape[0] - 1, 0, shape[1] - 1, z, z) -def change_slice_x(i_ren, obj, slider): +def change_slice_x(slider): x = int(np.round(slider.value)) image_actor_x.display_extent(x, x, 0, shape[1] - 1, 0, shape[2] - 1) -def change_slice_y(i_ren, obj, slider): +def change_slice_y(slider): y = int(np.round(slider.value)) image_actor_y.display_extent(0, shape[0] - 1, y, y, 0, shape[2] - 1) -def change_opacity(i_ren, obj, slider): +def change_opacity(slider): slicer_opacity = slider.value image_actor_z.opacity(slicer_opacity) image_actor_x.opacity(slicer_opacity) image_actor_y.opacity(slicer_opacity) -line_slider_z.add_callback(line_slider_z.slider_disk, - "MouseMoveEvent", - change_slice_z) -line_slider_z.add_callback(line_slider_z.slider_line, - "LeftButtonPressEvent", - change_slice_z) - -line_slider_x.add_callback(line_slider_x.slider_disk, - "MouseMoveEvent", - change_slice_x) -line_slider_x.add_callback(line_slider_x.slider_line, - "LeftButtonPressEvent", - change_slice_x) - -line_slider_y.add_callback(line_slider_y.slider_disk, - "MouseMoveEvent", - change_slice_y) -line_slider_y.add_callback(line_slider_y.slider_line, - "LeftButtonPressEvent", - change_slice_y) - -opacity_slider.add_callback(opacity_slider.slider_disk, - "MouseMoveEvent", - change_opacity) -opacity_slider.add_callback(opacity_slider.slider_line, - "LeftButtonPressEvent", - change_opacity) + +line_slider_z.on_change = change_slice_z +line_slider_x.on_change = change_slice_x +line_slider_y.on_change = change_slice_y +opacity_slider.on_change = change_opacity """ We'll also create text labels to identify the sliders. """ @@ -233,8 +210,7 @@ def build_label(text): label.bold = False label.italic = False label.shadow = False - label.actor.GetTextProperty().SetBackgroundColor(0, 0, 0) - label.actor.GetTextProperty().SetBackgroundOpacity(0.0) + label.background = (0, 0, 0) label.color = (1, 1, 1) return label @@ -250,20 +226,20 @@ def build_label(text): """ -panel = ui.Panel2D(center=(1030, 120), - size=(300, 200), +panel = ui.Panel2D(size=(300, 200), color=(1, 1, 1), opacity=0.1, align="right") - -panel.add_element(line_slider_label_x, 'relative', (0.1, 0.75)) -panel.add_element(line_slider_x, 'relative', (0.65, 0.8)) -panel.add_element(line_slider_label_y, 'relative', (0.1, 0.55)) -panel.add_element(line_slider_y, 'relative', (0.65, 0.6)) -panel.add_element(line_slider_label_z, 'relative', (0.1, 0.35)) -panel.add_element(line_slider_z, 'relative', (0.65, 0.4)) -panel.add_element(opacity_slider_label, 'relative', (0.1, 0.15)) -panel.add_element(opacity_slider, 'relative', (0.65, 0.2)) +panel.center = (1030, 120) + +panel.add_element(line_slider_label_x, (0.1, 0.75)) +panel.add_element(line_slider_x, (0.38, 0.75)) +panel.add_element(line_slider_label_y, (0.1, 0.55)) +panel.add_element(line_slider_y, (0.38, 0.55)) +panel.add_element(line_slider_label_z, (0.1, 0.35)) +panel.add_element(line_slider_z, (0.38, 0.35)) +panel.add_element(opacity_slider_label, (0.1, 0.15)) +panel.add_element(opacity_slider, (0.38, 0.15)) show_m.ren.add(panel) diff --git a/doc/examples/viz_bundles.py b/doc/examples/viz_bundles.py index ee35c0e0c2..1a213c3d03 100644 --- a/doc/examples/viz_bundles.py +++ b/doc/examples/viz_bundles.py @@ -113,8 +113,8 @@ renderer.clear() -hue = [0.0, 0.0] # red only -saturation = [0.0, 1.0] # white to red +hue = (0.0, 0.0) # red only +saturation = (0.0, 1.0) # white to red lut_cmap = actor.colormap_lookup_table(hue_range=hue, saturation_range=saturation) @@ -170,8 +170,8 @@ lengths = length(bundle_native) -hue = [0.5, 0.5] # red only -saturation = [0.0, 1.0] # black to white +hue = (0.5, 0.5) # blue only +saturation = (0.0, 1.0) # black to white lut_cmap = actor.colormap_lookup_table( scale_range=(lengths.min(), lengths.max()), diff --git a/doc/examples/viz_roi_contour.py b/doc/examples/viz_roi_contour.py index 1554f3be95..9cac7c3459 100644 --- a/doc/examples/viz_roi_contour.py +++ b/doc/examples/viz_roi_contour.py @@ -17,8 +17,7 @@ from dipy.tracking import utils from dipy.tracking.local import LocalTracking from dipy.tracking.streamline import Streamlines -from dipy.viz import actor, window -from dipy.viz.colormap import line_colors +from dipy.viz import actor, window, colormap as cmap """ First, we need to generate some streamlines. For a more complete @@ -55,7 +54,7 @@ We will create a streamline actor from the streamlines. """ -streamlines_actor = actor.line(streamlines, line_colors(streamlines)) +streamlines_actor = actor.line(streamlines, cmap.line_colors(streamlines)) """ Next, we create a surface actor from the corpus callosum seed ROI. We diff --git a/doc/examples/viz_slice.py b/doc/examples/viz_slice.py index 4f6d32310c..42e578d45d 100644 --- a/doc/examples/viz_slice.py +++ b/doc/examples/viz_slice.py @@ -171,17 +171,17 @@ result_position = ui.TextBlock2D(text='') result_value = ui.TextBlock2D(text='') -panel_picking = ui.Panel2D(center=(200, 120), - size=(250, 125), +panel_picking = ui.Panel2D(size=(250, 125), + position=(20, 20), color=(0, 0, 0), opacity=0.75, align="left") -panel_picking.add_element(label_position, 'relative', (0.1, 0.55)) -panel_picking.add_element(label_value, 'relative', (0.1, 0.25)) +panel_picking.add_element(label_position, (0.1, 0.55)) +panel_picking.add_element(label_value, (0.1, 0.25)) -panel_picking.add_element(result_position, 'relative', (0.45, 0.55)) -panel_picking.add_element(result_value, 'relative', (0.45, 0.25)) +panel_picking.add_element(result_position, (0.45, 0.55)) +panel_picking.add_element(result_value, (0.45, 0.25)) show_m.ren.add(panel_picking) @@ -204,6 +204,7 @@ def left_click_callback(obj, ev): result_position.message = '({}, {}, {})'.format(str(i), str(j), str(k)) result_value.message = '%.8f' % data[i, j, k] + fa_actor.SetInterpolate(False) fa_actor.AddObserver('LeftButtonPressEvent', left_click_callback, 1.0) @@ -243,6 +244,7 @@ def left_click_callback_mosaic(obj, ev): result_position.message = '({}, {}, {})'.format(str(i), str(j), str(k)) result_value.message = '%.8f' % data[i, j, k] + """ Now we need to create two nested for loops which will set the positions of the grid of the mosaic and add the new actors to the renderer. We are going @@ -276,7 +278,7 @@ def left_click_callback_mosaic(obj, ev): break renderer.reset_camera() -renderer.zoom(1.6) +renderer.zoom(1.0) # show_m_mosaic.ren.add(panel_picking) # show_m_mosaic.start() diff --git a/doc/examples/viz_surfaces.py b/doc/examples/viz_surfaces.py index ece10613b5..8d1fe399c2 100644 --- a/doc/examples/viz_surfaces.py +++ b/doc/examples/viz_surfaces.py @@ -19,8 +19,7 @@ """ import dipy.io.vtk as io_vtk -import dipy.viz.utils as ut_vtk -from dipy.viz import window +from dipy.viz import window, utils as ut_vtk # Conditional import machinery for vtk # Allow import, but disable doctests if we don't have vtk @@ -57,7 +56,7 @@ [0, 4, 5], [0, 5, 1], [1, 5, 7], - [1, 7, 3]],dtype='i8') + [1, 7, 3]], dtype='i8') """ diff --git a/doc/examples/viz_timers.py b/doc/examples/viz_timers.py new file mode 100644 index 0000000000..36b13d7854 --- /dev/null +++ b/doc/examples/viz_timers.py @@ -0,0 +1,69 @@ +""" +=============== +Using a timer +=============== + +This example shows how to create a simple animation using a timer callback. + +We will use a sphere actor that generates many spheres of different colors, +radii and opacity. Then we will animate this actor by rotating and changing +global opacity levels from inside a user defined callback. + +The timer will call this user defined callback every 200 milliseconds. The +application will exit after the callback has been called 100 times. +""" + + +import numpy as np +from dipy.viz import window, actor, ui +import itertools + +xyz = 10 * np.random.rand(100, 3) +colors = np.random.rand(100, 4) +radii = np.random.rand(100) + 0.5 + +renderer = window.Renderer() + +sphere_actor = actor.sphere(centers=xyz, + colors=colors, + radii=radii) + +renderer.add(sphere_actor) + +showm = window.ShowManager(renderer, + size=(900, 768), reset_camera=False, + order_transparent=True) + +showm.initialize() + +tb = ui.TextBlock2D(bold=True) + +# use itertools to avoid global variables +counter = itertools.count() + + +def timer_callback(obj, event): + cnt = next(counter) + tb.message = "Let's count up to 100 and exit :" + str(cnt) + showm.ren.azimuth(0.05 * cnt) + sphere_actor.GetProperty().SetOpacity(cnt/100.) + showm.render() + if cnt == 100: + showm.exit() + + +renderer.add(tb) + +# Run every 200 milliseconds +showm.add_timer_callback(True, 200, timer_callback) + +showm.start() + +window.record(showm.ren, size=(900, 768), out_path="viz_timer.png") + +""" +.. figure:: viz_timer.png + :align: center + + **Showing 100 spheres of random radii and opacity levels**. +""" diff --git a/doc/examples/viz_ui.py b/doc/examples/viz_ui.py index bb697f8df5..0b1561039d 100644 --- a/doc/examples/viz_ui.py +++ b/doc/examples/viz_ui.py @@ -1,220 +1,279 @@ +# -*- coding: utf-8 -*- """ =============== User Interfaces =============== -This example shows how to use the UI API. -Currently includes button, textbox, panel, and line slider. +This example shows how to use the UI API. We will demonstrate how to create +several DIPY UI elements, then use a list box to toggle which element is shown. First, a bunch of imports. - """ import os -from dipy.data import read_viz_icons, fetch_viz_icons +from dipy.viz import read_viz_icons, fetch_viz_icons from dipy.viz import ui, window """ -3D Elements -=========== +Shapes +====== -Let's have some cubes in 3D. +Let's start by drawing some simple shapes. First, a rectangle. """ +rect = ui.Rectangle2D(size=(200, 200), position=(400, 300), color=(1, 0, 1)) -def cube_maker(color=None, size=(0.2, 0.2, 0.2), center=None): - cube = window.vtk.vtkCubeSource() - cube.SetXLength(size[0]) - cube.SetYLength(size[1]) - cube.SetZLength(size[2]) - if center is not None: - cube.SetCenter(*center) - cube_mapper = window.vtk.vtkPolyDataMapper() - cube_mapper.SetInputConnection(cube.GetOutputPort()) - cube_actor = window.vtk.vtkActor() - cube_actor.SetMapper(cube_mapper) - if color is not None: - cube_actor.GetProperty().SetColor(color) - return cube_actor +""" +Then we can draw a solid circle, or disk. +""" + +disk = ui.Disk2D(outer_radius=50, center=(500, 500), color=(1, 1, 0)) +""" +Add an inner radius to make a ring. +""" -cube_actor_1 = cube_maker((1, 0, 0), (50, 50, 50), center=(0, 0, 0)) -cube_actor_2 = cube_maker((0, 1, 0), (10, 10, 10), center=(100, 0, 0)) +ring = ui.Disk2D(outer_radius=50, inner_radius=45, center=(500, 300), + color=(0, 1, 1)) """ -Buttons -======= +Image +===== -We first fetch the icons required for making the buttons. +Now let's display an image. First we need to fetch some icons that are included +in DIPY. """ fetch_viz_icons() """ -Add the icon filenames to a dict. +Now we can create an image container. """ -icon_files = dict() -icon_files['stop'] = read_viz_icons(fname='stop2.png') -icon_files['play'] = read_viz_icons(fname='play3.png') -icon_files['plus'] = read_viz_icons(fname='plus.png') -icon_files['cross'] = read_viz_icons(fname='cross.png') +img = ui.ImageContainer2D(img_path=read_viz_icons(fname='home3.png'), + position=(450, 350)) """ -Create a button through our API. +Panel with buttons and text +=========================== + +Let's create some buttons and text and put them in a panel. First we'll +make the panel. """ -button_example = ui.Button2D(icon_fnames=icon_files) +panel = ui.Panel2D(size=(300, 150), color=(1, 1, 1), align="right") +panel.center = (500, 400) """ -We now add some click listeners. +Then we'll make two text labels and place them on the panel. +Note that we specifiy the position with integer numbers of pixels. """ +text = ui.TextBlock2D(text='Click me') +text2 = ui.TextBlock2D(text='Me too') +panel.add_element(text, (50, 100)) +panel.add_element(text2, (180, 100)) -def left_mouse_button_click(i_ren, obj, button): - print("Left Button Clicked") +""" +Then we'll create two buttons and add them to the panel. +Note that here we specify the positions with floats. In this case, these are +percentages of the panel size. +""" -def left_mouse_button_drag(i_ren, obj, button): - print ("Left Button Dragged") +button_example = ui.Button2D( + icon_fnames=[('square', read_viz_icons(fname='stop2.png'))]) +icon_files = [] +icon_files.append(('down', read_viz_icons(fname='circle-down.png'))) +icon_files.append(('left', read_viz_icons(fname='circle-left.png'))) +icon_files.append(('up', read_viz_icons(fname='circle-up.png'))) +icon_files.append(('right', read_viz_icons(fname='circle-right.png'))) + +second_button_example = ui.Button2D(icon_fnames=icon_files) -button_example.on_left_mouse_button_drag = left_mouse_button_drag -button_example.on_left_mouse_button_pressed = left_mouse_button_click +panel.add_element(button_example, (0.25, 0.33)) +panel.add_element(second_button_example, (0.66, 0.33)) +""" +We can add a callback to each button to perform some action. +""" -def right_mouse_button_drag(i_ren, obj, button): - print("Right Button Dragged") +def change_text_callback(i_ren, obj, button): + text.message = 'Clicked!' + i_ren.force_render() -def right_mouse_button_click(i_ren, obj, button): - print ("Right Button Clicked") +def change_icon_callback(i_ren, obj, button): + button.next_icon() + i_ren.force_render() -button_example.on_right_mouse_button_drag = right_mouse_button_drag -button_example.on_right_mouse_button_pressed = right_mouse_button_click +button_example.on_left_mouse_button_clicked = change_text_callback +second_button_example.on_left_mouse_button_pressed = change_icon_callback """ -Let's have another button. +Cube and sliders +================ + +Let's add a cube to the scene and control it with sliders. """ -second_button_example = ui.Button2D(icon_fnames=icon_files) + +def cube_maker(color=(1, 1, 1), size=(0.2, 0.2, 0.2), center=(0, 0, 0)): + cube = window.vtk.vtkCubeSource() + cube.SetXLength(size[0]) + cube.SetYLength(size[1]) + cube.SetZLength(size[2]) + if center is not None: + cube.SetCenter(*center) + cube_mapper = window.vtk.vtkPolyDataMapper() + cube_mapper.SetInputConnection(cube.GetOutputPort()) + cube_actor = window.vtk.vtkActor() + cube_actor.SetMapper(cube_mapper) + if color is not None: + cube_actor.GetProperty().SetColor(color) + return cube_actor + +cube = cube_maker(color=(0, 0, 1), size=(20, 20, 20), center=(15, 0, 0)) """ -This time, we will call the built in `next_icon` method -via a callback that is triggered on left click. +Now we'll add two sliders: one circular and one linear. """ +ring_slider = ui.RingSlider2D(center=(740, 400), initial_value=0, + text_template="{angle:5.1f}°") -def modify_button_callback(i_ren, obj, button): - button.next_icon() - i_ren.force_render() +line_slider = ui.LineSlider2D(center=(500, 250), initial_value=0, + min_value=-10, max_value=10) +""" +We can use a callback to rotate the cube with the ring slider. +""" -second_button_example.on_left_mouse_button_pressed = modify_button_callback -""" -Panels -====== +def rotate_cube(slider): + angle = slider.value + previous_angle = slider.previous_value + rotation_angle = angle - previous_angle + cube.RotateX(rotation_angle) -Simply create a panel and add elements to it. +ring_slider.on_change = rotate_cube + +""" +Similarly, we can translate the cube with the line slider. """ -panel = ui.Panel2D(center=(440, 90), size=(300, 150), color=(1, 1, 1), - align="right") -panel.add_element(button_example, 'relative', (0.2, 0.2)) -panel.add_element(second_button_example, 'absolute', (480, 100)) + +def translate_cube(slider): + value = slider.value + cube.SetPosition(value, 0, 0) + +line_slider.on_change = translate_cube """ -TextBox -======= +Range Slider +============ + +Finally, we can add a range slider. This element is composed of two sliders. +The first slider has two handles which let you set the range of the second. """ -text = ui.TextBox2D(height=3, width=10) +range_slider = ui.RangeSlider( + line_width=8, handle_side=25, range_slider_center=(550, 450), + value_slider_center=(550, 350), length=250, min_value=0, + max_value=10, font_size=18, range_precision=2, value_precision=4, + shape="square") + """ -2D Line Slider -============== +Select menu +============ + +We just added many examples. If we showed them all at once, they would fill the +screen. Let's make a simple menu to choose which example is shown. + +We'll first make a list of the examples. """ +examples = [[rect], [disk, ring], [img], [panel], + [ring_slider, line_slider], [range_slider]] -def translate_green_cube(i_ren, obj, slider): - value = slider.value - cube_actor_2.SetPosition(value, 0, 0) +""" +Now we'll make a function to hide all the examples. Then we'll call it so that +none are shown initially. +""" -line_slider = ui.LineSlider2D(initial_value=-2, - min_value=-5, max_value=5) -line_slider.add_callback(line_slider.slider_disk, - "MouseMoveEvent", - translate_green_cube) +def hide_all_examples(): + for example in examples: + for element in example: + element.set_visibility(False) + cube.SetVisibility(False) -line_slider.add_callback(line_slider.slider_line, - "LeftButtonPressEvent", - translate_green_cube) +hide_all_examples() """ -2D Disk Slider -============== +To make the menu, we'll first need to create a list of labels which correspond +with the examples. """ +values = ['Rectangle', 'Disks', 'Image', "Button Panel", + "Line and Ring Slider", "Range Slider"] -def rotate_red_cube(i_ren, obj, slider): - angle = slider.value - previous_angle = slider.previous_value - rotation_angle = angle - previous_angle - cube_actor_1.RotateY(rotation_angle) - +""" +Now we can create the menu. +""" -disk_slider = ui.DiskSlider2D() -disk_slider.set_center((200, 200)) -disk_slider.add_callback(disk_slider.handle, - "MouseMoveEvent", - rotate_red_cube) +listbox = ui.ListBox2D(values=values, position=(10, 300), size=(300, 200), + multiselection=False) -disk_slider.add_callback(disk_slider.base_disk, - "LeftButtonPressEvent", - rotate_red_cube) """ -2D File Select Menu -============== +Then we will use a callback to show the correct example when a label is +clicked. """ -file_select_menu = ui.FileSelectMenu2D(size=(500, 500), - position=(300, 300), - font_size=16, - extensions=["py", "png"], - directory_path=os.getcwd(), - parent=None) + +def display_element(): + hide_all_examples() + example = examples[values.index(listbox.selected[0])] + for element in example: + element.set_visibility(True) + if values.index(listbox.selected[0]) == 4: + cube.SetVisibility(True) + +listbox.on_change = display_element """ -Adding Elements to the ShowManager +Show Manager ================================== -Once all elements have been initialised, they have -to be added to the show manager in the following manner. +Now that all the elements have been initialised, we add them to the show +manager. """ -current_size = (600, 600) +current_size = (800, 800) show_manager = window.ShowManager(size=current_size, title="DIPY UI Example") -show_manager.ren.add(cube_actor_1) -show_manager.ren.add(cube_actor_2) -show_manager.ren.add(panel) -show_manager.ren.add(text) -show_manager.ren.add(line_slider) -show_manager.ren.add(disk_slider) -show_manager.ren.add(file_select_menu) +show_manager.ren.add(listbox) +for example in examples: + for element in example: + show_manager.ren.add(element) +show_manager.ren.add(cube) show_manager.ren.reset_camera() +show_manager.ren.set_camera(position=(0, 0, 200)) show_manager.ren.reset_clipping_range() show_manager.ren.azimuth(30) -# Uncomment this to start the visualisation -# show_manager.start() +interactive = False + +if interactive: + show_manager.start() -window.record(show_manager.ren, size=current_size, out_path="viz_ui.png") +else: + window.record(show_manager.ren, size=current_size, out_path="viz_ui.png") """ .. figure:: viz_ui.png diff --git a/doc/examples/workflow_creation.py b/doc/examples/workflow_creation.py index ad5d8c96d4..c2b42aa382 100644 --- a/doc/examples/workflow_creation.py +++ b/doc/examples/workflow_creation.py @@ -10,10 +10,8 @@ line:: dipy_nlmeans t1.nii.gz t1_denoised.nii.gz -""" -""" -First create your workflow. Usually this would be in its own python file in +First create your workflow (let's name this workflow file as my_workflow.py). Usually this is a python file in the ``<../dipy/workflows>`` directory. """ @@ -29,7 +27,6 @@ ``Workflow`` is the base class that will be extended to create our workflow. """ - class AppendTextFlow(Workflow): def run(self, input_files, text_to_append='dipy', out_dir='', @@ -57,8 +54,8 @@ def run(self, input_files, text_to_append='dipy', out_dir='', text to a file. It is mandatory to have out_dir as a parameter. It is also mandatory - to put 'out_' in front of every parameter that is going to be an - output. Lastly, all out_ params needs to be at the end of the params + to put `out_` in front of every parameter that is going to be an + output. Lastly, all `out_` params needs to be at the end of the params list. The ``run`` docstring is very important, you need to document every @@ -87,16 +84,24 @@ def run(self, input_files, text_to_append='dipy', out_dir='', The code in the loop is the actual workflow processing code. It can be anything. For the example, it just appends text to an input file. -""" - -""" This is it for the workflow! Now to be able to call it easily via command line, you need to add this bit of code. Usually this is in a separate executable file located in ``bin``. + +The first line imports the run_flow method from the flow_runner class. """ from dipy.workflows.flow_runner import run_flow + +""" +The second line imports the ``AppendTextFlow`` class from the newly created +``my_workflow.py`` file. In this specific case, we comment this import +since ``AppendTextFlow`` class is not on an external file but in the current file. +""" + +# from dipy.workflows.my_workflow import AppendTextFlow + """ This is the method that will wrap everything that is needed to make a flow command line ready then run it. @@ -104,6 +109,7 @@ def run(self, input_files, text_to_append='dipy', out_dir='', if __name__ == "__main__": run_flow(AppendTextFlow()) + """ This is the only thing needed to make your workflow available through command line. diff --git a/doc/examples_index.rst b/doc/examples_index.rst index d84a7021cf..13ef1d2516 100644 --- a/doc/examples_index.rst +++ b/doc/examples_index.rst @@ -76,6 +76,11 @@ Mean Apparent Propagator (MAP)-MRI - :ref:`example_reconst_mapmri` +Studying diffusion time-dependence using qt-dMRI +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- :ref:`example_reconst_qtdmri` + Diffusion Tensor Imaging ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -160,6 +165,8 @@ Streamline analysis and connectivity - :ref:`example_streamline_tools` - :ref:`example_streamline_length` +- :ref:`example_cluster_confidence` +- :ref: `example_path_length_map` ------------------ @@ -171,11 +178,12 @@ Image-based Registration - :ref:`example_affine_registration_3d` - :ref:`example_syn_registration_2d` - :ref:`example_syn_registration_3d` +- :ref:`register_binary_fuzzy` Streamline-based Registration ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - :ref:`example_bundle_registration` - +- :ref:`example_tractogram_registration` ------------ Segmentation ------------ @@ -198,6 +206,11 @@ Tissue Classification - :ref:`example_tissue_classification` +Bundle Extraction +~~~~~~~~~~~~~~~~~~~~~ + +- :ref:`example_bundle_extraction` + ----------- Simulations ----------- @@ -229,6 +242,8 @@ Visualization - :ref:`example_viz_surfaces` - :ref:`example_viz_roi_contour` - :ref:`example_viz_ui` +- :ref:`example_viz_timers` + --------------- Workflows diff --git a/doc/faq.rst b/doc/faq.rst index b56e051037..0f2fd67405 100644 --- a/doc/faq.rst +++ b/doc/faq.rst @@ -108,7 +108,7 @@ Practical 3. **What do you use for visualization?** - For 3D visualization we use ``dipy.viz`` which depends in turn on ``python-vtk``:: + For 3D visualization we use ``dipy.viz`` which depends in turn on ``FURY``:: from dipy.viz import window, actor diff --git a/doc/index.rst b/doc/index.rst index b248e60a7c..41a56efd0d 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -13,66 +13,18 @@ visualization, and statistical analysis of MRI data. Highlights ********** -**DIPY 0.14.0** is now available. New features include: - -- RecoBundles: anatomically relevant segmentation of bundles -- New super fast clustering algorithm: QuickBundlesX -- New tracking algorithm: Particle Filtering Tracking. -- New tracking algorithm: Probabilistic Residual Bootstrap Tracking. -- Integration of the Streamlines API for reading, saving and processing tractograms. -- Fiber ORientation Estimated using Continuous Axially Symmetric Tensors (Forecast). -- New command line interfaces. -- Deprecated fvtk (old visualization framework). -- A range of new visualization improvements. +**DIPY 0.15.0** is now available. New features include: + +- Updated RecoBundles for automatic anatomical bundle segmentation. +- New Reconstruction Model: qtau-dMRI. +- New command line interfaces (e.g. dipy_slr). +- New continuous integration with AppVeyor CI. +- Nibabel Streamlines API now used almost everywhere for better memory management. +- Compatibility with Python 3.7. +- Many tutorials added or updated (5 New). - Large documentation update. - -**DIPY 0.13.0** is now available. New features include: - -- Faster local PCA implementation. -- Fixed different issues with OpenMP and Windows / OSX. -- Replacement of cvxopt by cvxpy. -- Replacement of Pytables by h5py. -- Updated API to support latest numpy version (1.14). -- New user interfaces for visualization. -- Large documentation update. - -**DIPY 0.12.0** is now available. New features include: - -- IVIM Simultaneous modeling of perfusion and diffusion. -- MAPL, tissue microstructure estimation using Laplacian-regularized MAP-MRI. -- DKI-based microstructural modelling. -- Free water diffusion tensor imaging. -- Denoising using Local PCA. -- Streamline-based registration (SLR). -- Fiber to bundle coherence (FBC) measures. -- Bayesian MRF-based tissue classification. -- New API for integrated user interfaces. -- New hdf5 file (.pam5) for saving reconstruction results. -- Interactive slicing of images, ODFs and peaks. -- Updated API to support latest numpy versions. -- New system for automatically generating command line interfaces. -- Faster computation of cross correlation for image registration. - -**DIPY 0.11.0** is now available. New features include: - -- New framework for contextual enhancement of ODFs. -- Compatibility with numpy (1.11). -- Compatibility with VTK 7.0 which supports Python 3.x. -- Faster PIESNO for noise estimation. -- Reorient gradient directions according to motion correction parameters. -- Supporting Python 3.3+ but not 3.2. -- Reduced memory usage in DTI. -- DSI now can use datasets with multiple b0s. -- Fixed different issues with Windows 64bit and Python 3.5. - -**DIPY 0.10.1** is now available. New features in this release include: - -- Compatibility with new versions of scipy (0.16) and numpy (1.10). -- New cleaner visualization API, including compatibility with VTK 6, and functions to create your own interactive visualizations. -- Diffusion Kurtosis Imaging (DKI): Google Summer of Code work by Rafael Henriques. -- Mean Apparent Propagator (MAP) MRI for tissue microstructure estimation. -- Anisotropic Power Maps from spherical harmonic coefficients. -- A new framework for affine registration of images. +- Moved visualization module to a new library: FURY. +- Closed 287 issues and merged 93 pull requests. See :ref:`Older Highlights `. @@ -81,16 +33,10 @@ See :ref:`Older Highlights `. Announcements ************* +- :ref:`DIPY 0.15 ` released December 12, 2018. - :ref:`DIPY 0.14 ` released May 1, 2018. - :ref:`DIPY 0.13 ` released October 24, 2017. - :ref:`DIPY 0.12 ` released June 26, 2017. -- :ref:`DIPY 0.11 ` released February 21, 2016. -- :ref:`DIPY 0.10 ` released December 4, 2015. -- :ref:`DIPY 0.9.2 ` released, March 18, 2015. -- :ref:`DIPY 0.8.0 ` released, January 6, 2015. -- DIPY_ was an official exhibitor in `OHBM 2015 `_. -- DIPY was featured in `The Scientist Magazine `_, Nov, 2014. -- `DIPY paper`_ accepted in Frontiers of Neuroinformatics, January 22nd, 2014. See some of our :ref:`Past Announcements ` diff --git a/doc/installation.rst b/doc/installation.rst index 1623e89805..f51ec28e01 100644 --- a/doc/installation.rst +++ b/doc/installation.rst @@ -27,15 +27,11 @@ Using Anaconda: On all platforms, you can use Anaconda_ to install DIPY. To do so issue the following command in a terminal:: - conda install dipy -c conda-forge + conda install -c conda-forge dipy -Some of the visualization methods require the VTK_ library and this can be installed separately (for the time being only on Python 2.7 and Python 3.6):: +Some of the visualization methods require the FURY_ library and this can be installed separately (for the time being only on Python 3.4+):: - conda install -c conda-forge vtk - -For OSX users, VTK_ is not available on conda-forge channel, so we recommend to use the following one:: - - conda install -c clinicalgraphics vtk + conda install -c conda-forge fury Using packages: =============== @@ -63,9 +59,9 @@ Windows This should work with no error. -#. Some of the visualization methods require the VTK_ library and this can be installed using Anaconda_:: +#. Some of the visualization methods require the FURY_ library and this can be installed by doing :: - conda install -c conda-forge vtk + pip install fury OSX @@ -89,9 +85,9 @@ OSX This should work with no error. -#. Some of the visualization methods require the VTK_ library and this can be installed using Anaconda_:: +#. Some of the visualization methods require the FURY_ library and this can be installed by doing:: - conda install -c clinicalgraphics vtk + pip install fury Linux ----- @@ -166,9 +162,8 @@ DIPY can process large diffusion datasets. For this reason we recommend using a Note on python versions ----------------------- -Most of the functionality in DIPY supports versions of Python from 2.6 to 3.5. -However, some visualization functionality depends on VTK_, which currently does not work with Python 3 versions. -Therefore, if you want to use the visualization functions in DIPY, please use it with Python 2. +Most DIPY functionality can be used with Python versions 2.6 and newer, including Python 3. +However, some visualization functionality depends on FURY, which only supports Python 3 in versions 7 and newer. .. _from-source: diff --git a/doc/links_names.inc b/doc/links_names.inc index 422dafd955..56563c9bcf 100644 --- a/doc/links_names.inc +++ b/doc/links_names.inc @@ -24,7 +24,7 @@ .. _`dipy gitter`: https://gitter.im/nipy/dipy .. _neurostars: https://neurostars.org/ .. _h5py: https://www.h5py.org/ -.. _cvxpy: http://www.cvxpy.org/en/latest/ +.. _cvxpy: http://www.cvxpy.org/ .. Packaging .. _neurodebian: http://neuro.debian.net @@ -99,6 +99,7 @@ .. _pytables: http://www.pytables.org .. _python-vtk: http://www.vtk.org .. _pypi: https://pypi.python.org/pypi +.. _FURY: https://fury.gl .. Python imaging projects .. _PyMVPA: http://www.pymvpa.org diff --git a/doc/old_highlights.rst b/doc/old_highlights.rst index fdbee3240a..95c56e13fb 100644 --- a/doc/old_highlights.rst +++ b/doc/old_highlights.rst @@ -4,6 +4,54 @@ Older Highlights **************** +**DIPY 0.13.0** is now available. New features include: + +- Faster local PCA implementation. +- Fixed different issues with OpenMP and Windows / OSX. +- Replacement of cvxopt by cvxpy. +- Replacement of Pytables by h5py. +- Updated API to support latest numpy version (1.14). +- New user interfaces for visualization. +- Large documentation update. + +**DIPY 0.12.0** is now available. New features include: + +- IVIM Simultaneous modeling of perfusion and diffusion. +- MAPL, tissue microstructure estimation using Laplacian-regularized MAP-MRI. +- DKI-based microstructural modelling. +- Free water diffusion tensor imaging. +- Denoising using Local PCA. +- Streamline-based registration (SLR). +- Fiber to bundle coherence (FBC) measures. +- Bayesian MRF-based tissue classification. +- New API for integrated user interfaces. +- New hdf5 file (.pam5) for saving reconstruction results. +- Interactive slicing of images, ODFs and peaks. +- Updated API to support latest numpy versions. +- New system for automatically generating command line interfaces. +- Faster computation of cross correlation for image registration. + +**DIPY 0.11.0** is now available. New features include: + +- New framework for contextual enhancement of ODFs. +- Compatibility with numpy (1.11). +- Compatibility with VTK 7.0 which supports Python 3.x. +- Faster PIESNO for noise estimation. +- Reorient gradient directions according to motion correction parameters. +- Supporting Python 3.3+ but not 3.2. +- Reduced memory usage in DTI. +- DSI now can use datasets with multiple b0s. +- Fixed different issues with Windows 64bit and Python 3.5. + +**DIPY 0.10.1** is now available. New features in this release include: + +- Compatibility with new versions of scipy (0.16) and numpy (1.10). +- New cleaner visualization API, including compatibility with VTK 6, and functions to create your own interactive visualizations. +- Diffusion Kurtosis Imaging (DKI): Google Summer of Code work by Rafael Henriques. +- Mean Apparent Propagator (MAP) MRI for tissue microstructure estimation. +- Anisotropic Power Maps from spherical harmonic coefficients. +- A new framework for affine registration of images. + DIPY was an **official exhibitor** for OHBM 2015. .. raw :: html @@ -21,7 +69,7 @@ DIPY was an **official exhibitor** for OHBM 2015. * New experimental framework for clustering * Improvements and 10X speedup for Quickbundles * Improvements in Linear Fascicle Evaluation (LiFE) -* New implementation of Geodesic Anisotropy +* New implementation of Geodesic Anisotropy * New efficient transformation functions for registration * Sparse Fascicle Model supports acquisitions with multiple b-values diff --git a/doc/old_news.rst b/doc/old_news.rst index 63861478a9..e06f07c7fc 100644 --- a/doc/old_news.rst +++ b/doc/old_news.rst @@ -5,6 +5,15 @@ Past Announcements ********************** +- :ref:`DIPY 0.13 ` released October 24, 2017. +- :ref:`DIPY 0.12 ` released June 26, 2017. +- :ref:`DIPY 0.11 ` released February 21, 2016. +- :ref:`DIPY 0.10 ` released December 4, 2015. +- :ref:`DIPY 0.9.2 ` released, March 18, 2015. +- :ref:`DIPY 0.8.0 ` released, January 6, 2015. +- DIPY_ was an official exhibitor in `OHBM 2015 `_. +- DIPY was featured in `The Scientist Magazine `_, Nov, 2014. +- `DIPY paper`_ accepted in Frontiers of Neuroinformatics, January 22nd, 2014. - **DIPY 0.7.1** is available for :ref:`download ` with **3X** more tutorials than 0.6.0! In addition, a `journal paper`_ focusing on @@ -20,6 +29,6 @@ Past Announcements - **DIPY 0.6.0** Released!, 30 March, 2013. - **DIPY 3rd Sprint**, Berkeley, CA, 8-18 April, 2013. - **IEEE ISBI HARDI challenge** 2013 chooses **DIPY**, February, 2013. - + .. include:: links_names.inc \ No newline at end of file diff --git a/doc/release0.15.rst b/doc/release0.15.rst new file mode 100644 index 0000000000..8111b2ed5a --- /dev/null +++ b/doc/release0.15.rst @@ -0,0 +1,338 @@ +.. _release0.15: + +==================================== + Release notes for DIPY version 0.15 +==================================== + +GitHub stats for 2018/05/01 - 2018/12/12 (tag: 0.14.0) + +These lists are automatically generated, and may be incomplete or contain duplicates. + +The following 30 authors contributed 676 commits. + +* Ariel Rokem +* Bramsh Qamar +* Chris Filo Gorgolewski +* David Reagan +* Demian Wassermann +* Eleftherios Garyfallidis +* Enes Albay +* Gabriel Girard +* Guillaume Theaud +* Javier Guaje +* Jean-Christophe Houde +* Jiri Borovec +* Jon Haitz Legarreta Gorroño +* Karandeep +* Kesshi Jordan +* Marc-Alexandre Côté +* Matt Cieslak +* Matthew Brett +* Parichit Sharma +* Ricci Woo +* Rutger Fick +* Serge Koudoro +* Shreyas Fadnavis +* Chandan Gangwar +* Daniel Enrico Cahall +* David Hunt +* Francois Rheault +* Jacob Wasserthal + + +We closed a total of 287 issues, 93 pull requests and 194 regular issues; +this is the full list (generated with the script +:file:`tools/github_stats.py`): + +Pull Requests (93): + +* :ghpull:`1684`: [FIX] testing line-based target function +* :ghpull:`1686`: Standardize workflow +* :ghpull:`1685`: [Fix] Typo on examples +* :ghpull:`1663`: Stats, SNR_in_CC workflow +* :ghpull:`1681`: fixed issue with cst orientation in bundle_extraction example +* :ghpull:`1680`: [Fix] workflow variable string +* :ghpull:`1683`: test for new error in IVIM +* :ghpull:`1667`: Changing the default b0_threshold in gtab +* :ghpull:`1677`: [FIX] workflow help msg +* :ghpull:`1678`: Numpy matrix deprecation +* :ghpull:`1676`: [FIX] Example Update +* :ghpull:`1283`: get_data consistence +* :ghpull:`1670`: fixed RecoBundle workflow, SLR reference, and updated fetcher.py +* :ghpull:`1669`: Flow csd sh order +* :ghpull:`1659`: From dipy.viz to FURY +* :ghpull:`1621`: workflows : warn user for strange b0 threshold +* :ghpull:`1657`: DOC: Add spherical harmonics basis documentation. +* :ghpull:`1660`: OPT - moved the tolerance check outside of the for loop +* :ghpull:`1658`: STYLE: Honor 'descoteaux'and 'tournier' SH basis naming. +* :ghpull:`1281`: Representing qtau- signal attenuation using qtau-dMRI functional basis +* :ghpull:`1651`: Add save/load tck +* :ghpull:`1656`: Link to the dipy tag on neurostars +* :ghpull:`1624`: NF: Outlier scoring +* :ghpull:`1655`: [Fix] decrease tolerance on forecast +* :ghpull:`1650`: Increase codecov tolerance +* :ghpull:`1649`: Path Length Map example rebase +* :ghpull:`1556`: RecoBundles and SLR workflows +* :ghpull:`1645`: Fix worflows creation tutorial error +* :ghpull:`1647`: DOC: Fix duplicate link and AppVeyor badge. +* :ghpull:`1644`: Adds an Appveyor badge +* :ghpull:`1643`: Add hash for SCIL b0 file +* :ghpull:`787`: TST: Add an appveyor starter file. +* :ghpull:`1642`: Test that you can use the 724 symmetric sphere in PAM. +* :ghpull:`1641`: changed vertices to float64 in evenly_distributed_sphere_642.npz +* :ghpull:`1564`: Added scroll bar to ListBox2D +* :ghpull:`1636`: Fixed broken link. +* :ghpull:`1584`: Added Examples +* :ghpull:`1554`: Checking if the input file or directory exists when running a workflow +* :ghpull:`1528`: Show spheres with different radii, colors and opacities + add timers + add exit a + resolve issue with imread +* :ghpull:`1526`: Eigenvalue - eigenvector array compatibility check +* :ghpull:`1628`: Adding python 3.7 on travis +* :ghpull:`1623`: NF: Convert between 4D DEC FA and 3D 24 bit representation. +* :ghpull:`1622`: [Fix] viz slice example +* :ghpull:`1626`: RF - removed duplicate tests +* :ghpull:`1619`: [DOC] update VTK version +* :ghpull:`1592`: Added File Menu element to viz.ui +* :ghpull:`1559`: Checkbox and RadioButton elements for viz.ui +* :ghpull:`1583`: Fix the relative SF threshold Issue +* :ghpull:`1602`: Fix random seed in tracking +* :ghpull:`1609`: [DOC] update dependencies file +* :ghpull:`1560`: Removed affine matrices from tracking. +* :ghpull:`1593`: Removed event.abort for release events +* :ghpull:`1597`: Upgrade nibabel minimum version +* :ghpull:`1601`: Fix: Decrease Nosetest warning +* :ghpull:`1515`: RF: Use the new Streamlines API for orienting of streamlines. +* :ghpull:`1590`: Revert 1570 file menu +* :ghpull:`1589`: Fix calculation of highest order for a sh basis set +* :ghpull:`1580`: Allow PRE=1 job to fail +* :ghpull:`1533`: Show message if number of arguments mismatch between the doc string and the run method. +* :ghpull:`1523`: Showing help when no input parameters are given and suppress warnings for cmds +* :ghpull:`1543`: Update the default out_strategy to create the output in the current working directory +* :ghpull:`1574`: Fixed Bug in PR #1547 +* :ghpull:`1561`: add example SDR for binary and fuzzy images +* :ghpull:`1578`: BF - bad condition in maximum dg +* :ghpull:`1570`: Added File Menu element to viz.ui +* :ghpull:`1563`: Replacing major_version in viz.ui +* :ghpull:`1557`: Range slider element for viz.ui +* :ghpull:`1547`: Changed the icon set in Button2D from Dictionary to List of Tuples +* :ghpull:`1555`: Fix bug in actor.label +* :ghpull:`1522`: Image element in dipy.viz.ui +* :ghpull:`1355`: WIP: ENH: UI Listbox +* :ghpull:`1540`: fix potential zero division in demon regist. +* :ghpull:`1548`: Fixed references per request of @garyfallidis. +* :ghpull:`1542`: fix for using cvxpy solver +* :ghpull:`1546`: References to reference +* :ghpull:`1545`: Adding a reference in README.rst +* :ghpull:`1492`: Enh ui components positioning (with code refactoring) +* :ghpull:`1538`: Explanation that is mistakenly rendered as code fixed in example of DKI +* :ghpull:`1536`: DOC: Update Rafael's current institution. +* :ghpull:`1537`: removed unncessary importd from sims example +* :ghpull:`1530`: Wrong default value for parameter 'symmetric' connectivity_matrix function +* :ghpull:`1529`: minor typo fix in quickstart +* :ghpull:`1520`: Updating the documentation for the workflow creation tutorial. +* :ghpull:`1524`: Values from streamlines object +* :ghpull:`1521`: Moved some older highlights and announcements to the old news files. +* :ghpull:`1518`: DOC: updated some developers affiliations. +* :ghpull:`1517`: Dev info update +* :ghpull:`1516`: [DOC] Installation instruction update +* :ghpull:`1514`: Adding pep8speak config file +* :ghpull:`1513`: fix typo in example of quick_start +* :ghpull:`1510`: copyright updated to 2008-2018 +* :ghpull:`1508`: Adds whitespace, to appease the sphinx. +* :ghpull:`1506`: moving to 0.15.0 dev + +Issues (194): + +* :ghissue:`1684`: [FIX] testing line-based target function +* :ghissue:`1679`: Intermittent issue in testing line-based target function +* :ghissue:`1220`: RF: Replaces 1997 definitions of tensor geometric params with 1999 definitions. +* :ghissue:`1686`: Standardize workflow +* :ghissue:`746`: New fetcher returns filenames as dictionary keys in a tuple +* :ghissue:`1685`: [Fix] Typo on examples +* :ghissue:`1663`: Stats, SNR_in_CC workflow +* :ghissue:`1637`: Advice for saving results from MAPMRI +* :ghissue:`1673`: CST Image in bundle extraction is not oriented well +* :ghissue:`1681`: fixed issue with cst orientation in bundle_extraction example +* :ghissue:`1680`: [Fix] workflow variable string +* :ghissue:`1338`: Variable string input does not work with self.get_io_iterator() in workflows +* :ghissue:`1683`: test for new error in IVIM +* :ghissue:`1682`: Add tests for IVIM for new Error +* :ghissue:`634`: BinaryTissueClassifier segfaults on corner case +* :ghissue:`742`: LinAlgError on tracking quickstart, with python 3.4 +* :ghissue:`852`: Problem with spherical harmonics computations on some Anaconda python versions +* :ghissue:`1667`: Changing the default b0_threshold in gtab +* :ghissue:`1500`: Updating streamlines API in streamlinear.py +* :ghissue:`944`: Slicer fix +* :ghissue:`1111`: WIP: A lightweight UI for medical visualizations based on VTK-Python +* :ghissue:`1099`: Needed PRs for merging recobundles into Dipy's master +* :ghissue:`1544`: Plans for viz module +* :ghissue:`641`: Tests raise a deprecation warning +* :ghissue:`643`: Use appveyor for Windows CI? +* :ghissue:`400`: Add travis-ci test without matplotlib installed +* :ghissue:`1677`: [FIX] workflow help msg +* :ghissue:`1674`: Workflows should print out help per default +* :ghissue:`1678`: Numpy matrix deprecation +* :ghissue:`1397`: Running dipy 'Intro to Basic Tracking' code and keep getting error. On Linux Centos +* :ghissue:`1676`: [FIX] Example Update +* :ghissue:`10`: data.get_data() should be consistent across datasets +* :ghissue:`1283`: get_data consistence +* :ghissue:`1670`: fixed RecoBundle workflow, SLR reference, and updated fetcher.py +* :ghissue:`1669`: Flow csd sh order +* :ghissue:`1668`: One issue on handling HCP data -- HCP b vectors raise NaN in the gradient table +* :ghissue:`1662`: Remove the points added oustide of a mask. Fix the related tests. +* :ghissue:`1659`: From dipy.viz to FURY +* :ghissue:`1621`: workflows : warn user for strange b0 threshold +* :ghissue:`1657`: DOC: Add spherical harmonics basis documentation. +* :ghissue:`1296`: Need of a travis bot that runs ana/mini/conda and vtk=7.1.0+ +* :ghissue:`1660`: OPT - moved the tolerance check outside of the for loop +* :ghissue:`1658`: STYLE: Honor 'descoteaux'and 'tournier' SH basis naming. +* :ghissue:`1281`: Representing qtau- signal attenuation using qtau-dMRI functional basis +* :ghissue:`1653`: STYLE: Honor 'descoteaux' SH basis naming. +* :ghissue:`1651`: Add save/load tck +* :ghissue:`1656`: Link to the dipy tag on neurostars +* :ghissue:`1624`: NF: Outlier scoring +* :ghissue:`1655`: [Fix] decrease tolerance on forecast +* :ghissue:`1654`: Test failure in FORECAST +* :ghissue:`1414`: [WIP] Switching tests to pytest and removing nose dependencies +* :ghissue:`1650`: Increase codecov tolerance +* :ghissue:`1093`: WIP: Add functionality to clip streamlines between ROIs in `orient_by_rois` +* :ghissue:`1611`: Preloader element for viz.ui +* :ghissue:`1615`: Color Picker element for viz.ui +* :ghissue:`1631`: Path Length Map example +* :ghissue:`1649`: Path Length Map example rebase +* :ghissue:`1556`: RecoBundles and SLR workflows +* :ghissue:`1645`: Fix worflows creation tutorial error +* :ghissue:`1647`: DOC: Fix duplicate link and AppVeyor badge. +* :ghissue:`1644`: Adds an Appveyor badge +* :ghissue:`1638`: Fetcher downloads data every time it is called +* :ghissue:`1643`: Add hash for SCIL b0 file +* :ghissue:`1600`: NODDIx 2 fibers crossing +* :ghissue:`1618`: viz.ui.FileMenu2D +* :ghissue:`1569`: viz.ui.ListBoxItem2D text overflow +* :ghissue:`1532`: dipy test failed on mac osx sierra with ananoda python. +* :ghissue:`1420`: window.record() resolution limit +* :ghissue:`1396`: Visualization problem with tensors ? +* :ghissue:`1295`: Reorienting peak_slicer and ODF_slicer +* :ghissue:`1232`: With VTK 6.3, streamlines color map bar text disappears when using streamtubes +* :ghissue:`928`: dipy.viz.colormap crash on single fibers +* :ghissue:`923`: change size of colorbar in viz module +* :ghissue:`854`: VTK and Python 3 support in fvtk +* :ghissue:`759`: How to resolve python-vtk6 link issues in Ubuntu +* :ghissue:`647`: fvtk contour function ignores voxsz parameter +* :ghissue:`646`: Dipy visualization with missing (?) affine parameter +* :ghissue:`645`: Dipy visualization (fvtk) crash when saving series of images +* :ghissue:`353`: fvtk.label won't show up if called twice +* :ghissue:`787`: TST: Add an appveyor starter file. +* :ghissue:`1642`: Test that you can use the 724 symmetric sphere in PAM. +* :ghissue:`1641`: changed vertices to float64 in evenly_distributed_sphere_642.npz +* :ghissue:`1203`: Some bots might need a newer version of nibabel +* :ghissue:`1156`: Deterministic tracking workflow +* :ghissue:`642`: WIP - NF parallel framework +* :ghissue:`1135`: WIP : Multiprocessing - implemented a parallel_voxel_fit decorator +* :ghissue:`387`: References do not render correctly in SHORE example +* :ghissue:`442`: Allow length and set_number_of_points to work with generators +* :ghissue:`558`: Allow setting of the zoom on fvtk ren objects +* :ghissue:`1236`: bundle visualisation using nibabel API: wrong colormap +* :ghissue:`1389`: VTK 8: minimal version? +* :ghissue:`1519`: Scipy stopped supporting scipy.misc.imread +* :ghissue:`1596`: Reproducibility in PFT tracking +* :ghissue:`1614`: for GSoC NODDIx_PR +* :ghissue:`1576`: [WIP] Needs Optimization and Cleaning +* :ghissue:`1564`: Added scroll bar to ListBox2D +* :ghissue:`1636`: Fixed broken link. +* :ghissue:`1584`: Added Examples +* :ghissue:`1568`: Multi_io axis out of bounds error +* :ghissue:`1554`: Checking if the input file or directory exists when running a workflow +* :ghissue:`1528`: Show spheres with different radii, colors and opacities + add timers + add exit a + resolve issue with imread +* :ghissue:`1108`: Local PCA Slow Version +* :ghissue:`1526`: Eigenvalue - eigenvector array compatibility check +* :ghissue:`1628`: Adding python 3.7 on travis +* :ghissue:`1623`: NF: Convert between 4D DEC FA and 3D 24 bit representation. +* :ghissue:`1622`: [Fix] viz slice example +* :ghissue:`1629`: [WIP][fix] remove Userwarning message +* :ghissue:`1591`: PRE is failing : module 'cvxpy' has no attribute 'utilities' +* :ghissue:`1626`: RF - removed duplicate tests +* :ghissue:`1582`: SF threshold in PMF is not relative +* :ghissue:`1575`: Website: warning about python versions +* :ghissue:`1619`: [DOC] update VTK version +* :ghissue:`1592`: Added File Menu element to viz.ui +* :ghissue:`1559`: Checkbox and RadioButton elements for viz.ui +* :ghissue:`1583`: Fix the relative SF threshold Issue +* :ghissue:`1602`: Fix random seed in tracking +* :ghissue:`1620`: 3.7 wheels +* :ghissue:`1598`: Apply Transform workflow for transforming a collection of moving images. +* :ghissue:`1595`: Workflow for visualizing the quality of the registered data with DIPY +* :ghissue:`1581`: Image registration Workflow with quality metrices +* :ghissue:`1588`: Dipy.reconst.shm.calculate_max_order only works on specific cases. +* :ghissue:`1608`: Parallelized affine registration +* :ghissue:`1610`: Tortoise - sub +* :ghissue:`1607`: Reminder to add in the docs that users will need to update nibabel to 2.3.0 during the next release +* :ghissue:`1609`: [DOC] update dependencies file +* :ghissue:`1560`: Removed affine matrices from tracking. +* :ghissue:`1593`: Removed event.abort for release events +* :ghissue:`1586`: Slider breaks interaction in viz_advanced example +* :ghissue:`1597`: Upgrade nibabel minimum version +* :ghissue:`1601`: Fix: Decrease Nosetest warning +* :ghissue:`1515`: RF: Use the new Streamlines API for orienting of streamlines. +* :ghissue:`1585`: Add a random seed for reproducibility +* :ghissue:`1594`: Integrating the support for the visualization in Affine registration +* :ghissue:`1590`: Revert 1570 file menu +* :ghissue:`1589`: Fix calculation of highest order for a sh basis set +* :ghissue:`1577`: Revert "Added File Menu element to viz.ui" +* :ghissue:`1571`: WIP: multi-threaded on affine registration +* :ghissue:`1580`: Allow PRE=1 job to fail +* :ghissue:`1533`: Show message if number of arguments mismatch between the doc string and the run method. +* :ghissue:`1523`: Showing help when no input parameters are given and suppress warnings for cmds +* :ghissue:`1579`: Error on PRE=1 (cython / numpy) +* :ghissue:`1543`: Update the default out_strategy to create the output in the current working directory +* :ghissue:`1433`: New version of h5py messing with us? +* :ghissue:`1541`: demon registration, unstable? +* :ghissue:`1574`: Fixed Bug in PR #1547 +* :ghissue:`1573`: Failure in test_ui_listbox_2d +* :ghissue:`1561`: add example SDR for binary and fuzzy images +* :ghissue:`1578`: BF - bad condition in maximum dg +* :ghissue:`1566`: Bad condition in local tracking +* :ghissue:`1570`: Added File Menu element to viz.ui +* :ghissue:`1572`: [WIP] +* :ghissue:`1567`: WIP: NF: multi-threaded on affine registration +* :ghissue:`1563`: Replacing major_version in viz.ui +* :ghissue:`1557`: Range slider element for viz.ui +* :ghissue:`1547`: Changed the icon set in Button2D from Dictionary to List of Tuples +* :ghissue:`1555`: Fix bug in actor.label +* :ghissue:`1551`: Actor.label not working anymore +* :ghissue:`1522`: Image element in dipy.viz.ui +* :ghissue:`1549`: CVXPY installation on >3.5 +* :ghissue:`1355`: WIP: ENH: UI Listbox +* :ghissue:`1562`: Should we retire our Python 3.5 travis builds? +* :ghissue:`1550`: Memory error when running rigid transform +* :ghissue:`1540`: fix potential zero division in demon regist. +* :ghissue:`1548`: Fixed references per request of @garyfallidis. +* :ghissue:`1527`: New version of CVXPY changes API +* :ghissue:`1542`: fix for using cvxpy solver +* :ghissue:`1534`: Changed the icon set in Button2D from Dictionary to List of Tuples +* :ghissue:`1546`: References to reference +* :ghissue:`1545`: Adding a reference in README.rst +* :ghissue:`1492`: Enh ui components positioning (with code refactoring) +* :ghissue:`1538`: Explanation that is mistakenly rendered as code fixed in example of DKI +* :ghissue:`1536`: DOC: Update Rafael's current institution. +* :ghissue:`1487`: Commit for updated check_scratch.py script. +* :ghissue:`1486`: Parichit dipy flows +* :ghissue:`1539`: Changing the default behavior of the workflows to create the output file(s) in the current working directory. +* :ghissue:`1537`: removed unncessary importd from sims example +* :ghissue:`1535`: removed some unnecessary imports from sims example +* :ghissue:`1530`: Wrong default value for parameter 'symmetric' connectivity_matrix function +* :ghissue:`1529`: minor typo fix in quickstart +* :ghissue:`1520`: Updating the documentation for the workflow creation tutorial. +* :ghissue:`1524`: Values from streamlines object +* :ghissue:`1521`: Moved some older highlights and announcements to the old news files. +* :ghissue:`1518`: DOC: updated some developers affiliations. +* :ghissue:`1517`: Dev info update +* :ghissue:`1516`: [DOC] Installation instruction update +* :ghissue:`1514`: Adding pep8speak config file +* :ghissue:`1507`: Mathematical expressions are not rendered correctly in reference page +* :ghissue:`1513`: fix typo in example of quick_start +* :ghissue:`1510`: copyright updated to 2008-2018 +* :ghissue:`1508`: Adds whitespace, to appease the sphinx. +* :ghissue:`1512`: Fix typo in example of quick_start +* :ghissue:`1511`: Fix typo in exaample quick_start +* :ghissue:`1509`: DOC: fix math rendering for some dki functions +* :ghissue:`1506`: moving to 0.15.0 dev diff --git a/doc/stateoftheart.rst b/doc/stateoftheart.rst index 7b8e46e4f8..9d76fbfd21 100644 --- a/doc/stateoftheart.rst +++ b/doc/stateoftheart.rst @@ -28,6 +28,7 @@ For a full list of the features implemented in the most recent release cycle, ch .. toctree:: :maxdepth: 1 + release0.15 release0.14 release0.13 release0.12 diff --git a/doc/theory/index.rst b/doc/theory/index.rst index 4ef234ab28..8d04bb722b 100644 --- a/doc/theory/index.rst +++ b/doc/theory/index.rst @@ -8,3 +8,4 @@ Contents: :maxdepth: 2 spherical + sh_basis diff --git a/doc/theory/sh_basis.rst b/doc/theory/sh_basis.rst new file mode 100644 index 0000000000..1b355c4daa --- /dev/null +++ b/doc/theory/sh_basis.rst @@ -0,0 +1,85 @@ +.. _sh-basis: + +======================== +Spherical Harmonic bases +======================== + +Spherical Harmonics (SH) are functions defined on the sphere. A collection of SH +can used as a basis function to represent and reconstruct any function on the +surface of a unit sphere. + +Spherical harmonics are ortho-normal functions defined by: + +.. math:: + + Y_l^m(\theta, \phi) = (-1)^m \sqrt{\frac{2l + 1}{4 \pi} \frac{(l - m)!}{(l + m)!}} P_l^m( cos \theta) e^{i m \phi} + +where $l$ is the band index, $m$ is the order, $P_l^m$ is an associated +$l$-th degree, $m$-th order Legendre polynomial, and $(\theta, \phi)$ is the +representation of the direction vector in the spherical coordinate. + +A function $f(\theta, \phi)$ can be represented using a spherical harmonics +basis using the spherical harmonics coefficients $a_l^m$, which can be +computed using the expression: + +.. math:: + + a_l^m = \int_S f(\theta, \phi) Y_l^m(\theta, \phi) ds + +Once the coefficients are computed, the function $f(\theta, \phi)$ can be +approximately computed as: + +.. math:: + + f(\theta, \phi) = \sum_{l = 0}^{\inf} \sum_{m = -l}^{l} a^m_l Y_l^m(\theta, \phi) + +In HARDI, the Orientation Distribution Function (ODF) is a function on the +sphere. + +Several Spherical Harmonics bases have been proposed in the diffusion imaging +literature for the computation of the ODF. DIPY implements two of these in the +:mod:`~dipy.reconst.shm` module tool set: + +- The basis proposed by Descoteaux *et al.* [1]_: + +.. math:: + + Y_i(\theta, \phi) = + \begin{cases} + \sqrt{2} \Re(Y_l^m(\theta, \phi)) & -l \leq m < 0, \\ + Y_l^0(\theta, \phi) & m = 0, \\ + \sqrt{2} \Im(Y_l^m(\theta, \phi)) & 0 < m \leq l + \end{cases} + +- The basis proposed by Tournier *et al.* [2]_: + +.. math:: + + Y_i(\theta, \phi) = + \begin{cases} + \Re(Y_l^m(\theta, \phi)) & -l \leq m < 0, \\ + Y_k^0(\theta, \phi) & m = 0, \\ + \Im(Y_{|l|}^m(\theta, \phi)) & 0 < m \leq l + \end{cases} + +In both cases, $\Re$ denotes the real part of the spherical harmonic basis, and +$\Im$ denotes the imaginary part. + +In practice, a maximum even order $k$ is chosen such that $k \leq l$. The +choice of an even order is motivated by the symmetry of the diffusion process +around the origin. + +Descoteaux *et al.* [1]_ use the Q-Ball Imaging (QBI) formalization to recover +the ODF, while Tournier *et al.* [2]_ use the Spherical Deconvolution (SD) +framework to recover the ODF. + + +References +---------- +.. [1] Descoteaux, M., Angelino, E., Fitzgibbons, S. and Deriche, R. + Regularized, Fast, and Robust Analytical Q‐ball Imaging. + Magn. Reson. Med. 2007;58:497-510. +.. [2] Tournier J.D., Calamante F. and Connelly A. Robust determination + of the fibre orientation distribution in diffusion MRI: + Non-negativity constrained super-resolved spherical deconvolution. + NeuroImage. 2007;35(4):1459–1472. diff --git a/requirements.txt b/requirements.txt index 59a6b1e4cb..7c021300a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,5 @@ cython>=0.25.1 numpy>=1.7.1 scipy>=0.9 -nibabel>=2.1.0 +nibabel>=2.3.0 h5py>=2.4.0 diff --git a/scratch/restore_dti_simulations.py b/scratch/restore_dti_simulations.py index ab1f9bf063..585c9f4055 100644 --- a/scratch/restore_dti_simulations.py +++ b/scratch/restore_dti_simulations.py @@ -6,7 +6,7 @@ import dipy.core.gradients as grad b0 = 1000. -bvecs, bval = dpd.read_bvec_file(dpd.get_data('55dir_grad.bvec')) +bvecs, bval = dpd.read_bvec_file(dpd.get_fnames('55dir_grad.bvec')) gtab = grad.gradient_table(bval, bvecs) B = bval[1] diff --git a/tools/run_with_env.cmd b/tools/run_with_env.cmd new file mode 100644 index 0000000000..5da547c499 --- /dev/null +++ b/tools/run_with_env.cmd @@ -0,0 +1,88 @@ +:: To build extensions for 64 bit Python 3, we need to configure environment +:: variables to use the MSVC 2010 C++ compilers from GRMSDKX_EN_DVD.iso of: +:: MS Windows SDK for Windows 7 and .NET Framework 4 (SDK v7.1) +:: +:: To build extensions for 64 bit Python 2, we need to configure environment +:: variables to use the MSVC 2008 C++ compilers from GRMSDKX_EN_DVD.iso of: +:: MS Windows SDK for Windows 7 and .NET Framework 3.5 (SDK v7.0) +:: +:: 32 bit builds, and 64-bit builds for 3.5 and beyond, do not require specific +:: environment configurations. +:: +:: Note: this script needs to be run with the /E:ON and /V:ON flags for the +:: cmd interpreter, at least for (SDK v7.0) +:: +:: More details at: +:: https://github.com/cython/cython/wiki/64BitCythonExtensionsOnWindows +:: http://stackoverflow.com/a/13751649/163740 +:: +:: Author: Olivier Grisel +:: License: CC0 1.0 Universal: http://creativecommons.org/publicdomain/zero/1.0/ +:: +:: Notes about batch files for Python people: +:: +:: Quotes in values are literally part of the values: +:: SET FOO="bar" +:: FOO is now five characters long: " b a r " +:: If you don't want quotes, don't include them on the right-hand side. +:: +:: The CALL lines at the end of this file look redundant, but if you move them +:: outside of the IF clauses, they do not run properly in the SET_SDK_64==Y +:: case, I don't know why. +@ECHO OFF + +SET COMMAND_TO_RUN=%* +SET WIN_SDK_ROOT=C:\Program Files\Microsoft SDKs\Windows +SET WIN_WDK=c:\Program Files (x86)\Windows Kits\10\Include\wdf + +:: Extract the major and minor versions, and allow for the minor version to be +:: more than 9. This requires the version number to have two dots in it. +SET MAJOR_PYTHON_VERSION=%PYTHON_VERSION:~0,1% +IF "%PYTHON_VERSION:~3,1%" == "." ( + SET MINOR_PYTHON_VERSION=%PYTHON_VERSION:~2,1% +) ELSE ( + SET MINOR_PYTHON_VERSION=%PYTHON_VERSION:~2,2% +) + +:: Based on the Python version, determine what SDK version to use, and whether +:: to set the SDK for 64-bit. +IF %MAJOR_PYTHON_VERSION% == 2 ( + SET WINDOWS_SDK_VERSION="v7.0" + SET SET_SDK_64=Y +) ELSE ( + IF %MAJOR_PYTHON_VERSION% == 3 ( + SET WINDOWS_SDK_VERSION="v7.1" + IF %MINOR_PYTHON_VERSION% LEQ 4 ( + SET SET_SDK_64=Y + ) ELSE ( + SET SET_SDK_64=N + IF EXIST "%WIN_WDK%" ( + :: See: https://connect.microsoft.com/VisualStudio/feedback/details/1610302/ + REN "%WIN_WDK%" 0wdf + ) + ) + ) ELSE ( + ECHO Unsupported Python version: "%MAJOR_PYTHON_VERSION%" + EXIT 1 + ) +) + +IF %PYTHON_ARCH% == 64 ( + IF %SET_SDK_64% == Y ( + ECHO Configuring Windows SDK %WINDOWS_SDK_VERSION% for Python %MAJOR_PYTHON_VERSION% on a 64 bit architecture + SET DISTUTILS_USE_SDK=1 + SET MSSdk=1 + "%WIN_SDK_ROOT%\%WINDOWS_SDK_VERSION%\Setup\WindowsSdkVer.exe" -q -version:%WINDOWS_SDK_VERSION% + "%WIN_SDK_ROOT%\%WINDOWS_SDK_VERSION%\Bin\SetEnv.cmd" /x64 /release + ECHO Executing: %COMMAND_TO_RUN% + call %COMMAND_TO_RUN% || EXIT 1 + ) ELSE ( + ECHO Using default MSVC build environment for 64 bit architecture + ECHO Executing: %COMMAND_TO_RUN% + call %COMMAND_TO_RUN% || EXIT 1 + ) +) ELSE ( + ECHO Using default MSVC build environment for 32 bit architecture + ECHO Executing: %COMMAND_TO_RUN% + call %COMMAND_TO_RUN% || EXIT 1 +)