Skip to content

Commit

Permalink
[Dataflow Streaming] Optimize failed key processing by indexing worki…
Browse files Browse the repository at this point in the history
…tems by sharding key (#33755)
  • Loading branch information
arunpandianp authored Feb 4, 2025
1 parent 9064743 commit 7bec450
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 165 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@
package org.apache.beam.runners.dataflow.worker.streaming;

import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList.toImmutableList;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap.flatteningToImmutableListMultimap;

import java.io.PrintWriter;
import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Deque;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
Expand All @@ -36,14 +33,13 @@
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache.ForComputation;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
Expand All @@ -63,11 +59,11 @@ public final class ActiveWorkState {
private static final int MAX_PRINTABLE_COMMIT_PENDING_KEYS = 50;

/**
* Map from {@link ShardedKey} to {@link Work} for the key. The first item in the {@link
* Queue<Work>} is actively processing.
* Map from shardingKey to {@link Work} for the key. The first item in the {@link LinkedHashMap}
* is actively processing.
*/
@GuardedBy("this")
private final Map<ShardedKey, Deque<ExecutableWork>> activeWork;
private final Map<Long /*shardingKey*/, LinkedHashMap<WorkId, ExecutableWork>> activeWork;

@GuardedBy("this")
private final WindmillStateCache.ForComputation computationStateCache;
Expand All @@ -81,8 +77,8 @@ public final class ActiveWorkState {
private GetWorkBudget activeGetWorkBudget;

private ActiveWorkState(
Map<ShardedKey, Deque<ExecutableWork>> activeWork,
WindmillStateCache.ForComputation computationStateCache) {
Map<Long, LinkedHashMap<WorkId, ExecutableWork>> activeWork,
ForComputation computationStateCache) {
this.activeWork = activeWork;
this.computationStateCache = computationStateCache;
this.activeGetWorkBudget = GetWorkBudget.noBudget();
Expand All @@ -94,7 +90,7 @@ static ActiveWorkState create(WindmillStateCache.ForComputation computationState

@VisibleForTesting
static ActiveWorkState forTesting(
Map<ShardedKey, Deque<ExecutableWork>> activeWork,
Map<Long, LinkedHashMap<WorkId, ExecutableWork>> activeWork,
WindmillStateCache.ForComputation computationStateCache) {
return new ActiveWorkState(activeWork, computationStateCache);
}
Expand Down Expand Up @@ -124,28 +120,30 @@ private static String elapsedString(Instant start, Instant end) {
*/
synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork) {
ShardedKey shardedKey = executableWork.work().getShardedKey();
Deque<ExecutableWork> workQueue = activeWork.getOrDefault(shardedKey, new ArrayDeque<>());
long shardingKey = shardedKey.shardingKey();
LinkedHashMap<WorkId, ExecutableWork> workQueue =
activeWork.computeIfAbsent(shardingKey, (unused) -> new LinkedHashMap<>());
// This key does not have any work queued up on it. Create one, insert Work, and mark the work
// to be executed.
if (!activeWork.containsKey(shardedKey) || workQueue.isEmpty()) {
workQueue.addLast(executableWork);
activeWork.put(shardedKey, workQueue);
if (workQueue.isEmpty()) {
workQueue.put(executableWork.id(), executableWork);
incrementActiveWorkBudget(executableWork.work());
return ActivateWorkResult.EXECUTE;
}

// Check to see if we have this work token queued.
Iterator<ExecutableWork> workIterator = workQueue.iterator();
Iterator<Entry<WorkId, ExecutableWork>> workIterator = workQueue.entrySet().iterator();
while (workIterator.hasNext()) {
ExecutableWork queuedWork = workIterator.next();
ExecutableWork queuedWork = workIterator.next().getValue();
if (queuedWork.id().equals(executableWork.id())) {
return ActivateWorkResult.DUPLICATE;
}
if (queuedWork.id().cacheToken() == executableWork.id().cacheToken()) {
if (queuedWork.id().cacheToken() == executableWork.id().cacheToken()
&& queuedWork.work().getShardedKey().equals(executableWork.work().getShardedKey())) {
if (executableWork.id().workToken() > queuedWork.id().workToken()) {
// Check to see if the queuedWork is active. We only want to remove it if it is NOT
// currently active.
if (!queuedWork.equals(workQueue.peek())) {
if (!queuedWork.equals(Preconditions.checkNotNull(firstValue(workQueue)))) {
workIterator.remove();
decrementActiveWorkBudget(queuedWork.work());
}
Expand All @@ -157,7 +155,7 @@ synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork
}

// Queue the work for later processing.
workQueue.addLast(executableWork);
workQueue.put(executableWork.id(), executableWork);
incrementActiveWorkBudget(executableWork.work());
return ActivateWorkResult.QUEUED;
}
Expand All @@ -167,54 +165,29 @@ synchronized ActivateWorkResult activateWorkForKey(ExecutableWork executableWork
*
* @param failedWork a map from sharding_key to tokens for the corresponding work.
*/
synchronized void failWorkForKey(Multimap<Long, WorkId> failedWork) {
// Note we can't construct a ShardedKey and look it up in activeWork directly since
// HeartbeatResponse doesn't include the user key.
for (Entry<ShardedKey, Deque<ExecutableWork>> entry : activeWork.entrySet()) {
Collection<WorkId> failedWorkIds = failedWork.get(entry.getKey().shardingKey());
for (WorkId failedWorkId : failedWorkIds) {
for (ExecutableWork queuedWork : entry.getValue()) {
WorkItem workItem = queuedWork.work().getWorkItem();
if (workItem.getWorkToken() == failedWorkId.workToken()
&& workItem.getCacheToken() == failedWorkId.cacheToken()) {
LOG.debug(
"Failing work "
+ computationStateCache.getComputation()
+ " "
+ entry.getKey().shardingKey()
+ " "
+ failedWorkId.workToken()
+ " "
+ failedWorkId.cacheToken()
+ ". The work will be retried and is not lost.");
queuedWork.work().setFailed();
break;
}
}
synchronized void failWorkForKey(ImmutableList<WorkIdWithShardingKey> failedWork) {
for (WorkIdWithShardingKey failedId : failedWork) {
@Nullable
LinkedHashMap<WorkId, ExecutableWork> workQueue = activeWork.get(failedId.shardingKey());
if (workQueue == null) {
// Work could complete/fail before heartbeat response arrives
continue;
}
@Nullable ExecutableWork executableWork = workQueue.get(failedId.workId());
if (executableWork == null) {
continue;
}
executableWork.work().setFailed();
LOG.debug(
"Failing work {} {}. The work will be retried and is not lost.",
computationStateCache.getComputation(),
failedId);
}
}

/**
* Returns a read only view of current active work.
*
* @implNote Do not return a reference to the underlying workQueue as iterations over it will
* cause a {@link java.util.ConcurrentModificationException} as it is not a thread-safe data
* structure.
*/
synchronized ImmutableListMultimap<ShardedKey, RefreshableWork> getReadOnlyActiveWork() {
return activeWork.entrySet().stream()
.collect(
flatteningToImmutableListMultimap(
Entry::getKey,
e ->
e.getValue().stream()
.map(executableWork -> (RefreshableWork) executableWork.work())));
}

synchronized ImmutableList<RefreshableWork> getRefreshableWork(Instant refreshDeadline) {
return activeWork.values().stream()
.flatMap(Deque::stream)
.flatMap(workMap -> workMap.values().stream())
.map(ExecutableWork::work)
.filter(work -> !work.isFailed() && work.getStartTime().isBefore(refreshDeadline))
.collect(toImmutableList());
Expand All @@ -236,7 +209,8 @@ private synchronized void decrementActiveWorkBudget(Work work) {
*/
synchronized Optional<ExecutableWork> completeWorkAndGetNextWorkForKey(
ShardedKey shardedKey, WorkId workId) {
@Nullable Queue<ExecutableWork> workQueue = activeWork.get(shardedKey);
@Nullable
LinkedHashMap<WorkId, ExecutableWork> workQueue = activeWork.get(shardedKey.shardingKey());
if (workQueue == null) {
// Work may have been completed due to clearing of stuck commits.
LOG.warn(
Expand All @@ -251,14 +225,15 @@ synchronized Optional<ExecutableWork> completeWorkAndGetNextWorkForKey(
}

private synchronized void removeCompletedWorkFromQueue(
Queue<ExecutableWork> workQueue, ShardedKey shardedKey, WorkId workId) {
@Nullable ExecutableWork completedWork = workQueue.peek();
if (completedWork == null) {
LinkedHashMap<WorkId, ExecutableWork> workQueue, ShardedKey shardedKey, WorkId workId) {
Iterator<Entry<WorkId, ExecutableWork>> completedWorkIterator = workQueue.entrySet().iterator();
if (!completedWorkIterator.hasNext()) {
// Work may have been completed due to clearing of stuck commits.
LOG.warn("Active key {} without work, expected token {}", shardedKey, workId);
return;
}

ExecutableWork completedWork = completedWorkIterator.next().getValue();
if (!completedWork.id().equals(workId)) {
// Work may have been completed due to clearing of stuck commits.
LOG.warn(
Expand All @@ -271,19 +246,18 @@ private synchronized void removeCompletedWorkFromQueue(
completedWork.id());
return;
}

// We consumed the matching work item.
workQueue.remove();
completedWorkIterator.remove();
decrementActiveWorkBudget(completedWork.work());
}

@SuppressWarnings("ReferenceEquality")
private synchronized Optional<ExecutableWork> getNextWork(
Queue<ExecutableWork> workQueue, ShardedKey shardedKey) {
Optional<ExecutableWork> nextWork = Optional.ofNullable(workQueue.peek());
LinkedHashMap<WorkId, ExecutableWork> workQueue, ShardedKey shardedKey) {
Optional<ExecutableWork> nextWork = Optional.ofNullable(firstValue(workQueue));
if (!nextWork.isPresent()) {
Preconditions.checkState(workQueue == activeWork.remove(shardedKey));
Preconditions.checkState(workQueue == activeWork.remove(shardedKey.shardingKey()));
}

return nextWork;
}

Expand All @@ -302,22 +276,26 @@ synchronized void invalidateStuckCommits(
}
}

private static @Nullable ExecutableWork firstValue(Map<WorkId, ExecutableWork> map) {
Iterator<Entry<WorkId, ExecutableWork>> iterator = map.entrySet().iterator();
return iterator.hasNext() ? iterator.next().getValue() : null;
}

private synchronized ImmutableMap<ShardedKey, WorkId> getStuckCommitsAt(
Instant stuckCommitDeadline) {
// Determine the stuck commit keys but complete them outside the loop iterating over
// activeWork as completeWork may delete the entry from activeWork.
ImmutableMap.Builder<ShardedKey, WorkId> stuckCommits = ImmutableMap.builder();
for (Entry<ShardedKey, Deque<ExecutableWork>> entry : activeWork.entrySet()) {
ShardedKey shardedKey = entry.getKey();
@Nullable ExecutableWork executableWork = entry.getValue().peek();
for (Entry<Long, LinkedHashMap<WorkId, ExecutableWork>> entry : activeWork.entrySet()) {
@Nullable ExecutableWork executableWork = firstValue(entry.getValue());
if (executableWork != null) {
Work work = executableWork.work();
if (work.isStuckCommittingAt(stuckCommitDeadline)) {
LOG.error(
"Detected key {} stuck in COMMITTING state since {}, completing it with error.",
shardedKey,
work.getShardedKey(),
work.getStateStartTime());
stuckCommits.put(shardedKey, work.id());
stuckCommits.put(work.getShardedKey(), work.id());
}
}
}
Expand Down Expand Up @@ -353,9 +331,10 @@ synchronized void printActiveWork(PrintWriter writer, Instant now) {
// Use StringBuilder because we are appending in loop.
StringBuilder activeWorkStatus = new StringBuilder();
int commitsPendingCount = 0;
for (Map.Entry<ShardedKey, Deque<ExecutableWork>> entry : activeWork.entrySet()) {
Queue<ExecutableWork> workQueue = Preconditions.checkNotNull(entry.getValue());
Work activeWork = Preconditions.checkNotNull(workQueue.peek()).work();
for (Entry<Long, LinkedHashMap<WorkId, ExecutableWork>> entry : activeWork.entrySet()) {
LinkedHashMap<WorkId, ExecutableWork> workQueue =
Preconditions.checkNotNull(entry.getValue());
Work activeWork = Preconditions.checkNotNull(firstValue(workQueue)).work();
WorkItem workItem = activeWork.getWorkItem();
if (activeWork.isCommitPending()) {
if (++commitsPendingCount >= MAX_PRINTABLE_COMMIT_PENDING_KEYS) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
import org.joda.time.Instant;

/**
Expand Down Expand Up @@ -120,7 +118,7 @@ public boolean activateWork(ExecutableWork executableWork) {
}
}

public void failWork(Multimap<Long, WorkId> failedWork) {
public void failWork(ImmutableList<WorkIdWithShardingKey> failedWork) {
activeWorkState.failWorkForKey(failedWork);
}

Expand All @@ -146,10 +144,6 @@ private void forceExecute(ExecutableWork executableWork) {
executor.forceExecute(executableWork, executableWork.work().getSerializedWorkItemSize());
}

public ImmutableListMultimap<ShardedKey, RefreshableWork> currentActiveWorkReadOnly() {
return activeWorkState.getReadOnlyActiveWork();
}

public ImmutableList<RefreshableWork> getRefreshableWork(Instant refreshDeadline) {
return activeWorkState.getRefreshableWork(refreshDeadline);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationHeartbeatResponse;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatResponse;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ArrayListMultimap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Multimap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;

/**
* Processes {@link ComputationHeartbeatResponse}(s). Marks {@link Work} that is invalid from
Expand All @@ -34,6 +33,7 @@
@Internal
public final class WorkHeartbeatResponseProcessor
implements Consumer<List<ComputationHeartbeatResponse>> {

/** Fetches a {@link ComputationState} for a computationId. */
private final Function<String, Optional<ComputationState>> computationStateFetcher;

Expand All @@ -46,23 +46,23 @@ public WorkHeartbeatResponseProcessor(
@Override
public void accept(List<ComputationHeartbeatResponse> responses) {
for (ComputationHeartbeatResponse computationHeartbeatResponse : responses) {
// Maps sharding key to (work token, cache token) for work that should be marked failed.
Multimap<Long, WorkId> failedWork = ArrayListMultimap.create();
ImmutableList.Builder<WorkIdWithShardingKey> failedWorkBuilder = ImmutableList.builder();
for (HeartbeatResponse heartbeatResponse :
computationHeartbeatResponse.getHeartbeatResponsesList()) {
if (heartbeatResponse.getFailed()) {
failedWork.put(
heartbeatResponse.getShardingKey(),
WorkId workId =
WorkId.builder()
.setWorkToken(heartbeatResponse.getWorkToken())
.setCacheToken(heartbeatResponse.getCacheToken())
.build());
.build();
failedWorkBuilder.add(
WorkIdWithShardingKey.create(heartbeatResponse.getShardingKey(), workId));
}
}

computationStateFetcher
.apply(computationHeartbeatResponse.getComputationId())
.ifPresent(state -> state.failWork(failedWork));
.ifPresent(state -> state.failWork(failedWorkBuilder.build()));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.dataflow.worker.streaming;

import com.google.auto.value.AutoValue;

@AutoValue
abstract class WorkIdWithShardingKey {

public static WorkIdWithShardingKey create(long shardingKey, WorkId workId) {
return new AutoValue_WorkIdWithShardingKey(shardingKey, workId);
}

public abstract long shardingKey();

public abstract WorkId workId();
}
Loading

0 comments on commit 7bec450

Please sign in to comment.