Skip to content
This repository was archived by the owner on Apr 22, 2020. It is now read-only.

Parallelise batching of writes for similarity algorithms #814

Open
wants to merge 11 commits into
base: 3.4
Choose a base branch
from
Next Next commit
batching wip
  • Loading branch information
mneedham committed Feb 12, 2019
commit 2777a82426be27e31fe897fca8559960863491b3
4 changes: 4 additions & 0 deletions algo/src/main/java/org/neo4j/graphalgo/impl/DSSResult.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ public Stream<DisjointSetStruct.Result> resultStream(IdMapping idMapping) {
: hugeStruct.resultStream(((HugeIdMapping) idMapping));
}

public Stream<DisjointSetStruct.InternalResult> internalResultStream(IdMapping idMapping) {
return struct.internalResultStream(idMapping);
}

public void forEach(NodeIterator nodes, IntIntPredicate consumer) {
if (struct != null) {
nodes.forEachNode(nodeId -> consumer.apply(nodeId, struct.find(nodeId)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ public Stream<SimilaritySummaryResult> cosine(
Stream<SimilarityResult> stream = generateWeightedStream(configuration, inputs, similarityCutoff, topN, topK, computer);

boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty);
boolean writeParallel = configuration.get("writeParallel", false);

return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel);
}

private SimilarityComputer<WeightedInput> similarityComputer(Double skipValue) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ public Stream<SimilaritySummaryResult> euclidean(
Stream<SimilarityResult> stream = generateWeightedStream(configuration, inputs, similarityCutoff, topN, topK, computer);

boolean write = configuration.isWriteFlag(false); // && similarityCutoff != 0.0;
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty );
boolean writeParallel = configuration.get("writeParallel", false);
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel );
}

Stream<SimilarityResult> generateWeightedStream(ProcedureConfiguration configuration, WeightedInput[] inputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,9 @@ public Stream<SimilaritySummaryResult> jaccard(
similarityCutoff, getTopK(configuration)), getTopN(configuration));

boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty );
boolean writeParallel = configuration.get("writeParallel", false);

return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel);
}

private SimilarityComputer<CategoricalInput> similarityComputer() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ public Stream<SimilaritySummaryResult> overlap(
Stream<SimilarityResult> stream = topN(similarityStream(inputs, computer, configuration, () -> null, similarityCutoff, getTopK(configuration)), getTopN(configuration));

boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty);
boolean writeParallel = configuration.get("writeParallel", false);

