Skip to content

Commit

Permalink
Optimize tests, add another synchronization event in test_mpfuture_do…
Browse files Browse the repository at this point in the history
…ne_callback
  • Loading branch information
mryab committed Nov 3, 2024
1 parent 30ad882 commit c8dee55
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
8 changes: 5 additions & 3 deletions tests/test_start_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ def test_cli_run_server_identity_path():
# overriding the loglevel to prevent debug print statements
cloned_env["HIVEMIND_LOGLEVEL"] = "INFO"

common_server_args = ["--hidden_dim", "4", "--num_handlers", "1"]

server_1_proc = Popen(
["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
["hivemind-server", "--num_experts", "1", "--identity_path", id_path] + common_server_args,
stderr=PIPE,
text=True,
encoding="utf-8",
Expand All @@ -51,7 +53,7 @@ def test_cli_run_server_identity_path():
assert len(ids_1) == 1

server_2_proc = Popen(
["hivemind-server", "--num_experts", "1", "--identity_path", id_path],
["hivemind-server", "--num_experts", "1", "--identity_path", id_path] + common_server_args,
stderr=PIPE,
text=True,
encoding="utf-8",
Expand All @@ -67,7 +69,7 @@ def test_cli_run_server_identity_path():
assert len(ids_2) == 1

server_3_proc = Popen(
["hivemind-server", "--num_experts", "1"],
["hivemind-server", "--num_experts", "1"] + common_server_args,
stderr=PIPE,
text=True,
encoding="utf-8",
Expand Down
15 changes: 9 additions & 6 deletions tests/test_util_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _future_creator():
@pytest.mark.forked
def test_mpfuture_done_callback():
receiver, sender = mp.Pipe(duplex=False)
events = [mp.Event() for _ in range(6)]
events = [mp.Event() for _ in range(7)]

def _future_creator():
future1, future2, future3 = hivemind.MPFuture(), hivemind.MPFuture(), hivemind.MPFuture()
Expand All @@ -250,7 +250,7 @@ def _check_result_and_set(future):

sender.send((future1, future2))
future2.cancel() # trigger future2 callback from the same process

events[6].set()
events[0].wait()
future1.add_done_callback(
lambda future: events[4].set()
Expand All @@ -262,6 +262,7 @@ def _check_result_and_set(future):

future1, future2 = receiver.recv()
future1.set_result(123)
events[6].wait()

with pytest.raises(RuntimeError):
future1.add_done_callback(lambda future: (1, 2, 3))
Expand Down Expand Up @@ -514,21 +515,23 @@ async def test_async_context_flooding():
Here's how the test below works: suppose that the thread pool has at most N workers;
If at least N + 1 coroutines await lock1 concurrently, N of them occupy workers and the rest are awaiting workers;
When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep(1e-2);
When the first of N workers acquires lock1, it lets coroutine A inside lock1 and into await sleep();
During that sleep, one of the worker-less coroutines will take up the worker freed by coroutine A.
Finally, coroutine A finishes sleeping and immediately gets stuck at lock2, because there are no free workers.
Thus, every single coroutine is either awaiting an already acquired lock, or awaiting for free workers in executor.
"""
total_sleep_time = 1
lock1, lock2 = mp.Lock(), mp.Lock()

num_coros = max(33, mp.cpu_count() * 5 + 1)

async def coro():
async with enter_asynchronously(lock1):
await asyncio.sleep(1e-2)
await asyncio.sleep(total_sleep_time/(num_coros*2))
async with enter_asynchronously(lock2):
await asyncio.sleep(1e-2)
await asyncio.sleep(total_sleep_time/(num_coros*2))

num_coros = max(33, mp.cpu_count() * 5 + 1)
await asyncio.wait({asyncio.create_task(coro()) for _ in range(num_coros)})


Expand Down

0 comments on commit c8dee55

Please sign in to comment.