From 7bec450c0fbcdeb71c58d92f1a47207be8845b5a Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Tue, 4 Feb 2025 06:17:54 -0800 Subject: [PATCH] [Dataflow Streaming] Optimize failed key processing by indexing workitems by sharding key (#33755) --- .../worker/streaming/ActiveWorkState.java | 141 +++++++-------- .../worker/streaming/ComputationState.java | 8 +- .../WorkHeartbeatResponseProcessor.java | 16 +- .../streaming/WorkIdWithShardingKey.java | 32 ++++ .../worker/StreamingDataflowWorkerTest.java | 117 +++++++----- .../worker/streaming/ActiveWorkStateTest.java | 170 +++++++++++++++--- 6 files changed, 319 insertions(+), 165 deletions(-) create mode 100644 runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkIdWithShardingKey.java diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java index 2fd5e5457629..e430f6c8f638 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java @@ -18,14 +18,11 @@ package org.apache.beam.runners.dataflow.worker.streaming; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList; -import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap.flatteningToImmutableListMultimap; import java.io.PrintWriter; -import java.util.ArrayDeque; -import java.util.Collection; -import java.util.Deque; import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; @@ -36,14 +33,13 @@ import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; +import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache.ForComputation; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; @@ -63,11 +59,11 @@ public final class ActiveWorkState { private static final int MAX_PRINTABLE_COMMIT_PENDING_KEYS = 50; /** - * Map from {@link ShardedKey} to {@link Work} for the key. The first item in the {@link - * Queue} is actively processing. + * Map from shardingKey to {@link Work} for the key. The first item in the {@link LinkedHashMap} + * is actively processing. */ @GuardedBy("this") - private final Map> activeWork; + private final Map> activeWork; @GuardedBy("this") private final WindmillStateCache.ForComputation computationStateCache; @@ -81,8 +77,8 @@ public final class ActiveWorkState { private GetWorkBudget activeGetWorkBudget; private ActiveWorkState( - Map> activeWork, - WindmillStateCache.ForComputation computationStateCache) { + Map> activeWork, + ForComputation computationStateCache) { this.activeWork = activeWork; this.computationStateCache = computationStateCache; this.activeGetWorkBudget = GetWorkBudget.noBudget(); @@ -94,7 +90,7 @@ static ActiveWorkState create(WindmillStateCache.ForComputation computationState @VisibleForTesting static ActiveWorkState forTesting( - Map> activeWork, + Map> activeWork, WindmillStateCache.ForComputation computationStateCache) { return new ActiveWorkState(activeWork, computationStateCache); } @@ -124,28 +120,30 @@ private static String elapsedString(Instant start, Instant end) { */ synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork) { ShardedKey shardedKey = executableWork.work().getShardedKey(); - Deque workQueue = activeWork.getOrDefault(shardedKey, new ArrayDeque<>()); + long shardingKey = shardedKey.shardingKey(); + LinkedHashMap workQueue = + activeWork.computeIfAbsent(shardingKey, (unused) -> new LinkedHashMap<>()); // This key does not have any work queued up on it. Create one, insert Work, and mark the work // to be executed. - if (!activeWork.containsKey(shardedKey) || workQueue.isEmpty()) { - workQueue.addLast(executableWork); - activeWork.put(shardedKey, workQueue); + if (workQueue.isEmpty()) { + workQueue.put(executableWork.id(), executableWork); incrementActiveWorkBudget(executableWork.work()); return ActivateWorkResult.EXECUTE; } // Check to see if we have this work token queued. - Iterator workIterator = workQueue.iterator(); + Iterator> workIterator = workQueue.entrySet().iterator(); while (workIterator.hasNext()) { - ExecutableWork queuedWork = workIterator.next(); + ExecutableWork queuedWork = workIterator.next().getValue(); if (queuedWork.id().equals(executableWork.id())) { return ActivateWorkResult.DUPLICATE; } - if (queuedWork.id().cacheToken() == executableWork.id().cacheToken()) { + if (queuedWork.id().cacheToken() == executableWork.id().cacheToken() + && queuedWork.work().getShardedKey().equals(executableWork.work().getShardedKey())) { if (executableWork.id().workToken() > queuedWork.id().workToken()) { // Check to see if the queuedWork is active. We only want to remove it if it is NOT // currently active. - if (!queuedWork.equals(workQueue.peek())) { + if (!queuedWork.equals(Preconditions.checkNotNull(firstValue(workQueue)))) { workIterator.remove(); decrementActiveWorkBudget(queuedWork.work()); } @@ -157,7 +155,7 @@ synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork } // Queue the work for later processing. - workQueue.addLast(executableWork); + workQueue.put(executableWork.id(), executableWork); incrementActiveWorkBudget(executableWork.work()); return ActivateWorkResult.QUEUED; } @@ -167,54 +165,29 @@ synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork * * @param failedWork a map from sharding_key to tokens for the corresponding work. */ - synchronized void failWorkForKey(Multimap failedWork) { - // Note we can't construct a ShardedKey and look it up in activeWork directly since - // HeartbeatResponse doesn't include the user key. - for (Entry> entry : activeWork.entrySet()) { - Collection failedWorkIds = failedWork.get(entry.getKey().shardingKey()); - for (WorkId failedWorkId : failedWorkIds) { - for (ExecutableWork queuedWork : entry.getValue()) { - WorkItem workItem = queuedWork.work().getWorkItem(); - if (workItem.getWorkToken() == failedWorkId.workToken() - && workItem.getCacheToken() == failedWorkId.cacheToken()) { - LOG.debug( - "Failing work " - + computationStateCache.getComputation() - + " " - + entry.getKey().shardingKey() - + " " - + failedWorkId.workToken() - + " " - + failedWorkId.cacheToken() - + ". The work will be retried and is not lost."); - queuedWork.work().setFailed(); - break; - } - } + synchronized void failWorkForKey(ImmutableList failedWork) { + for (WorkIdWithShardingKey failedId : failedWork) { + @Nullable + LinkedHashMap workQueue = activeWork.get(failedId.shardingKey()); + if (workQueue == null) { + // Work could complete/fail before heartbeat response arrives + continue; + } + @Nullable ExecutableWork executableWork = workQueue.get(failedId.workId()); + if (executableWork == null) { + continue; } + executableWork.work().setFailed(); + LOG.debug( + "Failing work {} {}. The work will be retried and is not lost.", + computationStateCache.getComputation(), + failedId); } } - /** - * Returns a read only view of current active work. - * - * @implNote Do not return a reference to the underlying workQueue as iterations over it will - * cause a {@link java.util.ConcurrentModificationException} as it is not a thread-safe data - * structure. - */ - synchronized ImmutableListMultimap getReadOnlyActiveWork() { - return activeWork.entrySet().stream() - .collect( - flatteningToImmutableListMultimap( - Entry::getKey, - e -> - e.getValue().stream() - .map(executableWork -> (RefreshableWork) executableWork.work()))); - } - synchronized ImmutableList getRefreshableWork(Instant refreshDeadline) { return activeWork.values().stream() - .flatMap(Deque::stream) + .flatMap(workMap -> workMap.values().stream()) .map(ExecutableWork::work) .filter(work -> !work.isFailed() && work.getStartTime().isBefore(refreshDeadline)) .collect(toImmutableList()); @@ -236,7 +209,8 @@ private synchronized void decrementActiveWorkBudget(Work work) { */ synchronized Optional completeWorkAndGetNextWorkForKey( ShardedKey shardedKey, WorkId workId) { - @Nullable Queue workQueue = activeWork.get(shardedKey); + @Nullable + LinkedHashMap workQueue = activeWork.get(shardedKey.shardingKey()); if (workQueue == null) { // Work may have been completed due to clearing of stuck commits. LOG.warn( @@ -251,14 +225,15 @@ synchronized Optional completeWorkAndGetNextWorkForKey( } private synchronized void removeCompletedWorkFromQueue( - Queue workQueue, ShardedKey shardedKey, WorkId workId) { - @Nullable ExecutableWork completedWork = workQueue.peek(); - if (completedWork == null) { + LinkedHashMap workQueue, ShardedKey shardedKey, WorkId workId) { + Iterator> completedWorkIterator = workQueue.entrySet().iterator(); + if (!completedWorkIterator.hasNext()) { // Work may have been completed due to clearing of stuck commits. LOG.warn("Active key {} without work, expected token {}", shardedKey, workId); return; } + ExecutableWork completedWork = completedWorkIterator.next().getValue(); if (!completedWork.id().equals(workId)) { // Work may have been completed due to clearing of stuck commits. LOG.warn( @@ -271,19 +246,18 @@ private synchronized void removeCompletedWorkFromQueue( completedWork.id()); return; } - // We consumed the matching work item. - workQueue.remove(); + completedWorkIterator.remove(); decrementActiveWorkBudget(completedWork.work()); } + @SuppressWarnings("ReferenceEquality") private synchronized Optional getNextWork( - Queue workQueue, ShardedKey shardedKey) { - Optional nextWork = Optional.ofNullable(workQueue.peek()); + LinkedHashMap workQueue, ShardedKey shardedKey) { + Optional nextWork = Optional.ofNullable(firstValue(workQueue)); if (!nextWork.isPresent()) { - Preconditions.checkState(workQueue == activeWork.remove(shardedKey)); + Preconditions.checkState(workQueue == activeWork.remove(shardedKey.shardingKey())); } - return nextWork; } @@ -302,22 +276,26 @@ synchronized void invalidateStuckCommits( } } + private static @Nullable ExecutableWork firstValue(Map map) { + Iterator> iterator = map.entrySet().iterator(); + return iterator.hasNext() ? iterator.next().getValue() : null; + } + private synchronized ImmutableMap getStuckCommitsAt( Instant stuckCommitDeadline) { // Determine the stuck commit keys but complete them outside the loop iterating over // activeWork as completeWork may delete the entry from activeWork. ImmutableMap.Builder stuckCommits = ImmutableMap.builder(); - for (Entry> entry : activeWork.entrySet()) { - ShardedKey shardedKey = entry.getKey(); - @Nullable ExecutableWork executableWork = entry.getValue().peek(); + for (Entry> entry : activeWork.entrySet()) { + @Nullable ExecutableWork executableWork = firstValue(entry.getValue()); if (executableWork != null) { Work work = executableWork.work(); if (work.isStuckCommittingAt(stuckCommitDeadline)) { LOG.error( "Detected key {} stuck in COMMITTING state since {}, completing it with error.", - shardedKey, + work.getShardedKey(), work.getStateStartTime()); - stuckCommits.put(shardedKey, work.id()); + stuckCommits.put(work.getShardedKey(), work.id()); } } } @@ -353,9 +331,10 @@ synchronized void printActiveWork(PrintWriter writer, Instant now) { // Use StringBuilder because we are appending in loop. StringBuilder activeWorkStatus = new StringBuilder(); int commitsPendingCount = 0; - for (Map.Entry> entry : activeWork.entrySet()) { - Queue workQueue = Preconditions.checkNotNull(entry.getValue()); - Work activeWork = Preconditions.checkNotNull(workQueue.peek()).work(); + for (Entry> entry : activeWork.entrySet()) { + LinkedHashMap workQueue = + Preconditions.checkNotNull(entry.getValue()); + Work activeWork = Preconditions.checkNotNull(firstValue(workQueue)).work(); WorkItem workItem = activeWork.getWorkItem(); if (activeWork.isCommitPending()) { if (++commitsPendingCount >= MAX_PRINTABLE_COMMIT_PENDING_KEYS) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java index c9be13835c09..3886d4fbc01b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationState.java @@ -29,9 +29,7 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; import org.joda.time.Instant; /** @@ -120,7 +118,7 @@ public boolean activateWork(ExecutableWork executableWork) { } } - public void failWork(Multimap failedWork) { + public void failWork(ImmutableList failedWork) { activeWorkState.failWorkForKey(failedWork); } @@ -146,10 +144,6 @@ private void forceExecute(ExecutableWork executableWork) { executor.forceExecute(executableWork, executableWork.work().getSerializedWorkItemSize()); } - public ImmutableListMultimap currentActiveWorkReadOnly() { - return activeWorkState.getReadOnlyActiveWork(); - } - public ImmutableList getRefreshableWork(Instant refreshDeadline) { return activeWorkState.getRefreshableWork(refreshDeadline); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkHeartbeatResponseProcessor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkHeartbeatResponseProcessor.java index 341f434cefa4..fe35ac801d46 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkHeartbeatResponseProcessor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkHeartbeatResponseProcessor.java @@ -24,8 +24,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatResponse; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; /** * Processes {@link ComputationHeartbeatResponse}(s). Marks {@link Work} that is invalid from @@ -34,6 +33,7 @@ @Internal public final class WorkHeartbeatResponseProcessor implements Consumer> { + /** Fetches a {@link ComputationState} for a computationId. */ private final Function> computationStateFetcher; @@ -46,23 +46,23 @@ public WorkHeartbeatResponseProcessor( @Override public void accept(List responses) { for (ComputationHeartbeatResponse computationHeartbeatResponse : responses) { - // Maps sharding key to (work token, cache token) for work that should be marked failed. - Multimap failedWork = ArrayListMultimap.create(); + ImmutableList.Builder failedWorkBuilder = ImmutableList.builder(); for (HeartbeatResponse heartbeatResponse : computationHeartbeatResponse.getHeartbeatResponsesList()) { if (heartbeatResponse.getFailed()) { - failedWork.put( - heartbeatResponse.getShardingKey(), + WorkId workId = WorkId.builder() .setWorkToken(heartbeatResponse.getWorkToken()) .setCacheToken(heartbeatResponse.getCacheToken()) - .build()); + .build(); + failedWorkBuilder.add( + WorkIdWithShardingKey.create(heartbeatResponse.getShardingKey(), workId)); } } computationStateFetcher .apply(computationHeartbeatResponse.getComputationId()) - .ifPresent(state -> state.failWork(failedWork)); + .ifPresent(state -> state.failWork(failedWorkBuilder.build())); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkIdWithShardingKey.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkIdWithShardingKey.java new file mode 100644 index 000000000000..84f883ee9f21 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/WorkIdWithShardingKey.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import com.google.auto.value.AutoValue; + +@AutoValue +abstract class WorkIdWithShardingKey { + + public static WorkIdWithShardingKey create(long shardingKey, WorkId workId) { + return new AutoValue_WorkIdWithShardingKey(shardingKey, workId); + } + + public abstract long shardingKey(); + + public abstract WorkId workId(); +} diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 9eeb7738dcca..209932965111 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -85,6 +85,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -216,6 +217,7 @@ // released (2.11.0) @SuppressWarnings({"unused", "deprecation"}) public class StreamingDataflowWorkerTest { + private static final Logger LOG = LoggerFactory.getLogger(StreamingDataflowWorkerTest.class); private static final IntervalWindow DEFAULT_WINDOW = new IntervalWindow(new Instant(1234), Duration.millis(1000)); @@ -1042,65 +1044,59 @@ public void testIgnoreRetriedKeys() throws Exception { makeWorker(defaultWorkerParams().setInstructions(instructions).publishCounters().build()); worker.start(); - for (int i = 0; i < numIters; ++i) { + for (int i = 1; i <= numIters; ++i) { server .whenGetWorkCalled() - .thenReturn( - makeInput( - i, TimeUnit.MILLISECONDS.toMicros(i), keyStringForIndex(i), DEFAULT_SHARDING_KEY)) + .thenReturn(makeInput(i, TimeUnit.MILLISECONDS.toMicros(i), keyStringForIndex(i), i)) // Also add work for a different shard of the same key. .thenReturn( makeInput( - i + 1000, - TimeUnit.MILLISECONDS.toMicros(i), - keyStringForIndex(i), - DEFAULT_SHARDING_KEY + 1)); + i + 1000, TimeUnit.MILLISECONDS.toMicros(i), keyStringForIndex(i), i + 1000)); } // Wait for keys to schedule. They will be blocked. - BlockingFn.counter.acquire(numIters * 2); + BlockingFn.counter().acquire(numIters * 2); // Re-add the work, it should be ignored due to the keys being active. - for (int i = 0; i < numIters; ++i) { + for (int i = 1; i <= numIters; ++i) { // Same work token. server .whenGetWorkCalled() - .thenReturn(makeInput(i, TimeUnit.MILLISECONDS.toMicros(i))) + .thenReturn(makeInput(i, TimeUnit.MILLISECONDS.toMicros(i), keyStringForIndex(i), i)) .thenReturn( makeInput( - i + 1000, - TimeUnit.MILLISECONDS.toMicros(i), - keyStringForIndex(i), - DEFAULT_SHARDING_KEY + 1)); + i + 1000, TimeUnit.MILLISECONDS.toMicros(i), keyStringForIndex(i), i + 1000)); } // Give all added calls a chance to run. server.waitForEmptyWorkQueue(); - for (int i = 0; i < numIters; ++i) { + for (int i = 1; i <= numIters; ++i) { // Different work token same keys. server .whenGetWorkCalled() .thenReturn( - makeInput( - i + numIters, - TimeUnit.MILLISECONDS.toMicros(i), - keyStringForIndex(i), - DEFAULT_SHARDING_KEY)); + makeInput(i + numIters, TimeUnit.MILLISECONDS.toMicros(i), keyStringForIndex(i), i)); } // Give all added calls a chance to run. server.waitForEmptyWorkQueue(); // Release the blocked calls. - BlockingFn.blocker.countDown(); + BlockingFn.blocker().countDown(); // Verify the output Map result = server.waitForAndGetCommits(numIters * 3); - for (int i = 0; i < numIters; ++i) { + for (int i = 1; i <= numIters; ++i) { assertTrue(result.containsKey((long) i)); assertEquals( - makeExpectedOutput(i, TimeUnit.MILLISECONDS.toMicros(i)).build(), + makeExpectedOutput( + i, + TimeUnit.MILLISECONDS.toMicros(i), + keyStringForIndex(i), + i, + keyStringForIndex(i)) + .build(), removeDynamicFields(result.get((long) i))); assertTrue(result.containsKey((long) i + 1000)); assertEquals( @@ -1108,7 +1104,7 @@ public void testIgnoreRetriedKeys() throws Exception { i + 1000, TimeUnit.MILLISECONDS.toMicros(i), keyStringForIndex(i), - DEFAULT_SHARDING_KEY + 1, + i + 1000, keyStringForIndex(i)) .build(), removeDynamicFields(result.get((long) i + 1000))); @@ -1118,33 +1114,30 @@ public void testIgnoreRetriedKeys() throws Exception { i + numIters, TimeUnit.MILLISECONDS.toMicros(i), keyStringForIndex(i), - DEFAULT_SHARDING_KEY, + i, keyStringForIndex(i)) .build(), removeDynamicFields(result.get((long) i + numIters))); } // Re-add the work, it should process due to the keys no longer being active. - for (int i = 0; i < numIters; ++i) { + for (int i = 1; i <= numIters; ++i) { server .whenGetWorkCalled() .thenReturn( makeInput( - i + numIters * 2, - TimeUnit.MILLISECONDS.toMicros(i), - keyStringForIndex(i), - DEFAULT_SHARDING_KEY)); + i + numIters * 2, TimeUnit.MILLISECONDS.toMicros(i), keyStringForIndex(i), i)); } result = server.waitForAndGetCommits(numIters); worker.stop(); - for (int i = 0; i < numIters; ++i) { + for (int i = 1; i <= numIters; ++i) { assertTrue(result.containsKey((long) i + numIters * 2)); assertEquals( makeExpectedOutput( i + numIters * 2, TimeUnit.MILLISECONDS.toMicros(i), keyStringForIndex(i), - DEFAULT_SHARDING_KEY, + i, keyStringForIndex(i)) .build(), removeDynamicFields(result.get((long) i + numIters * 2))); @@ -1168,27 +1161,33 @@ public void testNumberOfWorkerHarnessThreadsIsHonored() throws Exception { .build()); worker.start(); - for (int i = 0; i < expectedNumberOfThreads * 2; ++i) { - server.whenGetWorkCalled().thenReturn(makeInput(i, TimeUnit.MILLISECONDS.toMicros(i))); + for (int i = 1; i <= expectedNumberOfThreads * 2; ++i) { + server + .whenGetWorkCalled() + .thenReturn(makeInput(i, TimeUnit.MILLISECONDS.toMicros(i), keyStringForIndex(i), i)); } // This will fail to complete if the number of threads is less than the amount of work. // Forcing this test to timeout. - BlockingFn.counter.acquire(expectedNumberOfThreads); + BlockingFn.counter().acquire(expectedNumberOfThreads); // Attempt to acquire an additional permit, if we were able to then that means // too many items were being processed concurrently. - if (BlockingFn.counter.tryAcquire(500, TimeUnit.MILLISECONDS)) { + if (BlockingFn.counter().tryAcquire(500, TimeUnit.MILLISECONDS)) { fail( "Expected number of threads " + expectedNumberOfThreads + " does not match actual " + "number of work items processed concurrently " - + BlockingFn.callCounter.get() + + BlockingFn.callCounter().get() + "."); } - BlockingFn.blocker.countDown(); + BlockingFn.blocker().countDown(); + + // Wait for semaphore to be released by all harness threads + assertTrue( + BlockingFn.counter().tryAcquire(expectedNumberOfThreads, 500, TimeUnit.MILLISECONDS)); } @Test @@ -3499,7 +3498,7 @@ public void testActiveWorkFailure() throws Exception { server.sendFailedHeartbeats(Collections.singletonList(failedHeartbeat.build())); // Release the blocked calls. - BlockingFn.blocker.countDown(); + BlockingFn.blocker().countDown(); Map commits = server.waitForAndGetCommitsWithTimeout(2, Duration.standardSeconds((5))); assertEquals(1, commits.size()); @@ -4073,27 +4072,43 @@ public void testDefaultNumCommitThreads() { static class BlockingFn extends DoFn implements TestRule { - public static CountDownLatch blocker = new CountDownLatch(1); - public static Semaphore counter = new Semaphore(0); + public static AtomicReference blocker = + new AtomicReference<>(new CountDownLatch(1)); + public static AtomicReference counter = new AtomicReference<>(new Semaphore(0)); public static AtomicInteger callCounter = new AtomicInteger(0); @ProcessElement public void processElement(ProcessContext c) throws InterruptedException { callCounter.incrementAndGet(); - counter.release(); - blocker.await(); + counter().release(); + blocker().await(); c.output(c.element()); } + public static CountDownLatch blocker() { + return blocker.get(); + } + + public static AtomicInteger callCounter() { + return callCounter; + } + + public static Semaphore counter() { + return counter.get(); + } + @Override public Statement apply(final Statement base, final Description description) { return new Statement() { @Override public void evaluate() throws Throwable { - blocker = new CountDownLatch(1); - counter = new Semaphore(0); - callCounter = new AtomicInteger(); - base.evaluate(); + try { + base.evaluate(); + } finally { + blocker.set(new CountDownLatch(1)); + counter.set(new Semaphore(0)); + callCounter.set(0); + } } }; } @@ -4267,6 +4282,7 @@ public void processElement(ProcessContext c) throws Exception { } static class FakeClock implements Supplier { + private final PriorityQueue jobs = new PriorityQueue<>(); private Instant now = Instant.now(); @@ -4305,6 +4321,7 @@ private synchronized void schedule(Duration fromNow, Runnable work) { } private static class Job implements Comparable { + final Instant when; final Runnable work; @@ -4320,6 +4337,7 @@ public int compareTo(Job job) { } private class FakeScheduledExecutor implements ScheduledExecutorService { + @Override public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { return true; @@ -4427,6 +4445,7 @@ public void run() { } private static class FakeSlowDoFn extends DoFn { + private static FakeClock clock; // A static variable keeps this DoFn serializable. private final Duration sleep; @@ -4444,6 +4463,7 @@ public void processElement(ProcessContext c) throws Exception { // Aggregates LatencyAttribution data from active work refresh requests. static class ActiveWorkRefreshSink { + private final Function responder; private final Map> totalDurations = new HashMap<>(); @@ -4518,6 +4538,7 @@ GetDataResponse getData(GetDataRequest request) { // A DoFn that triggers a GetData request. static class ReadingDoFn extends DoFn { + @StateId("int") private final StateSpec> counter = StateSpecs.value(VarIntCoder.of()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java index c6d985c07c53..c0cb8241d73e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java @@ -25,11 +25,15 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import java.util.Arrays; import java.util.Collections; -import java.util.Deque; import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; +import java.util.concurrent.ThreadLocalRandom; import org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.ActivateWorkResult; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; @@ -37,6 +41,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.joda.time.Instant; import org.junit.Before; import org.junit.Rule; @@ -47,10 +52,11 @@ @RunWith(JUnit4.class) public class ActiveWorkStateTest { + private final WindmillStateCache.ForComputation computationStateCache = mock(WindmillStateCache.ForComputation.class); @Rule public transient Timeout globalTimeout = Timeout.seconds(600); - private Map> readOnlyActiveWork; + private Map> readOnlyActiveWork; private ActiveWorkState activeWorkState; @@ -101,7 +107,7 @@ private static Windmill.WorkItem createWorkItem( @Before public void setup() { - Map> readWriteActiveWorkMap = new HashMap<>(); + Map> readWriteActiveWorkMap = new HashMap<>(); // Only use readOnlyActiveWork to verify internal behavior in reaction to exposed API calls. readOnlyActiveWork = Collections.unmodifiableMap(readWriteActiveWorkMap); activeWorkState = ActiveWorkState.forTesting(readWriteActiveWorkMap, computationStateCache); @@ -131,7 +137,7 @@ public void testActivateWorkForKey_EXECUTE_emptyWorkQueueForKey() { assertEquals(ActivateWorkResult.EXECUTE, activateWorkResult); assertEquals(Optional.empty(), nextWorkForKey); - assertThat(readOnlyActiveWork).doesNotContainKey(shardedKey); + assertThat(readOnlyActiveWork).doesNotContainKey(shardedKey.shardingKey()); } @Test @@ -182,8 +188,8 @@ public void testCompleteWorkAndGetNextWorkForKey_noWorkQueueForKey() { activeWorkState.completeWorkAndGetNextWorkForKey( shardedKey, workId(otherWorkToken, cacheToken)); - assertEquals(1, readOnlyActiveWork.get(shardedKey).size()); - assertEquals(workInQueue, readOnlyActiveWork.get(shardedKey).peek()); + assertEquals(1, readOnlyActiveWork.get(shardedKey.shardingKey()).size()); + assertEquals(workInQueue, firstValue(readOnlyActiveWork.get(shardedKey.shardingKey()))); } @Test @@ -197,9 +203,9 @@ public void testCompleteWorkAndGetNextWorkForKey_removesWorkFromQueueWhenComplet activeWorkState.activateWorkForKey(nextWork); activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, activeWork.id()); - assertEquals(nextWork, readOnlyActiveWork.get(shardedKey).peek()); - assertEquals(1, readOnlyActiveWork.get(shardedKey).size()); - assertFalse(readOnlyActiveWork.get(shardedKey).contains(activeWork)); + assertEquals(nextWork, firstValue(readOnlyActiveWork.get(shardedKey.shardingKey()))); + assertEquals(1, readOnlyActiveWork.get(shardedKey.shardingKey()).size()); + assertFalse(readOnlyActiveWork.get(shardedKey.shardingKey()).containsValue(activeWork)); } @Test @@ -211,7 +217,7 @@ public void testCompleteWorkAndGetNextWorkForKey_removesQueueIfNoWorkPresent() { activeWorkState.activateWorkForKey(workInQueue); activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, workInQueue.id()); - assertFalse(readOnlyActiveWork.containsKey(shardedKey)); + assertFalse(readOnlyActiveWork.containsKey(shardedKey.shardingKey())); } @Test @@ -235,7 +241,7 @@ public void testCompleteWorkAndGetNextWorkForKey_returnsWorkIfPresent() { activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, nextWork.id()); assertFalse(endOfWorkQueue.isPresent()); - assertFalse(readOnlyActiveWork.containsKey(shardedKey)); + assertFalse(readOnlyActiveWork.containsKey(shardedKey.shardingKey())); } @Test @@ -346,7 +352,7 @@ public void testInvalidateStuckCommits() { ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(secondWork); assertEquals(ActivateWorkResult.QUEUED, activateWorkResult); - assertTrue(readOnlyActiveWork.get(shardedKey).contains(secondWork)); + assertTrue(readOnlyActiveWork.get(shardedKey.shardingKey()).containsValue(secondWork)); Optional nextWork = activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, differentWorkTokenWork.id()); @@ -373,8 +379,8 @@ public void testInvalidateStuckCommits() { ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(secondWork); assertEquals(ActivateWorkResult.QUEUED, activateWorkResult); - assertEquals(firstWork, readOnlyActiveWork.get(shardedKey).peek()); - assertTrue(readOnlyActiveWork.get(shardedKey).contains(secondWork)); + assertEquals(firstWork, firstValue(readOnlyActiveWork.get(shardedKey.shardingKey()))); + assertTrue(readOnlyActiveWork.get(shardedKey.shardingKey()).containsValue(secondWork)); Optional nextWork = activeWorkState.completeWorkAndGetNextWorkForKey(shardedKey, firstWork.id()); assertTrue(nextWork.isPresent()); @@ -397,8 +403,8 @@ public void testInvalidateStuckCommits() { // newWork should be queued and queuedWork should not be removed since it is currently active. assertEquals(ActivateWorkResult.QUEUED, activateWorkResult); - assertTrue(readOnlyActiveWork.get(shardedKey).contains(newWork)); - assertEquals(queuedWork, readOnlyActiveWork.get(shardedKey).peek()); + assertTrue(readOnlyActiveWork.get(shardedKey.shardingKey()).containsValue(newWork)); + assertEquals(queuedWork, firstValue(readOnlyActiveWork.get(shardedKey.shardingKey()))); } @Test @@ -420,9 +426,10 @@ public void testInvalidateStuckCommits() { ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(newWork); assertEquals(ActivateWorkResult.QUEUED, activateWorkResult); - assertTrue(readOnlyActiveWork.get(shardedKey).contains(newWork)); - assertFalse(readOnlyActiveWork.get(shardedKey).contains(queuedWork)); - assertEquals(differentWorkTokenWork, readOnlyActiveWork.get(shardedKey).peek()); + assertTrue(readOnlyActiveWork.get(shardedKey.shardingKey()).containsValue(newWork)); + assertFalse(readOnlyActiveWork.get(shardedKey.shardingKey()).containsValue(queuedWork)); + assertEquals( + differentWorkTokenWork, firstValue(readOnlyActiveWork.get(shardedKey.shardingKey()))); } @Test @@ -439,7 +446,128 @@ public void testActivateWorkForKey_matchingCacheTokens_newWorkTokenLesser_STALE( ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(newWork); assertEquals(ActivateWorkResult.STALE, activateWorkResult); - assertFalse(readOnlyActiveWork.get(shardedKey).contains(newWork)); - assertEquals(queuedWork, readOnlyActiveWork.get(shardedKey).peek()); + assertFalse(readOnlyActiveWork.get(shardedKey.shardingKey()).containsValue(newWork)); + assertEquals(queuedWork, firstValue(readOnlyActiveWork.get(shardedKey.shardingKey()))); + } + + @Test + public void testFailWork() { + { + long workToken = 0L; + for (long shardingKey : Arrays.asList(1L, 2L)) { + for (String key : Arrays.asList("key1", "key2")) { + for (long cacheToken : Arrays.asList(5L, 6L)) { + ShardedKey shardedKey = shardedKey(key, shardingKey); + ExecutableWork work = createWork(createWorkItem(++workToken, cacheToken, shardedKey)); + ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(work); + assertThat(activateWorkResult) + .isAnyOf(ActivateWorkResult.EXECUTE, ActivateWorkResult.QUEUED); + } + } + } + } + { + long workToken = 0L; + for (long shardingKey : Arrays.asList(1L, 2L)) { + for (String unusedKey : Arrays.asList("key1", "key2")) { + for (long cacheToken : Arrays.asList(5L, 6L)) { + WorkId workId = + WorkId.builder().setWorkToken(++workToken).setCacheToken(cacheToken).build(); + boolean shouldFail = ThreadLocalRandom.current().nextBoolean(); + if (shouldFail) { + activeWorkState.failWorkForKey( + ImmutableList.of(WorkIdWithShardingKey.create(shardingKey, workId))); + } + LinkedHashMap workIdExecutableWorkLinkedHashMap = + readOnlyActiveWork.get(shardingKey); + assertEquals( + shouldFail, workIdExecutableWorkLinkedHashMap.get(workId).work().isFailed()); + } + } + } + } + } + + @Test + public void testFailWork_batchFail() { + + ImmutableList.Builder toFailBuilder1 = ImmutableList.builder(); + ImmutableList.Builder toFailBuilder2 = ImmutableList.builder(); + + { + long workToken = 0L; + for (long shardingKey : Arrays.asList(1L, 2L)) { + for (String key : Arrays.asList("key1", "key2")) { + for (long cacheToken : Arrays.asList(5L, 6L)) { + ++workToken; + ShardedKey shardedKey = shardedKey(key, shardingKey); + ExecutableWork work = createWork(createWorkItem(workToken, cacheToken, shardedKey)); + ActivateWorkResult activateWorkResult = activeWorkState.activateWorkForKey(work); + assertThat(activateWorkResult) + .isAnyOf(ActivateWorkResult.EXECUTE, ActivateWorkResult.QUEUED); + + WorkId workId = + WorkId.builder().setWorkToken(workToken).setCacheToken(cacheToken).build(); + WorkIdWithShardingKey workIdWithShardingKey = + WorkIdWithShardingKey.create(shardingKey, workId); + if (ThreadLocalRandom.current().nextBoolean()) { + toFailBuilder1.add(workIdWithShardingKey); + } else { + toFailBuilder2.add(workIdWithShardingKey); + } + } + } + } + } + ImmutableList toFail1 = toFailBuilder1.build(); + ImmutableList toFail2 = toFailBuilder2.build(); + + activeWorkState.failWorkForKey(toFail1); + + for (WorkIdWithShardingKey workIdWithShardingKey : toFail1) { + assertTrue( + readOnlyActiveWork + .get(workIdWithShardingKey.shardingKey()) + .get(workIdWithShardingKey.workId()) + .work() + .isFailed()); + } + + for (WorkIdWithShardingKey workIdWithShardingKey : toFail2) { + assertFalse( + readOnlyActiveWork + .get(workIdWithShardingKey.shardingKey()) + .get(workIdWithShardingKey.workId()) + .work() + .isFailed()); + } + + activeWorkState.failWorkForKey(toFail2); + + for (WorkIdWithShardingKey workIdWithShardingKey : toFail1) { + assertTrue( + readOnlyActiveWork + .get(workIdWithShardingKey.shardingKey()) + .get(workIdWithShardingKey.workId()) + .work() + .isFailed()); + } + + for (WorkIdWithShardingKey workIdWithShardingKey : toFail2) { + assertTrue( + readOnlyActiveWork + .get(workIdWithShardingKey.shardingKey()) + .get(workIdWithShardingKey.workId()) + .work() + .isFailed()); + } + } + + private static ExecutableWork firstValue(Map map) { + Iterator> iterator = map.entrySet().iterator(); + if (iterator.hasNext()) { + return iterator.next().getValue(); + } + throw new NullPointerException(); } }