return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel);
}


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
/**
* Copyright (c) 2017 "Neo4j, Inc." <http://neo4j.com>
* <p>
* This file is part of Neo4j Graph Algorithms <http://github.com/neo4j-contrib/neo4j-graph-algorithms>.
* <p>
* Neo4j Graph Algorithms is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
* <p>
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
* <p>
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.graphalgo.similarity;

import com.carrotsearch.hppc.IntHashSet;
import com.carrotsearch.hppc.IntSet;
import org.neo4j.graphalgo.core.IdMap;
import org.neo4j.graphalgo.core.WeightMap;
import org.neo4j.graphalgo.core.heavyweight.AdjacencyMatrix;
import org.neo4j.graphalgo.core.heavyweight.HeavyGraph;
import org.neo4j.graphalgo.core.utils.*;
import org.neo4j.graphalgo.core.utils.dss.DisjointSetStruct;
import org.neo4j.graphalgo.core.utils.paged.AllocationTracker;
import org.neo4j.graphalgo.impl.DSSResult;
import org.neo4j.graphalgo.impl.GraphUnionFind;
import org.neo4j.graphdb.Direction;
import org.neo4j.internal.kernel.api.exceptions.EntityNotFoundException;
import org.neo4j.internal.kernel.api.exceptions.InvalidTransactionTypeKernelException;
import org.neo4j.internal.kernel.api.exceptions.KernelException;
import org.neo4j.internal.kernel.api.exceptions.explicitindex.AutoIndexingKernelException;
import org.neo4j.kernel.api.KernelTransaction;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.logging.Log;
import org.neo4j.values.storable.Values;

import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class ParallelSimilarityExporter extends StatementApi {

private final Log log;
private final int propertyId;
private final int relationshipTypeId;
private final int nodeCount;

public ParallelSimilarityExporter(GraphDatabaseAPI api,
Log log,
String relationshipType,
String propertyName, int nodeCount) {
super(api);
this.log = log;
propertyId = getOrCreatePropertyId(propertyName);
relationshipTypeId = getOrCreateRelationshipId(relationshipType);
this.nodeCount = nodeCount;
}

public void export(Stream<SimilarityResult> similarityPairs, long batchSize) {
IdMap idMap = new IdMap(this.nodeCount);
AdjacencyMatrix adjacencyMatrix = new AdjacencyMatrix(this.nodeCount, false, AllocationTracker.EMPTY);
WeightMap weightMap = new WeightMap(nodeCount, 0, propertyId);

int[] numberOfRelationships = {0};

similarityPairs.forEach(pair -> {
int id1 = idMap.mapOrGet(pair.item1);
int id2 = idMap.mapOrGet(pair.item2);
adjacencyMatrix.addOutgoing(id1, id2);
weightMap.put(RawValues.combineIntInt(id1, id2), pair.similarity);
numberOfRelationships[0]++;
});

idMap.buildMappedIds();
HeavyGraph graph = new HeavyGraph(idMap, adjacencyMatrix, weightMap, Collections.emptyMap());

DSSResult dssResult = computePartitions(graph);

Stream<List<DisjointSetStruct.InternalResult>> stream = dssResult.internalResultStream(graph)
.collect(Collectors.groupingBy(item -> item.setId))
.values()
.stream();

int queueSize = dssResult.getSetCount();
log.info("ParallelSimilarityExporter: Relationships to be created: %d, Partitions found: %d", numberOfRelationships[0], queueSize);

ArrayBlockingQueue<List<SimilarityResult>> outQueue = new ArrayBlockingQueue<>(queueSize);

ExecutorService executor = Executors.newFixedThreadPool(1);
Future<Integer> inQueueBatchCountFuture = executor.submit(() -> {
AtomicInteger inQueueBatchCount = new AtomicInteger(0);
stream.parallel().forEach(partition -> {
IntSet nodesInPartition = new IntHashSet();
for (DisjointSetStruct.InternalResult internalResult : partition) {
nodesInPartition.add(internalResult.internalNodeId);
}

List<SimilarityResult> inPartition = new ArrayList<>();
List<SimilarityResult> outPartition = new ArrayList<>();

for (DisjointSetStruct.InternalResult result : partition) {
int nodeId = result.internalNodeId;
graph.forEachRelationship(nodeId, Direction.OUTGOING, (sourceNodeId, targetNodeId, relationId, weight) -> {
SimilarityResult similarityRelationship = new SimilarityResult(idMap.toOriginalNodeId(sourceNodeId), idMap.toOriginalNodeId(targetNodeId), -1, -1, -1, weight);

if (nodesInPartition.contains(targetNodeId)) {
inPartition.add(similarityRelationship);
} else {
outPartition.add(similarityRelationship);
}

return false;
});
}

if (!inPartition.isEmpty()) {
int inQueueBatches = writeSequential(inPartition.stream(), batchSize);
inQueueBatchCount.addAndGet(inQueueBatches);
}

if (!outPartition.isEmpty()) {
put(outQueue, outPartition);
}
});
return inQueueBatchCount.get();
});

Integer inQueueBatches = null;
try {
inQueueBatches = inQueueBatchCountFuture.get();
} catch (InterruptedException | ExecutionException e) {
e.printStackTrace();
}


int outQueueBatches = writeSequential(outQueue.stream().flatMap(Collection::stream), batchSize);
log.info("ParallelSimilarityExporter: Batch Size: %d, Batches written - in parallel: %d, sequentially: %d", batchSize, inQueueBatches, outQueueBatches);
}

private static <T> void put(BlockingQueue<T> queue, T items) {
try {
queue.put(items);
} catch (InterruptedException e) {
// ignore
}
}

private DSSResult computePartitions(HeavyGraph graph) {
GraphUnionFind algo = new GraphUnionFind(graph);
DisjointSetStruct struct = algo.compute();
algo.release();
return new DSSResult(struct);
}

private void export(SimilarityResult similarityResult) {
applyInTransaction(statement -> {
try {
createRelationship(similarityResult, statement);
} catch (KernelException e) {
ExceptionUtil.throwKernelException(e);
}
return null;
});

}

private void export(List<SimilarityResult> similarityResults) {
applyInTransaction(statement -> {
for (SimilarityResult similarityResult : similarityResults) {
try {
createRelationship(similarityResult, statement);
} catch (KernelException e) {
ExceptionUtil.throwKernelException(e);
}
}
return null;
});

}

private void createRelationship(SimilarityResult similarityResult, KernelTransaction statement) throws EntityNotFoundException, InvalidTransactionTypeKernelException, AutoIndexingKernelException {
long node1 = similarityResult.item1;
long node2 = similarityResult.item2;
long relationshipId = statement.dataWrite().relationshipCreate(node1, relationshipTypeId, node2);

statement.dataWrite().relationshipSetProperty(
relationshipId, propertyId, Values.doubleValue(similarityResult.similarity));
}

private int getOrCreateRelationshipId(String relationshipType) {
return applyInTransaction(stmt -> stmt
.tokenWrite()
.relationshipTypeGetOrCreateForName(relationshipType));
}

private int getOrCreatePropertyId(String propertyName) {
return applyInTransaction(stmt -> stmt
.tokenWrite()
.propertyKeyGetOrCreateForName(propertyName));
}

private int writeSequential(Stream<SimilarityResult> similarityPairs, long batchSize) {
int[] counter = {0};
if (batchSize == 1) {
similarityPairs.forEach(similarityResult -> {
export(similarityResult);
counter[0]++;
});
} else {
Iterator<SimilarityResult> iterator = similarityPairs.iterator();
do {
List<SimilarityResult> batch = take(iterator, Math.toIntExact(batchSize));
export(batch);
if(batch.size() > 0) {
counter[0]++;
}
} while (iterator.hasNext());
}

return counter[0];
}

private static List<SimilarityResult> take(Iterator<SimilarityResult> iterator, int batchSize) {
List<SimilarityResult> result = new ArrayList<>(batchSize);
while (iterator.hasNext() && batchSize-- > 0) {
result.add(iterator.next());
}
return result;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ public Stream<SimilaritySummaryResult> pearson(
Stream<SimilarityResult> stream = generateWeightedStream(configuration, inputs, similarityCutoff, topN, topK, computer);

boolean write = configuration.isWriteFlag(false) && similarityCutoff > 0.0;
return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty);
boolean writeParallel = configuration.get("writeParallel", false);

return writeAndAggregateResults(stream, inputs.length, configuration, write, writeRelationshipType, writeProperty, writeParallel);
}

private SimilarityComputer<WeightedInput> similarityComputer(Double skipValue) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.neo4j.internal.kernel.api.exceptions.explicitindex.AutoIndexingKernelException;
import org.neo4j.kernel.api.KernelTransaction;
import org.neo4j.kernel.internal.GraphDatabaseAPI;
import org.neo4j.logging.Log;
import org.neo4j.values.storable.Values;

import java.util.ArrayList;
Expand All @@ -35,19 +36,22 @@

public class SimilarityExporter extends StatementApi {

private final Log log;
private final int propertyId;
private final int relationshipTypeId;

public SimilarityExporter(GraphDatabaseAPI api,
String relationshipType,
Log log, String relationshipType,
String propertyName) {
super(api);
this.log = log;
propertyId = getOrCreatePropertyId(propertyName);
relationshipTypeId = getOrCreateRelationshipId(relationshipType);
}

public void export(Stream<SimilarityResult> similarityPairs, long batchSize) {
writeSequential(similarityPairs, batchSize);
int batches = writeSequential(similarityPairs, batchSize);
log.info("ParallelSimilarityExporter: Batch Size: %d, Batches written - sequentially: %d", batchSize, batches);
}

private void export(SimilarityResult similarityResult) {
Expand Down Expand Up @@ -97,17 +101,29 @@ private int getOrCreatePropertyId(String propertyName) {
.propertyKeyGetOrCreateForName(propertyName));
}

private void writeSequential(Stream<SimilarityResult> similarityPairs, long batchSize) {
private int writeSequential(Stream<SimilarityResult> similarityPairs, long batchSize) {
log.info("SimilarityExporter: Writing relationships...");
int[] counter = {0};
if (batchSize == 1) {
similarityPairs.forEach(this::export);
similarityPairs.forEach(similarityResult -> {
export(similarityResult);
counter[0]++;
});
} else {
Iterator<SimilarityResult> iterator = similarityPairs.iterator();
do {
export(take(iterator, Math.toIntExact(batchSize)));
List<SimilarityResult> batch = take(iterator, Math.toIntExact(batchSize));
export(batch);
if(batch.size() > 0) {
counter[0]++;
}
} while (iterator.hasNext());
}

return counter[0];
}


private static List<SimilarityResult> take(Iterator<SimilarityResult> iterator, int batchSize) {
List<SimilarityResult> result = new ArrayList<>(batchSize);
while (iterator.hasNext() && batchSize-- > 0) {
Expand Down
Loading