diff --git a/fiftyone/core/dataset.py b/fiftyone/core/dataset.py index c01582ada0..8555e4d713 100644 --- a/fiftyone/core/dataset.py +++ b/fiftyone/core/dataset.py @@ -7,6 +7,7 @@ """ from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor import contextlib from datetime import datetime import fnmatch @@ -17,9 +18,9 @@ import random import string + from bson import json_util, ObjectId, DBRef import cachetools -from deprecated import deprecated import mongoengine.errors as moe from pymongo import ( DeleteMany, @@ -28,7 +29,6 @@ UpdateMany, UpdateOne, ) -from pymongo.collection import Collection from pymongo.errors import CursorNotFound, BulkWriteError import eta.core.serial as etas @@ -2884,6 +2884,96 @@ def delete_group_slice(self, name): self.save() + def map_samples( + self, + map_func, + map_func_args=None, + aggregate_func=None, + num_workers=4, + progress=False, + autosave=False, + batch_size=None, + batching_strategy=None, + skip_failures=True, + warn_failures=False, + ): + """Maps a function over the samples in the dataset. + + Args: + map_func: a function that accepts a :class:`fiftyone.core.sample.Sample` as + an input and returns a result + map_func_args (None): additional arguments to pass to the map_func + aggregate_func (None): an optional function that accepts a list of + the results of the mapping operation and returns the final + result. By default, the results are returned as a list + num_workers (4): the number of worker threads to use + progress (False): whether to render a progress bar (True/False) + autosave (False): whether to automatically save the results + batch_size (None): the batch size to use when autosaving samples + batching_strategy (None): the batching strategy to use for each save + skip_failures (True): whether to gracefully continue without raising an + error if processing fails for a sample + warn_failures (False): whether to log a warning if processing fails for + a sample + + Returns: + the result of the mapping operation, which is a list of the results + if ``aggregate_func`` is not provided + """ + if not callable(map_func): + raise ValueError("map_func must be callable") + + # Create a thread pool + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [] + result = [] + max_queue = 10000 + + # Submit samples to the worker pool + for sample in self.iter_samples( + progress=progress, + autosave=autosave, + batch_size=batch_size, + batching_strategy=batching_strategy, + ): + future = executor.submit(map_func, sample, map_func_args) + futures.append(future) + if len(futures) > max_queue: + self._process_future_result( + futures, result, skip_failures, warn_failures + ) + futures = [] + + # Process remaining results + self._process_future_result( + futures, result, skip_failures, warn_failures + ) + + if aggregate_func is not None: + if callable(aggregate_func): + result = aggregate_func(result) + else: + raise ValueError("aggregate_func must be callable") + + return result + + def _process_future_result( + self, futures, result, skip_failures, warn_failures + ): + for future in futures: + try: + result.append(future.result()) + except Exception as e: + if not skip_failures: + raise RuntimeError( + "Worker failed while processing sample" + ) from e + + if warn_failures: + logger.warning("Error processing sample: %s", e) + + result.append(None) + def iter_samples( self, progress=False, diff --git a/fiftyone/core/view.py b/fiftyone/core/view.py index 0bb5f6618e..8493f8b4c8 100644 --- a/fiftyone/core/view.py +++ b/fiftyone/core/view.py @@ -6,10 +6,14 @@ | """ from collections import defaultdict, OrderedDict +from concurrent.futures import ThreadPoolExecutor import contextlib from copy import copy, deepcopy import itertools +import logging import numbers +import queue +import threading from bson import ObjectId from pymongo.errors import CursorNotFound @@ -29,6 +33,9 @@ fost = fou.lazy_import("fiftyone.core.stages") +logger = logging.getLogger(__name__) + + class DatasetView(foc.SampleCollection): """A view into a :class:`fiftyone.core.dataset.Dataset`. @@ -447,6 +454,128 @@ def view(self): """ return copy(self) + def map_samples( + self, + map_func, + map_func_args=None, + aggregate_func=None, + num_workers=4, + progress=False, + autosave=False, + batch_size=None, + batching_strategy=None, + skip_failures=True, + warn_failures=False, + ): + if not callable(map_func): + raise ValueError("map_func must be callable") + + # Create shared queues with max size + MAX_QUEUE_SIZE = 10000 + work_queue = queue.Queue(maxsize=MAX_QUEUE_SIZE) + result_queue = queue.Queue() + + def worker(thread_name): + """Worker thread function that processes samples from the queue""" + while True: + try: + sample = work_queue.get_nowait() + try: + result = map_func(sample, map_func_args) + result_queue.put(("success", result)) + except Exception as e: + if skip_failures: + if warn_failures: + print( + f"Warning: Failed to process sample: {e}" + ) + result_queue.put(("failure", None)) + else: + result_queue.put(("error", e)) + work_queue.task_done() + except queue.Empty: + print(f"Worker thread {thread_name} exiting") + break + + results = [] + sample_count = 0 + error_occurred = False + threads = [] + + try: + # Start worker threads + for i in range(num_workers): + thread = threading.Thread( + target=worker, + name=f"SampleMapper-{i}", + args=(f"SampleMapper-{i}",), + daemon=True, + ) + thread.start() + threads.append(thread) + + # Process samples + for sample in self.iter_samples( + progress=progress, + autosave=autosave, + batch_size=batch_size, + batching_strategy=batching_strategy, + ): + if error_occurred and not skip_failures: + break + + # Simply block until space is available in the queue + work_queue.put(sample) + sample_count += 1 + + # Collect all results + while not result_queue.empty(): + status, result = result_queue.get() + if status == "error" and not skip_failures: + raise result + elif status == "success": + results.append(result) + + finally: + # Clean up threads + for thread in threads: + thread.join(timeout=1.0) + + # Check result count + if len(results) != sample_count: + message = f"Worker threads did not process all samples, getting {len(results)} results out of {sample_count}" + if not skip_failures: + raise RuntimeError(message) + else: + if warn_failures: + logger.warning(message) + + # Apply aggregation if specified + if aggregate_func is not None: + if callable(aggregate_func): + results = aggregate_func(results) + else: + raise ValueError("aggregate_func must be callable") + + return results + + def _process_future_result( + self, futures, result, skip_failures, warn_failures + ): + for future in futures: + try: + result.append(future.result()) + except Exception as e: + if not skip_failures: + raise RuntimeError( + "Worker failed while processing sample" + ) from e + + if warn_failures: + logger.warning("Error processing sample: %s", e) + + result.append(None) + def iter_samples( self, progress=False, diff --git a/fiftyone/utils/eval/detection.py b/fiftyone/utils/eval/detection.py index 9c086ab587..b58e6e63e2 100644 --- a/fiftyone/utils/eval/detection.py +++ b/fiftyone/utils/eval/detection.py @@ -30,6 +30,56 @@ logger = logging.getLogger(__name__) +def _process_sample(sample, args): + """Process a single sample or its frames.""" + ( + eval_method, + eval_key, + processing_frames, + save, + tp_field, + fp_field, + fn_field, + ) = args + if processing_frames: + docs = sample.frames.values() + else: + docs = [sample] + + matches = [] + sample_tp = sample_fp = sample_fn = 0 + + for doc in docs: + doc_matches = eval_method.evaluate(doc, eval_key=eval_key) + matches.extend(doc_matches) + tp, fp, fn = _tally_matches(doc_matches) + sample_tp += tp + sample_fp += fp + sample_fn += fn + + if processing_frames and eval_key is not None: + doc[f"{eval_key}_tp"] = tp + doc[f"{eval_key}_fp"] = fp + doc[f"{eval_key}_fn"] = fn + + if save: + sample[tp_field] = sample_tp + sample[fp_field] = sample_fp + sample[fn_field] = sample_fn + + return matches + + +def _aggregate_results(results): + """Aggregate results from multiple samples.""" + matches = [] + + for sample_matches in results: + matches.extend(sample_matches) + + return matches + + def evaluate_detections( samples, pred_field, @@ -45,6 +95,8 @@ def evaluate_detections( dynamic=True, custom_metrics=None, progress=None, + num_workers=None, + batch_size=None, **kwargs, ): """Evaluates the predicted detections in the given samples with respect to @@ -141,6 +193,11 @@ def evaluate_detections( progress (None): whether to render a progress bar (True/False), use the default value ``fiftyone.config.show_progress_bars`` (None), or a progress callback function to invoke instead + num_workers (None): the number of processes to use to compute the + evaluation. If none is provided, use single-threaded evaluation + batch_size (None): an optional batch size to use for adaptive batcher + for sample updates. If not provided, use the default batching + strategy. **kwargs: optional keyword arguments for the constructor of the :class:`DetectionEvaluationConfig` being used @@ -183,44 +240,40 @@ def evaluate_detections( processing_frames = samples._is_frame_field(pred_field) save = eval_key is not None - if save: - tp_field = "%s_tp" % eval_key - fp_field = "%s_fp" % eval_key - fn_field = "%s_fn" % eval_key + tp_field = "%s_tp" % eval_key if save else None + fp_field = "%s_fp" % eval_key if save else None + fn_field = "%s_fn" % eval_key if save else None if config.requires_additional_fields: _samples = samples else: _samples = samples.select_fields([gt_field, pred_field]) - matches = [] + args = ( + eval_method, + eval_key, + processing_frames, + save, + tp_field, + fp_field, + fn_field, + ) + logger.info("Evaluating detections...") - for sample in _samples.iter_samples(progress=progress, autosave=save): - if processing_frames: - docs = sample.frames.values() - else: - docs = [sample] - - sample_tp = 0 - sample_fp = 0 - sample_fn = 0 - for doc in docs: - doc_matches = eval_method.evaluate(doc, eval_key=eval_key) - matches.extend(doc_matches) - tp, fp, fn = _tally_matches(doc_matches) - sample_tp += tp - sample_fp += fp - sample_fn += fn - - if processing_frames and save: - doc[tp_field] = tp - doc[fp_field] = fp - doc[fn_field] = fn - - if save: - sample[tp_field] = sample_tp - sample[fp_field] = sample_fp - sample[fn_field] = sample_fn + if num_workers: + matches = _samples.map_samples( + _process_sample, + map_func_args=args, + aggregate_func=_aggregate_results, + num_workers=num_workers, + progress=progress, + autosave=save, + batch_size=batch_size, + ) + else: + matches = [] + for sample in _samples.iter_samples(progress=progress, autosave=save): + matches.extend(_process_sample(sample, args)) results = eval_method.generate_results( samples,