-
Notifications
You must be signed in to change notification settings - Fork 590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added map_samples implementation #5444
base: develop
Are you sure you want to change the base?
Changes from all commits
88aecb8
c20a10f
7b1a754
dc7fea9
5f5ebc5
86f3f48
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -7,6 +7,7 @@ | |||||||||||
""" | ||||||||||||
|
||||||||||||
from collections import defaultdict | ||||||||||||
from concurrent.futures import ThreadPoolExecutor | ||||||||||||
import contextlib | ||||||||||||
from datetime import datetime | ||||||||||||
import fnmatch | ||||||||||||
|
@@ -17,9 +18,9 @@ | |||||||||||
import random | ||||||||||||
import string | ||||||||||||
|
||||||||||||
|
||||||||||||
from bson import json_util, ObjectId, DBRef | ||||||||||||
import cachetools | ||||||||||||
from deprecated import deprecated | ||||||||||||
import mongoengine.errors as moe | ||||||||||||
from pymongo import ( | ||||||||||||
DeleteMany, | ||||||||||||
|
@@ -28,7 +29,6 @@ | |||||||||||
UpdateMany, | ||||||||||||
UpdateOne, | ||||||||||||
) | ||||||||||||
from pymongo.collection import Collection | ||||||||||||
from pymongo.errors import CursorNotFound, BulkWriteError | ||||||||||||
|
||||||||||||
import eta.core.serial as etas | ||||||||||||
|
@@ -2884,6 +2884,96 @@ def delete_group_slice(self, name): | |||||||||||
|
||||||||||||
self.save() | ||||||||||||
|
||||||||||||
def map_samples( | ||||||||||||
self, | ||||||||||||
map_func, | ||||||||||||
map_func_args=None, | ||||||||||||
aggregate_func=None, | ||||||||||||
num_workers=4, | ||||||||||||
progress=False, | ||||||||||||
autosave=False, | ||||||||||||
batch_size=None, | ||||||||||||
batching_strategy=None, | ||||||||||||
skip_failures=True, | ||||||||||||
warn_failures=False, | ||||||||||||
): | ||||||||||||
"""Maps a function over the samples in the dataset. | ||||||||||||
|
||||||||||||
Args: | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where parameters are the same as |
||||||||||||
map_func: a function that accepts a :class:`fiftyone.core.sample.Sample` as | ||||||||||||
an input and returns a result | ||||||||||||
map_func_args (None): additional arguments to pass to the map_func | ||||||||||||
aggregate_func (None): an optional function that accepts a list of | ||||||||||||
the results of the mapping operation and returns the final | ||||||||||||
result. By default, the results are returned as a list | ||||||||||||
num_workers (4): the number of worker threads to use | ||||||||||||
progress (False): whether to render a progress bar (True/False) | ||||||||||||
autosave (False): whether to automatically save the results | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that there's no point in The only reason that for sample in dataset.iter_samples():
fcn(sample)
sample.save() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a scenario where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I think that's right Minh. We should make the default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah you are actually supporting map-reduce here, not just map. In that case we might want to consider a slightly smarter default to make the syntax for common cases as clean as possible:
Per above, we may want to rename We may also want to consider renaming the function, eg |
||||||||||||
batch_size (None): the batch size to use when autosaving samples | ||||||||||||
batching_strategy (None): the batching strategy to use for each save | ||||||||||||
skip_failures (True): whether to gracefully continue without raising an | ||||||||||||
error if processing fails for a sample | ||||||||||||
warn_failures (False): whether to log a warning if processing fails for | ||||||||||||
a sample | ||||||||||||
|
||||||||||||
Returns: | ||||||||||||
the result of the mapping operation, which is a list of the results | ||||||||||||
if ``aggregate_func`` is not provided | ||||||||||||
""" | ||||||||||||
if not callable(map_func): | ||||||||||||
raise ValueError("map_func must be callable") | ||||||||||||
|
||||||||||||
# Create a thread pool | ||||||||||||
with ThreadPoolExecutor(max_workers=num_workers) as executor: | ||||||||||||
futures = [] | ||||||||||||
result = [] | ||||||||||||
max_queue = 10000 | ||||||||||||
|
||||||||||||
# Submit samples to the worker pool | ||||||||||||
for sample in self.iter_samples( | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't run this code myself, but in a static review of it, I'm confused how this is even working when
fiftyone/fiftyone/core/dataset.py Lines 2972 to 2976 in ffd7f97
Isn't is possible for the thread pool here to accumulate a backlog and not actually apply One explanation for the reported speedup in the PR is that no DB updates are actually happening because I fundamentally don't understand how the approach in this PR could ever be much faster than Maybe I'm misunderstanding something here, but I'm just very confused There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I am not 100% certain on the details of how saving works with However, we are testing two batching strategies (latency and static) and the number looks sort of consistent so I think there might a good chance that the batch update works (but not 100% sure). Can we check if a sample has an unsaved update? I can run a validation at the end of parallel processing to figure out if there is any unsaved update. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another solution is that we can call |
||||||||||||
progress=progress, | ||||||||||||
autosave=autosave, | ||||||||||||
batch_size=batch_size, | ||||||||||||
batching_strategy=batching_strategy, | ||||||||||||
): | ||||||||||||
future = executor.submit(map_func, sample, map_func_args) | ||||||||||||
futures.append(future) | ||||||||||||
if len(futures) > max_queue: | ||||||||||||
self._process_future_result( | ||||||||||||
futures, result, skip_failures, warn_failures | ||||||||||||
) | ||||||||||||
futures = [] | ||||||||||||
|
||||||||||||
# Process remaining results | ||||||||||||
self._process_future_result( | ||||||||||||
futures, result, skip_failures, warn_failures | ||||||||||||
) | ||||||||||||
|
||||||||||||
if aggregate_func is not None: | ||||||||||||
if callable(aggregate_func): | ||||||||||||
result = aggregate_func(result) | ||||||||||||
else: | ||||||||||||
raise ValueError("aggregate_func must be callable") | ||||||||||||
|
||||||||||||
return result | ||||||||||||
|
||||||||||||
def _process_future_result( | ||||||||||||
self, futures, result, skip_failures, warn_failures | ||||||||||||
): | ||||||||||||
for future in futures: | ||||||||||||
try: | ||||||||||||
result.append(future.result()) | ||||||||||||
except Exception as e: | ||||||||||||
if not skip_failures: | ||||||||||||
raise RuntimeError( | ||||||||||||
"Worker failed while processing sample" | ||||||||||||
) from e | ||||||||||||
|
||||||||||||
if warn_failures: | ||||||||||||
logger.warning("Error processing sample: %s", e) | ||||||||||||
|
||||||||||||
result.append(None) | ||||||||||||
|
||||||||||||
def iter_samples( | ||||||||||||
self, | ||||||||||||
progress=False, | ||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -6,10 +6,14 @@ | |||||||||
| | ||||||||||
""" | ||||||||||
from collections import defaultdict, OrderedDict | ||||||||||
from concurrent.futures import ThreadPoolExecutor | ||||||||||
import contextlib | ||||||||||
from copy import copy, deepcopy | ||||||||||
import itertools | ||||||||||
import logging | ||||||||||
import numbers | ||||||||||
import queue | ||||||||||
import threading | ||||||||||
|
||||||||||
from bson import ObjectId | ||||||||||
from pymongo.errors import CursorNotFound | ||||||||||
|
@@ -29,6 +33,9 @@ | |||||||||
fost = fou.lazy_import("fiftyone.core.stages") | ||||||||||
|
||||||||||
|
||||||||||
logger = logging.getLogger(__name__) | ||||||||||
|
||||||||||
|
||||||||||
class DatasetView(foc.SampleCollection): | ||||||||||
"""A view into a :class:`fiftyone.core.dataset.Dataset`. | ||||||||||
|
||||||||||
|
@@ -447,6 +454,128 @@ def view(self): | |||||||||
""" | ||||||||||
return copy(self) | ||||||||||
|
||||||||||
def map_samples( | ||||||||||
self, | ||||||||||
map_func, | ||||||||||
map_func_args=None, | ||||||||||
aggregate_func=None, | ||||||||||
num_workers=4, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why 4? From your local testing, 2 seemed faster. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AH I just put a random number there until I can run a bit more testing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Validate and adjust the number of workers. The hardcoded value of 4 workers may not be optimal. Consider:
- num_workers=4,
+ num_workers=None, Add this validation at the start of the method: if num_workers is None:
num_workers = max(1, mp.cpu_count() - 1)
elif num_workers <= 0:
raise ValueError("num_workers must be greater than 0") |
||||||||||
progress=False, | ||||||||||
autosave=False, | ||||||||||
batch_size=None, | ||||||||||
batching_strategy=None, | ||||||||||
skip_failures=True, | ||||||||||
warn_failures=False, | ||||||||||
): | ||||||||||
if not callable(map_func): | ||||||||||
raise ValueError("map_func must be callable") | ||||||||||
|
||||||||||
# Create shared queues with max size | ||||||||||
MAX_QUEUE_SIZE = 10000 | ||||||||||
work_queue = queue.Queue(maxsize=MAX_QUEUE_SIZE) | ||||||||||
result_queue = queue.Queue() | ||||||||||
|
||||||||||
def worker(thread_name): | ||||||||||
"""Worker thread function that processes samples from the queue""" | ||||||||||
while True: | ||||||||||
try: | ||||||||||
sample = work_queue.get_nowait() | ||||||||||
try: | ||||||||||
result = map_func(sample, map_func_args) | ||||||||||
result_queue.put(("success", result)) | ||||||||||
except Exception as e: | ||||||||||
if skip_failures: | ||||||||||
if warn_failures: | ||||||||||
print( | ||||||||||
f"Warning: Failed to process sample: {e}" | ||||||||||
) | ||||||||||
Comment on lines
+489
to
+491
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Use logger instead of print statements. Replace print statements with proper logging calls since a logger is already configured. - print(
- f"Warning: Failed to process sample: {e}"
- )
+ logger.warning("Failed to process sample: %s", e) 📝 Committable suggestion
Suggested change
|
||||||||||
result_queue.put(("failure", None)) | ||||||||||
else: | ||||||||||
result_queue.put(("error", e)) | ||||||||||
work_queue.task_done() | ||||||||||
except queue.Empty: | ||||||||||
print(f"Worker thread {thread_name} exiting") | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Use logger for worker thread messages. Replace print statement with proper logging for consistency. - print(f"Worker thread {thread_name} exiting")
+ logger.debug("Worker thread %s exiting", thread_name) 📝 Committable suggestion
Suggested change
|
||||||||||
break | ||||||||||
|
||||||||||
results = [] | ||||||||||
sample_count = 0 | ||||||||||
error_occurred = False | ||||||||||
threads = [] | ||||||||||
|
||||||||||
try: | ||||||||||
# Start worker threads | ||||||||||
for i in range(num_workers): | ||||||||||
thread = threading.Thread( | ||||||||||
target=worker, | ||||||||||
name=f"SampleMapper-{i}", | ||||||||||
args=(f"SampleMapper-{i}",), | ||||||||||
daemon=True, | ||||||||||
) | ||||||||||
thread.start() | ||||||||||
threads.append(thread) | ||||||||||
|
||||||||||
# Process samples | ||||||||||
for sample in self.iter_samples( | ||||||||||
progress=progress, | ||||||||||
autosave=autosave, | ||||||||||
batch_size=batch_size, | ||||||||||
batching_strategy=batching_strategy, | ||||||||||
): | ||||||||||
if error_occurred and not skip_failures: | ||||||||||
break | ||||||||||
|
||||||||||
# Simply block until space is available in the queue | ||||||||||
work_queue.put(sample) | ||||||||||
sample_count += 1 | ||||||||||
|
||||||||||
# Collect all results | ||||||||||
while not result_queue.empty(): | ||||||||||
status, result = result_queue.get() | ||||||||||
if status == "error" and not skip_failures: | ||||||||||
raise result | ||||||||||
elif status == "success": | ||||||||||
results.append(result) | ||||||||||
|
||||||||||
finally: | ||||||||||
# Clean up threads | ||||||||||
for thread in threads: | ||||||||||
thread.join(timeout=1.0) | ||||||||||
|
||||||||||
# Check result count | ||||||||||
if len(results) != sample_count: | ||||||||||
message = f"Worker threads did not process all samples, getting {len(results)} results out of {sample_count}" | ||||||||||
if not skip_failures: | ||||||||||
raise RuntimeError(message) | ||||||||||
else: | ||||||||||
if warn_failures: | ||||||||||
logger.warning(message) | ||||||||||
|
||||||||||
# Apply aggregation if specified | ||||||||||
if aggregate_func is not None: | ||||||||||
if callable(aggregate_func): | ||||||||||
results = aggregate_func(results) | ||||||||||
else: | ||||||||||
raise ValueError("aggregate_func must be callable") | ||||||||||
|
||||||||||
return results | ||||||||||
|
||||||||||
def _process_future_result( | ||||||||||
self, futures, result, skip_failures, warn_failures | ||||||||||
): | ||||||||||
for future in futures: | ||||||||||
try: | ||||||||||
result.append(future.result()) | ||||||||||
except Exception as e: | ||||||||||
if not skip_failures: | ||||||||||
raise RuntimeError( | ||||||||||
"Worker failed while processing sample" | ||||||||||
) from e | ||||||||||
|
||||||||||
if warn_failures: | ||||||||||
logger.warning("Error processing sample: %s", e) | ||||||||||
|
||||||||||
result.append(None) | ||||||||||
|
||||||||||
def iter_samples( | ||||||||||
self, | ||||||||||
progress=False, | ||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider adding input validation for num_workers.
The
num_workers
parameter should be validated to ensure it's a positive integer. A value of 0 or negative would cause issues with ThreadPoolExecutor.