Skip to content

Commit

Permalink
Support sender tracking for reducing computations
Browse files Browse the repository at this point in the history
  • Loading branch information
s1ck committed Oct 16, 2024
1 parent d14088a commit f6f10e0
Show file tree
Hide file tree
Showing 9 changed files with 381 additions and 28 deletions.
30 changes: 28 additions & 2 deletions pregel/src/main/java/org/neo4j/gds/beta/pregel/Messages.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.jetbrains.annotations.NotNull;

import java.util.Iterator;
import java.util.OptionalLong;
import java.util.PrimitiveIterator;

public final class Messages implements Iterable<Double> {
Expand All @@ -33,7 +34,12 @@ public Iterator<Double> iterator() {
}

public interface MessageIterator extends PrimitiveIterator.OfDouble {

boolean isEmpty();

default OptionalLong sender() {
return OptionalLong.empty();
}
}

private final MessageIterator iterator;
Expand All @@ -42,12 +48,32 @@ public interface MessageIterator extends PrimitiveIterator.OfDouble {
this.iterator = iterator;
}

/**
* Returns a iterator that can be used to iterate over the messages.
*/
@NotNull
public PrimitiveIterator.OfDouble doubleIterator() {
return iterator;
return this.iterator;
}

/**
* Indicates if there are messages present.
*/
public boolean isEmpty() {
return iterator.isEmpty();
return this.iterator.isEmpty();
}

/**
* If the computation defined a {@link org.neo4j.gds.beta.pregel.Reducer}, this method will
* return the sender of the aggregated message. Depending on the reducer implementation, the
* sender is deterministically defined by the reducer, e.g., for Max or Min. In any other case,
* the sender will be one of the node ids that sent messages to that node.
* <p>
* Note, that {@link PregelConfig#trackSender()} must return true to enable sender tracking.
*
* @return the sender of an aggregated message or an empty optional if no reducer is defined
*/
public OptionalLong sender() {
return this.iterator.sender();
}
}
6 changes: 6 additions & 0 deletions pregel/src/main/java/org/neo4j/gds/beta/pregel/Messenger.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
*/
package org.neo4j.gds.beta.pregel;

import java.util.OptionalLong;

