Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added map_samples implementation #5444

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 92 additions & 2 deletions fiftyone/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""

from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
import contextlib
from datetime import datetime
import fnmatch
Expand All @@ -17,9 +18,9 @@
import random
import string


from bson import json_util, ObjectId, DBRef
import cachetools
from deprecated import deprecated
import mongoengine.errors as moe
from pymongo import (
DeleteMany,
Expand All @@ -28,7 +29,6 @@
UpdateMany,
UpdateOne,
)
from pymongo.collection import Collection
from pymongo.errors import CursorNotFound, BulkWriteError

import eta.core.serial as etas
Expand Down Expand Up @@ -2884,6 +2884,96 @@ def delete_group_slice(self, name):

self.save()

def map_samples(
self,
map_func,
map_func_args=None,
aggregate_func=None,
num_workers=4,
progress=False,
autosave=False,
batch_size=None,
batching_strategy=None,
skip_failures=True,
warn_failures=False,
):
Comment on lines +2887 to +2899
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding input validation for num_workers.

The num_workers parameter should be validated to ensure it's a positive integer. A value of 0 or negative would cause issues with ThreadPoolExecutor.

def map_samples(
    self,
    map_func,
    map_func_args=None,
    aggregate_func=None,
-   num_workers=4,
+   num_workers=4,
    progress=False,
    autosave=False,
    batch_size=None,
    batching_strategy=None,
    skip_failures=True,
    warn_failures=False,
):
+   if not isinstance(num_workers, int) or num_workers < 1:
+       raise ValueError("num_workers must be a positive integer")

Committable suggestion skipped: line range outside the PR's diff.

"""Maps a function over the samples in the dataset.

Args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where parameters are the same as iter_samples(), just copy the whole description from its docstring.

map_func: a function that accepts a :class:`fiftyone.core.sample.Sample` as
an input and returns a result
map_func_args (None): additional arguments to pass to the map_func
aggregate_func (None): an optional function that accepts a list of
the results of the mapping operation and returns the final
result. By default, the results are returned as a list
num_workers (4): the number of worker threads to use
progress (False): whether to render a progress bar (True/False)
autosave (False): whether to automatically save the results
Copy link
Contributor

@brimoor brimoor Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that there's no point in map_samples(autosave=False). If you are applying changes to samples, you need to be saving them.

The only reason that iter_samples(autosave=False) makes sense is that the user is assumed to be calling sample.save() inside the for loop manually:

for sample in dataset.iter_samples():
    fcn(sample)
    sample.save()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a scenario where map_func does not modify the samples? We can just compute some large scale statistics and aggregate the results at the end without modifying samples.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that's right Minh. We should make the default autosave=True, but should leave it in, because if the caller says they don't need to save samples then we can elide that whole process.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you are actually supporting map-reduce here, not just map. In that case we might want to consider a slightly smarter default to make the syntax for common cases as clean as possible:

  • save=True by default if no aggregate_func is provided, because there's no point in mapping w/o saving)
  • save=False by default if an aggregate_func is provided, to optimize for the case where you're computing a reduction and not actually editing samples

Per above, we may want to rename autosave -> save because there's no "manual" save option in this case.

We may also want to consider renaming the function, eg map_reduce(), as we have things like ViewExpression.map() that are just map functions, not reductions.

batch_size (None): the batch size to use when autosaving samples
batching_strategy (None): the batching strategy to use for each save
skip_failures (True): whether to gracefully continue without raising an
error if processing fails for a sample
warn_failures (False): whether to log a warning if processing fails for
a sample

Returns:
the result of the mapping operation, which is a list of the results
if ``aggregate_func`` is not provided
"""
if not callable(map_func):
raise ValueError("map_func must be callable")

# Create a thread pool
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = []
result = []
max_queue = 10000

# Submit samples to the worker pool
for sample in self.iter_samples(
Copy link
Contributor

@brimoor brimoor Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't run this code myself, but in a static review of it, I'm confused how this is even working when autosave=True.

iter_samples(autosave=True) immediately calls sample.save() on each sample as soon as control is returned to the for statement:

for sample in samples:
yield sample
if autosave:
save_context.save(sample)

Isn't is possible for the thread pool here to accumulate a backlog and not actually apply map_func to a sample before the next loop iteration is started? Indeed, that would be a whole point of a pool I assume.

One explanation for the reported speedup in the PR is that no DB updates are actually happening because sample.save() is getting called too early, before there are any edits to save!

I fundamentally don't understand how the approach in this PR could ever be much faster than iter_samples(autosave=True) given that it relies on a single _aggregate() to load data, relies on a single SaveContext to save data, and uses a thread pool to apply the functions, which isn't parallelized due to GIL.

Maybe I'm misunderstanding something here, but I'm just very confused

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I am not 100% certain on the details of how saving works with iter_samples so we are just hoping it will work; I think the next step is to check to make sure that the results are correct (same stats, samples with correct fields, etc.)

However, we are testing two batching strategies (latency and static) and the number looks sort of consistent so I think there might a good chance that the batch update works (but not 100% sure).

Can we check if a sample has an unsaved update? I can run a validation at the end of parallel processing to figure out if there is any unsaved update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another solution is that we can call sample.save(...) manually or integrate the thread implementation with an auto-save batcher too. In that case, the updates will be performed correctly.

progress=progress,
autosave=autosave,
batch_size=batch_size,
batching_strategy=batching_strategy,
):
future = executor.submit(map_func, sample, map_func_args)
futures.append(future)
if len(futures) > max_queue:
self._process_future_result(
futures, result, skip_failures, warn_failures
)
futures = []

# Process remaining results
self._process_future_result(
futures, result, skip_failures, warn_failures
)

if aggregate_func is not None:
if callable(aggregate_func):
result = aggregate_func(result)
else:
raise ValueError("aggregate_func must be callable")

return result

def _process_future_result(
self, futures, result, skip_failures, warn_failures
):
for future in futures:
try:
result.append(future.result())
except Exception as e:
if not skip_failures:
raise RuntimeError(
"Worker failed while processing sample"
) from e

if warn_failures:
logger.warning("Error processing sample: %s", e)

result.append(None)

def iter_samples(
self,
progress=False,
Expand Down
129 changes: 129 additions & 0 deletions fiftyone/core/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
|
"""
from collections import defaultdict, OrderedDict
from concurrent.futures import ThreadPoolExecutor
import contextlib
from copy import copy, deepcopy
import itertools
import logging
import numbers
import queue
import threading

from bson import ObjectId
from pymongo.errors import CursorNotFound
Expand All @@ -29,6 +33,9 @@
fost = fou.lazy_import("fiftyone.core.stages")


logger = logging.getLogger(__name__)


class DatasetView(foc.SampleCollection):
"""A view into a :class:`fiftyone.core.dataset.Dataset`.

Expand Down Expand Up @@ -447,6 +454,128 @@ def view(self):
"""
return copy(self)

def map_samples(
self,
map_func,
map_func_args=None,
aggregate_func=None,
num_workers=4,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 4? From your local testing, 2 seemed faster.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AH I just put a random number there until I can run a bit more testing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Validate and adjust the number of workers.

The hardcoded value of 4 workers may not be optimal. Consider:

  1. Using multiprocessing.cpu_count() to set a default based on available CPUs
  2. Adding validation to ensure num_workers > 0
-        num_workers=4,
+        num_workers=None,

Add this validation at the start of the method:

if num_workers is None:
    num_workers = max(1, mp.cpu_count() - 1)
elif num_workers <= 0:
    raise ValueError("num_workers must be greater than 0")

progress=False,
autosave=False,
batch_size=None,
batching_strategy=None,
skip_failures=True,
warn_failures=False,
):
if not callable(map_func):
raise ValueError("map_func must be callable")

# Create shared queues with max size
MAX_QUEUE_SIZE = 10000
work_queue = queue.Queue(maxsize=MAX_QUEUE_SIZE)
result_queue = queue.Queue()

def worker(thread_name):
"""Worker thread function that processes samples from the queue"""
while True:
try:
sample = work_queue.get_nowait()
try:
result = map_func(sample, map_func_args)
result_queue.put(("success", result))
except Exception as e:
if skip_failures:
if warn_failures:
print(
f"Warning: Failed to process sample: {e}"
)
Comment on lines +489 to +491
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use logger instead of print statements.

Replace print statements with proper logging calls since a logger is already configured.

-                                print(
-                                    f"Warning: Failed to process sample: {e}"
-                                )
+                                logger.warning("Failed to process sample: %s", e)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(
f"Warning: Failed to process sample: {e}"
)
logger.warning("Failed to process sample: %s", e)

result_queue.put(("failure", None))
else:
result_queue.put(("error", e))
work_queue.task_done()
except queue.Empty:
print(f"Worker thread {thread_name} exiting")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Use logger for worker thread messages.

Replace print statement with proper logging for consistency.

-                    print(f"Worker thread {thread_name} exiting")
+                    logger.debug("Worker thread %s exiting", thread_name)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"Worker thread {thread_name} exiting")
logger.debug("Worker thread %s exiting", thread_name)

break

results = []
sample_count = 0
error_occurred = False
threads = []

try:
# Start worker threads
for i in range(num_workers):
thread = threading.Thread(
target=worker,
name=f"SampleMapper-{i}",
args=(f"SampleMapper-{i}",),
daemon=True,
)
thread.start()
threads.append(thread)

# Process samples
for sample in self.iter_samples(
progress=progress,
autosave=autosave,
batch_size=batch_size,
batching_strategy=batching_strategy,
):
if error_occurred and not skip_failures:
break

# Simply block until space is available in the queue
work_queue.put(sample)
sample_count += 1

# Collect all results
while not result_queue.empty():
status, result = result_queue.get()
if status == "error" and not skip_failures:
raise result
elif status == "success":
results.append(result)

finally:
# Clean up threads
for thread in threads:
thread.join(timeout=1.0)

# Check result count
if len(results) != sample_count:
message = f"Worker threads did not process all samples, getting {len(results)} results out of {sample_count}"
if not skip_failures:
raise RuntimeError(message)
else:
if warn_failures:
logger.warning(message)

# Apply aggregation if specified
if aggregate_func is not None:
if callable(aggregate_func):
results = aggregate_func(results)
else:
raise ValueError("aggregate_func must be callable")

return results

def _process_future_result(
self, futures, result, skip_failures, warn_failures
):
for future in futures:
try:
result.append(future.result())
except Exception as e:
if not skip_failures:
raise RuntimeError(
"Worker failed while processing sample"
) from e

if warn_failures:
logger.warning("Error processing sample: %s", e)

result.append(None)

def iter_samples(
self,
progress=False,
Expand Down
Loading
Loading