public interface Messenger<ITERATOR extends Messages.MessageIterator> {

void initIteration(int iteration);
Expand All @@ -29,5 +31,9 @@ public interface Messenger<ITERATOR extends Messages.MessageIterator> {

void initMessageIterator(ITERATOR messageIterator, long nodeId, boolean isFirstIteration);

default OptionalLong sender(long nodeId) {
return OptionalLong.empty();
}

void release();
}
13 changes: 11 additions & 2 deletions pregel/src/main/java/org/neo4j/gds/beta/pregel/Pregel.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ public static MemoryEstimation memoryEstimation(
Map<String, ValueType> propertiesMap,
boolean isQueueBased,
boolean isAsync
) {
return memoryEstimation(propertiesMap, isQueueBased, isAsync, false);
}

public static MemoryEstimation memoryEstimation(
Map<String, ValueType> propertiesMap,
boolean isQueueBased,
boolean isAsync,
boolean isTrackingSender
) {
var estimationBuilder = MemoryEstimations.builder(Pregel.class)
.perNode("vote bits", HugeAtomicBitSet::memoryEstimation)
Expand All @@ -123,7 +132,7 @@ public static MemoryEstimation memoryEstimation(
estimationBuilder.add("message queues", SyncQueueMessenger.memoryEstimation());
}
} else {
estimationBuilder.add("message arrays", ReducingMessenger.memoryEstimation());
estimationBuilder.add("message arrays", ReducingMessenger.memoryEstimation(isTrackingSender));
}

return estimationBuilder.build();
Expand Down Expand Up @@ -169,7 +178,7 @@ private Pregel(
var reducer = computation.reducer();

this.messenger = reducer.isPresent()
? new ReducingMessenger(graph, config, reducer.get())
? ReducingMessenger.create(graph, config, reducer.get())
: config.isAsynchronous()
? new AsyncQueueMessenger(graph.nodeCount())
: new SyncQueueMessenger(graph.nodeCount());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,9 @@ default Partitioning partitioning() {
default boolean useForkJoin() {
return partitioning() == Partitioning.AUTO;
}

@Configuration.Ignore
default boolean trackSender() {
return false;
}
}
117 changes: 101 additions & 16 deletions pregel/src/main/java/org/neo4j/gds/beta/pregel/ReducingMessenger.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,49 +20,70 @@
package org.neo4j.gds.beta.pregel;

import org.neo4j.gds.api.Graph;
import org.neo4j.gds.collections.ha.HugeLongArray;
import org.neo4j.gds.collections.haa.HugeAtomicDoubleArray;
import org.neo4j.gds.core.concurrency.ParallelUtil;
import org.neo4j.gds.termination.TerminationFlag;
import org.neo4j.gds.core.utils.paged.ParallelDoublePageCreator;
import org.neo4j.gds.mem.MemoryEstimation;
import org.neo4j.gds.mem.MemoryEstimations;
import org.neo4j.gds.core.utils.paged.ParallelDoublePageCreator;
import org.neo4j.gds.termination.TerminationFlag;

import java.util.OptionalLong;

/**
* A messenger implementation that is backed by two double arrays used
* to send and receive messages. The messenger can only be applied in
* combination with a {@link Reducer}
* which atomically reduces all incoming messages into a single one.
*/
public class ReducingMessenger implements Messenger<ReducingMessenger.SingleMessageIterator> {
class ReducingMessenger implements Messenger<ReducingMessenger.SingleMessageIterator> {

private final Graph graph;
private final PregelConfig config;
private final Reducer reducer;
final Reducer reducer;

private HugeAtomicDoubleArray sendArray;
private HugeAtomicDoubleArray receiveArray;
HugeAtomicDoubleArray sendArray;
HugeAtomicDoubleArray receiveArray;

ReducingMessenger(Graph graph, PregelConfig config, Reducer reducer) {
assert !Double.isNaN(reducer.identity()): "identity element must not be NaN";
static ReducingMessenger create(Graph graph, PregelConfig config, Reducer reducer) {
return config.trackSender()
? new WithSender(graph, config, reducer)
: new ReducingMessenger(graph, config, reducer);
}

private ReducingMessenger(Graph graph, PregelConfig config, Reducer reducer) {
assert !Double.isNaN(reducer.identity()) : "identity element must not be NaN";

this.graph = graph;
this.config = config;
this.reducer = reducer;

this.receiveArray = HugeAtomicDoubleArray.of(graph.nodeCount(), ParallelDoublePageCreator.passThrough(config.concurrency()));
this.sendArray = HugeAtomicDoubleArray.of(graph.nodeCount(), ParallelDoublePageCreator.passThrough(config.concurrency()));
this.receiveArray = HugeAtomicDoubleArray.of(
graph.nodeCount(),
ParallelDoublePageCreator.passThrough(config.concurrency())
);
this.sendArray = HugeAtomicDoubleArray.of(
graph.nodeCount(),
ParallelDoublePageCreator.passThrough(config.concurrency())
);
}

static MemoryEstimation memoryEstimation() {
return MemoryEstimations.builder(ReducingMessenger.class)
static MemoryEstimation memoryEstimation(boolean withSender) {
var builder = MemoryEstimations.builder(ReducingMessenger.class)
.perNode("send array", HugeAtomicDoubleArray::memoryEstimation)
.perNode("receive array", HugeAtomicDoubleArray::memoryEstimation)
.perNode("receive array", HugeAtomicDoubleArray::memoryEstimation);

if (withSender) {
builder
.perNode("send sender array", HugeLongArray::memoryEstimation)
.perNode("receive sender array", HugeLongArray::memoryEstimation);
}
return builder
.build();
}

@Override
public void initIteration(int iteration) {
// Swap arrays
var tmp = receiveArray;
this.receiveArray = sendArray;
this.sendArray = tmp;
Expand Down Expand Up @@ -96,7 +117,7 @@ public void initMessageIterator(
boolean isInitialIteration
) {
var message = receiveArray.getAndReplace(nodeId, reducer.identity());
messageIterator.init(message, message != reducer.identity());
messageIterator.init(message, message != reducer.identity(), OptionalLong.empty());
}

@Override
Expand All @@ -105,14 +126,73 @@ public void release() {
receiveArray.release();
}

static class WithSender extends ReducingMessenger {
private HugeLongArray sendSenderArray;
private HugeLongArray receiveSenderArray;

WithSender(Graph graph, PregelConfig config, Reducer reducer) {
super(graph, config, reducer);
this.sendSenderArray = HugeLongArray.newArray(graph.nodeCount());
this.receiveSenderArray = HugeLongArray.newArray(graph.nodeCount());
}

@Override
public void initIteration(int iteration) {
super.initIteration(iteration);
// Swap sender arrays
var tmp = receiveSenderArray;
this.receiveSenderArray = sendSenderArray;
this.sendSenderArray = tmp;
}

@Override
public void initMessageIterator(
ReducingMessenger.SingleMessageIterator messageIterator,
long nodeId,
boolean isInitialIteration
) {
var message = receiveArray.getAndReplace(nodeId, reducer.identity());
var sender = receiveSenderArray.get(nodeId);
messageIterator.init(message, message != reducer.identity(), OptionalLong.of(sender));
}

@Override
public void sendTo(long sourceNodeId, long targetNodeId, double message) {
sendArray.update(
targetNodeId,
currentMessage -> {
var reducedMessage = reducer.reduce(currentMessage, message);
if (Double.compare(reducedMessage, currentMessage) != 0) {
sendSenderArray.set(targetNodeId, sourceNodeId);
}
return reducedMessage;
}
);
}

@Override
public OptionalLong sender(long nodeId) {
return OptionalLong.of(receiveSenderArray.get(nodeId));
}

@Override
public void release() {
sendSenderArray.release();
receiveSenderArray.release();
super.release();
}
}

static class SingleMessageIterator implements Messages.MessageIterator {

boolean hasNext;
double message;
OptionalLong sender;

void init(double value, boolean hasNext) {
void init(double value, boolean hasNext, OptionalLong sender) {
this.message = value;
this.hasNext = hasNext;
this.sender = sender;
}

@Override
Expand All @@ -130,5 +210,10 @@ public double nextDouble() {
hasNext = false;
return message;
}

@Override
public OptionalLong sender() {
return this.sender;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ public long longNodeValue(String key) {
return nodeValue.longValue(key, nodeId);
}

/**
* Returns the node value for the given node-id and node schema key.
*
* @throws IllegalArgumentException if the key does not exist or the value is not a long
*/
public long longNodeValue(String key, long nodeId) {
return nodeValue.longValue(key, nodeId);
}

/**
* Returns the node value for the given node schema key.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.neo4j.gds.beta.pregel.PregelConfig;
import org.neo4j.gds.core.utils.progress.tasks.ProgressTracker;

import java.util.OptionalLong;
import java.util.function.LongConsumer;

public abstract class NodeCentricContext<CONFIG extends PregelConfig> extends PregelContext<CONFIG> {
Expand All @@ -33,8 +34,11 @@ public abstract class NodeCentricContext<CONFIG extends PregelConfig> extends Pr

protected final Graph graph;

private OptionalLong sender = OptionalLong.empty();

long nodeId;


NodeCentricContext(Graph graph, CONFIG config, NodeValue nodeValue, ProgressTracker progressTracker) {
super(config, progressTracker);
this.graph = graph;
Expand Down
Loading

0 comments on commit f6f10e0

Please sign in to comment.