From 80250258ddffa84ef98acc7813e40098266367bf Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Tue, 3 Feb 2026 10:51:35 +0100 Subject: [PATCH 1/3] Add Support For Grouped Tiles Add Stream Split and Merge Primitives + Bugfixes / Additional Tests Preliminary Bugfixes and Performance Improvements Bugfix Cache Deletion Add Support for Output Write Partitioning --- .../sysds/hops/ipa/IPAPassInjectOOCTee.java | 51 +++ .../hops/ipa/IPAPassPruneUnreachableHops.java | 124 ++++++++ .../hops/ipa/InterProceduralAnalysis.java | 5 + .../sysds/hops/rewrite/ProgramRewriter.java | 1 - .../controlprogram/caching/CacheableData.java | 3 +- .../cp/FunctionCallCPInstruction.java | 21 +- .../ooc/AggregateUnaryOOCInstruction.java | 2 +- .../ooc/BinaryOOCInstruction.java | 4 +- .../instructions/ooc/CachingStream.java | 279 +++++++++++++--- .../ooc/MapMMChainOOCInstruction.java | 122 ++++--- .../instructions/ooc/OOCInstruction.java | 234 ++++++++++---- .../runtime/instructions/ooc/OOCStream.java | 10 +- .../instructions/ooc/OOCStreamable.java | 4 + .../ooc/SubscribableTaskQueue.java | 15 + .../instructions/ooc/TeeOOCInstruction.java | 24 +- .../apache/sysds/runtime/io/MatrixWriter.java | 4 +- .../runtime/io/WriterBinaryBlockParallel.java | 97 ++++++ .../sysds/runtime/ooc/cache/BlockEntry.java | 16 + .../runtime/ooc/cache/GroupedBlockKey.java | 33 ++ .../runtime/ooc/cache/OOCCacheManager.java | 212 ++++++++++++- .../sysds/runtime/ooc/cache/OOCIOHandler.java | 13 + .../ooc/cache/OOCLRUCacheScheduler.java | 6 +- .../runtime/ooc/cache/OOCMatrixIOHandler.java | 130 ++++++-- .../runtime/ooc/stream/FilteredOOCStream.java | 23 +- .../runtime/ooc/stream/MergedOOCStream.java | 253 +++++++++++++++ .../runtime/ooc/stream/SourceOOCStream.java | 74 ++++- .../ooc/stream/SourceOOCStreamable.java | 125 ++++++++ .../ooc/stream/SplittingOOCStream.java | 220 +++++++++++++ .../runtime/ooc/stream/SubOOCStream.java | 180 +++++++++++ .../sysds/runtime/ooc/stream/TaskContext.java | 9 +- .../org/apache/sysds/utils/Statistics.java | 20 +- .../ooc/BinaryWritePartitioningTest.java | 132 ++++++++ .../test/functions/ooc/CSVReaderTest.java | 6 +- .../sysds/test/functions/ooc/LmCGTest.java | 1 - .../sysds/test/functions/ooc/PNMFTest.java | 104 ++++++ .../ooc/SplitMergeOOCStreamTest.java | 299 ++++++++++++++++++ src/test/scripts/functions/ooc/PNMF.dml | 26 ++ 37 files changed, 2656 insertions(+), 226 deletions(-) create mode 100644 src/main/java/org/apache/sysds/hops/ipa/IPAPassInjectOOCTee.java create mode 100644 src/main/java/org/apache/sysds/hops/ipa/IPAPassPruneUnreachableHops.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/cache/GroupedBlockKey.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/stream/SourceOOCStreamable.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/stream/SplittingOOCStream.java create mode 100644 src/main/java/org/apache/sysds/runtime/ooc/stream/SubOOCStream.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/BinaryWritePartitioningTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/SplitMergeOOCStreamTest.java create mode 100644 src/test/scripts/functions/ooc/PNMF.dml diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassInjectOOCTee.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassInjectOOCTee.java new file mode 100644 index 00000000000..f8f44c8e741 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassInjectOOCTee.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.ipa; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.hops.HopsException; +import org.apache.sysds.hops.rewrite.ProgramRewriteStatus; +import org.apache.sysds.hops.rewrite.ProgramRewriter; +import org.apache.sysds.hops.rewrite.RewriteInjectOOCTee; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.LanguageException; + +/** + * Applies OOC tee injection after static/dynamic rewrites in IPA. + */ +public class IPAPassInjectOOCTee extends IPAPass { + @Override + public boolean isApplicable(FunctionCallGraph fgraph) { + return DMLScript.USE_OOC; + } + + @Override + public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) { + try { + ProgramRewriter rewriter = new ProgramRewriter(new RewriteInjectOOCTee()); + ProgramRewriteStatus status = new ProgramRewriteStatus(); + rewriter.rewriteProgramHopDAGs(prog, true, status); + return false; + } + catch(LanguageException ex) { + throw new HopsException(ex); + } + } +} diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassPruneUnreachableHops.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassPruneUnreachableHops.java new file mode 100644 index 00000000000..0d8c529d679 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassPruneUnreachableHops.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.ipa; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.ForStatement; +import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.FunctionStatement; +import org.apache.sysds.parser.FunctionStatementBlock; +import org.apache.sysds.parser.IfStatement; +import org.apache.sysds.parser.IfStatementBlock; +import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.WhileStatement; +import org.apache.sysds.parser.WhileStatementBlock; + +/** + * Prune stale parent links by keeping only parent references reachable from the statement block roots/predicates. + */ +public class IPAPassPruneUnreachableHops extends IPAPass { + @Override + public boolean isApplicable(FunctionCallGraph fgraph) { + return DMLScript.USE_OOC; + } + + @Override + public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) { + pruneStatementBlocks(prog.getStatementBlocks()); + for(FunctionStatementBlock fsb : prog.getFunctionStatementBlocks()) + pruneStatementBlocks(((FunctionStatement) fsb.getStatement(0)).getBody()); + return false; + } + + private static void pruneStatementBlocks(List sbs) { + for(StatementBlock sb : sbs) { + if(sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement) sb.getStatement(0); + pruneHops(wsb.getPredicateHops()); + pruneStatementBlocks(wstmt.getBody()); + } + else if(sb instanceof IfStatementBlock) { + IfStatementBlock isb = (IfStatementBlock) sb; + IfStatement istmt = (IfStatement) sb.getStatement(0); + pruneHops(isb.getPredicateHops()); + pruneStatementBlocks(istmt.getIfBody()); + if(istmt.getElseBody() != null) + pruneStatementBlocks(istmt.getElseBody()); + } + else if(sb instanceof ForStatementBlock) { + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement) sb.getStatement(0); + pruneHops(fsb.getFromHops()); + pruneHops(fsb.getToHops()); + pruneHops(fsb.getIncrementHops()); + pruneStatementBlocks(fstmt.getBody()); + } + else if(sb instanceof FunctionStatementBlock) { + FunctionStatement fstmt = (FunctionStatement) sb.getStatement(0); + pruneStatementBlocks(fstmt.getBody()); + } + else { + pruneHops(sb.getHops()); + } + } + } + + private static void pruneHops(Hop root) { + if(root == null) + return; + Set reachable = new HashSet<>(); + collectReachable(root, reachable); + pruneParents(root, reachable, new HashSet()); + } + + private static void pruneHops(List roots) { + if(roots == null || roots.isEmpty()) + return; + + Set reachable = new HashSet<>(); + for(Hop root : roots) + collectReachable(root, reachable); + + for(Hop root : roots) + pruneParents(root, reachable, new HashSet()); + } + + private static void collectReachable(Hop hop, Set reachable) { + if(hop == null || !reachable.add(hop.getHopID())) + return; + for(Hop in : hop.getInput()) + collectReachable(in, reachable); + } + + private static void pruneParents(Hop hop, Set reachable, Set visited) { + if(hop == null || !visited.add(hop.getHopID())) + return; + hop.getParent().removeIf(p -> !reachable.contains(p.getHopID())); + for(Hop in : hop.getInput()) + pruneParents(in, reachable, visited); + } +} diff --git a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java index dce8fb05424..3344a95d2d5 100644 --- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java +++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java @@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.hops.DataOp; import org.apache.sysds.hops.FunctionOp; import org.apache.sysds.hops.FunctionOp.FunctionType; @@ -141,6 +142,10 @@ public InterProceduralAnalysis(DMLProgram dmlp) { //would require an update of the function call graph _passes.add(new IPAPassForwardFunctionCalls()); _passes.add(new IPAPassApplyStaticAndDynamicHopRewrites()); + if (DMLScript.USE_OOC) { + _passes.add(new IPAPassPruneUnreachableHops()); + _passes.add(new IPAPassInjectOOCTee()); + } } public InterProceduralAnalysis(StatementBlock sb) { diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index aa82adcfdc5..98534f5d8c8 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -152,7 +152,6 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) _dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse _sbRuleSet.add( new RewriteRemoveEmptyBasicBlocks() ); _sbRuleSet.add( new RewriteRemoveEmptyForLoops() ); - _sbRuleSet.add( new RewriteInjectOOCTee() ); } /** diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java index f41b0511ee9..e12d2caa5e3 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java @@ -63,6 +63,7 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.meta.MetaData; import org.apache.sysds.runtime.meta.MetaDataFormat; +import org.apache.sysds.runtime.ooc.stream.SourceOOCStreamable; import org.apache.sysds.runtime.util.HDFSTool; import org.apache.sysds.runtime.util.LocalFileUtils; import org.apache.sysds.runtime.util.UtilFunctions; @@ -496,7 +497,7 @@ public synchronized OOCStream getStreamHandle() { } public OOCStreamable getStreamable() { - return _streamHandle; + return _streamHandle == null ? new SourceOOCStreamable(this) : _streamHandle; } /** diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java index e6f553dc6aa..b3f6a6a117c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/FunctionCallCPInstruction.java @@ -37,10 +37,12 @@ import org.apache.sysds.runtime.DMLScriptException; import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock; import org.apache.sysds.runtime.controlprogram.LocalVariableMap; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.lineage.Lineage; import org.apache.sysds.runtime.lineage.LineageCache; @@ -172,6 +174,8 @@ public void processInstruction(ExecutionContext ec) { //set input parameter functionVariables.put(currFormalParam.getName(), value); + if (DMLScript.USE_OOC && value instanceof MatrixObject) + TeeOOCInstruction.incrRef(((MatrixObject) value).getStreamable(), 1); //map lineage to function arguments if( lineage != null ) { @@ -227,7 +231,8 @@ public void processInstruction(ExecutionContext ec) { if( expectRetVars.contains(varName) ) continue; //cleanup unexpected return values to avoid leaks - fn_ec.cleanupDataObject(fn_ec.removeVariable(varName)); + //(including OOC reference tracking for matrix streams) + VariableCPInstruction.processRmvarInstruction(fn_ec, varName); } // Unpin the pinned variables @@ -247,10 +252,12 @@ public void processInstruction(ExecutionContext ec) { // remove existing data bound to output variable name Data exdata = ec.removeVariable(boundVarName); + if (DMLScript.USE_OOC && exdata instanceof MatrixObject && exdata != boundValue) + TeeOOCInstruction.incrRef(((MatrixObject) exdata).getStreamable(), -1); // save old data for cleanup later if (exdata != boundValue && !retVars.hasReferences(exdata)) toBeCleanedUp.add(exdata); - //FIXME: interferes with reuse. Removes broadcasts before materialization + //FIXME: interferes with reuse. Removes broadcasts before materialization //add/replace data in symbol table ec.setVariable(boundVarName, boundValue); @@ -276,11 +283,17 @@ public void processInstruction(ExecutionContext ec) { //update lineage cache with the functions outputs if ((DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() && !fpb.isNondeterministic()) || (LineageCacheConfig.isEstimator() && !fpb.isNondeterministic())) { - LineageCache.putValue(fpb.getOutputParams(), liInputs, + LineageCache.putValue(fpb.getOutputParams(), liInputs, getCacheFunctionName(_functionName, fpb), fn_ec, t1-t0); - //FIXME: send _boundOutputNames instead of fpb.getOutputParams as + //FIXME: send _boundOutputNames instead of fpb.getOutputParams as //those are already replaced by boundoutput names in the lineage map. } + + // cleanup declared outputs that are not bound at callsite + for (int i = numOutputs; i < fpb.getOutputParams().size(); i++) { + String retVarName = fpb.getOutputParams().get(i).getName(); + VariableCPInstruction.processRmvarInstruction(fn_ec, retVarName); + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java index fa884b84d17..db8091621da 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/AggregateUnaryOOCInstruction.java @@ -157,7 +157,7 @@ public void processInstruction( ExecutionContext ec ) { else { OOCStream qLocal = createWritableStream(); - mapOOC(qIn, qLocal, tmp -> (MatrixBlock) ((MatrixBlock) tmp.getValue()) + mapOOC(qIn, qLocal, tmp -> (MatrixBlock) tmp.getValue() .aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes())); MatrixBlock ltmp; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java index 3dfdce26113..e45b8e93bc6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/BinaryOOCInstruction.java @@ -82,7 +82,7 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) { boolean isRowBroadcast = m1.getNumRows() > 1 && m2.getNumRows() == 1; if (isColBroadcast && !isRowBroadcast) { - final long maxProcessesPerBroadcast = m1.getNumColumns() / m1.getBlocksize(); + final long maxProcessesPerBroadcast = (m1.getNumColumns() + m1.getBlocksize() - 1) / m1.getBlocksize(); broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> { IndexedMatrixValue tmpOut = new IndexedMatrixValue(); @@ -96,7 +96,7 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) { }, tmp -> tmp.getIndexes().getRowIndex()); } else if (isRowBroadcast && !isColBroadcast) { - final long maxProcessesPerBroadcast = m1.getNumRows() / m1.getBlocksize(); + final long maxProcessesPerBroadcast = (m1.getNumRows() + m1.getBlocksize() - 1) / m1.getBlocksize(); broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> { IndexedMatrixValue tmpOut = new IndexedMatrixValue(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index d3e2b91630f..6aa65b9f723 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -23,12 +23,13 @@ import org.apache.commons.collections4.bidimap.DualHashBidiMap; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.CacheableData; -import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import org.apache.sysds.runtime.meta.DataCharacteristics; import org.apache.sysds.runtime.ooc.cache.BlockKey; +import org.apache.sysds.runtime.ooc.cache.GroupedBlockKey; import org.apache.sysds.runtime.ooc.cache.OOCIOHandler; import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; import org.apache.sysds.runtime.ooc.stream.SourceOOCStream; @@ -37,6 +38,9 @@ import org.apache.sysds.runtime.util.IndexRange; import shaded.parquet.it.unimi.dsi.fastutil.ints.IntArrayList; +import java.util.ArrayList; +import java.util.List; +import java.util.BitSet; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutionException; @@ -56,6 +60,11 @@ public class CachingStream implements OOCStreamable { private final OOCStream _source; private final IntArrayList _consumptionCounts = new IntArrayList(); private final IntArrayList _consumerConsumptionCounts = new IntArrayList(); + private final IntArrayList _groupIndices = new IntArrayList(); + private final IntArrayList _groupSizes = new IntArrayList(); + private final List _cacheKeys = new ArrayList<>(); + private final BitSet _ownedCacheKeys = new BitSet(); + private int _ownedCacheKeysSize = 0; // stream identifier private final long _streamId; @@ -84,46 +93,137 @@ public CachingStream(OOCStream source, long streamId) { _source = source; _source.setDownstreamMessageRelay(this::messageDownstream); _streamId = streamId; - if (OOCWatchdog.WATCH) { + if(OOCWatchdog.WATCH) { _watchdogId = "CS-" + hashCode(); // Capture a short context to help identify origin OOCWatchdog.registerOpen(_watchdogId, "CachingStream@" + hashCode(), getCtxMsg(), this); } _downstreamRelays = null; source.setSubscriber(tmp -> { - try (tmp) { - final IndexedMatrixValue task = tmp.get(); + try(tmp) { int blk; + int groupSize = 1; Consumer>[] mSubscribers; OOCStream.QueueCallback mCallback = null; - synchronized (this) { + synchronized(this) { mSubscribers = _subscribers; - if(task != LocalTaskQueue.NO_MORE_TASKS) { + if(!tmp.isEos()) { if(!_cacheInProgress) throw new DMLRuntimeException("Stream is closed"); - OOCIOHandler.SourceBlockDescriptor descriptor = null; - if(_source instanceof SourceOOCStream src) { - descriptor = src.getDescriptor(task.getIndexes()); - } - if(descriptor == null) { - if(mSubscribers == null || mSubscribers.length == 0) - OOCCacheManager.put(_streamId, _numBlocks, task); - else - mCallback = OOCCacheManager.putAndPin(_streamId, _numBlocks, task); - } - else { - if(mSubscribers == null || mSubscribers.length == 0) - OOCCacheManager.putSourceBacked(_streamId, _numBlocks, task, descriptor); - else - mCallback = OOCCacheManager.putAndPinSourceBacked(_streamId, _numBlocks, task, - descriptor); - } - if(_index != null) - _index.put(task.getIndexes(), _numBlocks); - blk = _numBlocks; - _numBlocks++; - _consumptionCounts.add(0); + if(tmp instanceof OOCStream.GroupQueueCallback) { + @SuppressWarnings("unchecked") + OOCStream.GroupQueueCallback group = + (OOCStream.GroupQueueCallback) tmp; + groupSize = group.size(); + for(int gi = 0; gi < groupSize; gi++) { + OOCStream.QueueCallback sub = group.getCallback(gi); + try(sub) { + IndexedMatrixValue imv = sub.get(); + if(_index != null) + _index.put(imv.getIndexes(), _numBlocks + gi); + } + } + + BlockKey baseKey; + boolean ownsEntry = true; + if(tmp instanceof OOCCacheManager.CachedGroupCallback cachedGroup) { + baseKey = cachedGroup.getBlockKey(); + ownsEntry = false; + if(mSubscribers != null && mSubscribers.length > 0) + mCallback = tmp.keepOpen(); + } + else { + List values = new java.util.ArrayList<>(groupSize); + long totalSize = 0; + for(int gi = 0; gi < groupSize; gi++) { + OOCStream.QueueCallback sub = group.getCallback(gi); + try(sub) { + IndexedMatrixValue imv = sub.get(); + values.add(imv); + totalSize += ((MatrixBlock) imv.getValue()).getExactSerializedSize(); + } + } + + baseKey = new BlockKey(_streamId, _numBlocks); + if (_source instanceof SourceOOCStream && tmp instanceof SourceOOCStream.SourceGroupCallback sg) { + OOCIOHandler.GroupSourceBlockDescriptor gdesc = sg.getDescriptor(); + if (mSubscribers == null || mSubscribers.length == 0) + OOCCacheManager.putRawSourceBacked(baseKey, values, totalSize, gdesc); + else + mCallback = OOCCacheManager.putAndPinRawSourceBacked(baseKey, values, totalSize, gdesc); + } + else { + if(mSubscribers == null || mSubscribers.length == 0) + OOCCacheManager.putRaw(baseKey, values, totalSize); + else + mCallback = OOCCacheManager.putAndPinRaw(baseKey, values, totalSize); + } + } + + blk = _numBlocks; + _numBlocks += groupSize; + for(int gi = 0; gi < groupSize; gi++) { + registerCacheKey(blk + gi, + new GroupedBlockKey(baseKey.getStreamId(), (int) baseKey.getSequenceNumber(), gi), + ownsEntry); + _consumptionCounts.add(0); + _groupIndices.add(gi); + _groupSizes.add(groupSize); + } + } + else { + final IndexedMatrixValue task = tmp.get(); + OOCIOHandler.SourceBlockDescriptor descriptor = null; + BlockKey blockKey = null; + boolean ownsEntry = true; + + if(tmp instanceof OOCCacheManager.CachedQueueCallback cachedQueue) { + blockKey = cachedQueue.getBlockKey(); + ownsEntry = false; + if(mSubscribers != null && mSubscribers.length > 0) + mCallback = tmp.keepOpen(); + } + else if(tmp instanceof OOCCacheManager.CachedSubCallback cachedSub) { + BlockKey parent = cachedSub.getParent().getBlockKey(); + blockKey = new GroupedBlockKey(parent.getStreamId(), (int) parent.getSequenceNumber(), + cachedSub.getGroupIndex()); + ownsEntry = false; + if(mSubscribers != null && mSubscribers.length > 0) + mCallback = tmp.keepOpen(); + } + + if(_source instanceof SourceOOCStream src) { + descriptor = src.getDescriptor(task.getIndexes()); + } + + if(blockKey == null) { + ownsEntry = true; + blockKey = new BlockKey(_streamId, _numBlocks); + if(descriptor == null) { + if(mSubscribers == null || mSubscribers.length == 0) + OOCCacheManager.put(_streamId, _numBlocks, task); + else + mCallback = OOCCacheManager.putAndPin(_streamId, _numBlocks, task); + } + else { + if(mSubscribers == null || mSubscribers.length == 0) + OOCCacheManager.putSourceBacked(_streamId, _numBlocks, task, descriptor); + else + mCallback = OOCCacheManager.putAndPinSourceBacked(_streamId, _numBlocks, task, + descriptor); + } + } + if(_index != null) + _index.put(task.getIndexes(), _numBlocks); + blk = _numBlocks; + _numBlocks++; + registerCacheKey(blk, blockKey, ownsEntry); + _consumptionCounts.add(0); + _groupIndices.add(-1); + _groupSizes.add(1); + } + notifyAll(); } else { @@ -150,7 +250,12 @@ public CachingStream(OOCStream source, long streamId) { try(localCallback) { mSubscribers[i].accept(localCallback); } - if(onConsumed(blk, i)) + boolean done = false; + for(int gi = 0; gi < groupSize; gi++) { + if(onConsumed(blk + gi, i)) + done = true; + } + if(done) mSubscribers[i].accept(OOCStream.eos(_failure)); } } @@ -219,8 +324,25 @@ private synchronized void tryDeleteBlock(int i) { int cnt = _consumptionCounts.getInt(i); if (cnt > _maxConsumptionCount) throw new DMLRuntimeException("Cannot have more than " + _maxConsumptionCount + " consumptions."); - if (cnt == _maxConsumptionCount) - OOCCacheManager.forget(_streamId, i); + if(!_ownedCacheKeys.get(i)) + return; + if (cnt == _maxConsumptionCount) { + int groupIdx = _groupIndices.getInt(i); + int groupSize = _groupSizes.getInt(i); + if (groupIdx >= 0 && groupSize > 1) { + int baseId = i - groupIdx; + if (i != baseId) + return; + for (int j = 0; j < groupSize; j++) { + if (_consumptionCounts.getInt(baseId + j) < _maxConsumptionCount) + return; + } + OOCCacheManager.getCache().forget(getEntryBlockKey(baseId)); + } + else { + OOCCacheManager.getCache().forget(getEntryBlockKey(i)); + } + } } private synchronized boolean onConsumed(int blockIdx, int consumerIdx) { @@ -249,7 +371,7 @@ public synchronized CompletableFuture { synchronized(this) { if(_index != null) // Ensure index is up to date @@ -279,7 +401,7 @@ public synchronized int findCachedIndex(MatrixIndexes idx) { } public synchronized BlockKey peekCachedBlockKey(MatrixIndexes idx) { - return new BlockKey(_streamId, _index.get(idx)); + return getBlockKey(_index.get(idx)); } public synchronized OOCStream.QueueCallback findCached(MatrixIndexes idx) { @@ -292,7 +414,7 @@ public synchronized OOCStream.QueueCallback findCached(Matri _consumptionCounts.set(mIdx, newCount); try { - return OOCCacheManager.requestBlock(_streamId, mIdx).get(); + return OOCCacheManager.requestBlock(getBlockKey(mIdx)).get(); } catch (InterruptedException | ExecutionException e) { return new OOCStream.SimpleQueueCallback<>(null, new DMLRuntimeException(e)); } finally { @@ -310,7 +432,7 @@ public void findCachedAsync(MatrixIndexes idx, Consumer { + OOCCacheManager.requestBlock(getBlockKey(mIdx)).whenComplete((cb, r) -> { try (cb) { synchronized(CachingStream.this) { int newCount = _consumptionCounts.getInt(mIdx) + 1; @@ -347,7 +469,7 @@ public void peekCachedAsync(MatrixIndexes idx, Consumer callback.accept(cb)); + OOCCacheManager.requestBlock(getBlockKey(mIdx)).whenComplete((cb, r) -> callback.accept(cb)); } /** @@ -359,12 +481,41 @@ public OOCStream.QueueCallback peekCached(MatrixIndexes idx) mIdx = _index.get(idx); } try { - return OOCCacheManager.requestBlock(_streamId, mIdx).get(); + return OOCCacheManager.requestBlock(getBlockKey(mIdx)).get(); } catch (InterruptedException | ExecutionException e) { return new OOCStream.SimpleQueueCallback<>(null, new DMLRuntimeException(e)); } } + private BlockKey getBlockKey(int blockIdx) { + if(_cacheKeys.size() > blockIdx) + return _cacheKeys.get(blockIdx); + int groupIdx = _groupIndices.size() > blockIdx ? _groupIndices.getInt(blockIdx) : -1; + if (groupIdx >= 0) { + int baseId = blockIdx - groupIdx; + return new GroupedBlockKey(_streamId, baseId, groupIdx); + } + return new BlockKey(_streamId, blockIdx); + } + + private BlockKey getEntryBlockKey(int blockIdx) { + BlockKey key = getBlockKey(blockIdx); + if(key instanceof GroupedBlockKey) + return new BlockKey(key.getStreamId(), key.getSequenceNumber()); + return key; + } + + private void registerCacheKey(int blockIdx, BlockKey key, boolean ownsEntry) { + _cacheKeys.add(key); + if(ownsEntry) + _ownedCacheKeys.set(blockIdx); + else + _ownedCacheKeys.clear(blockIdx); + _ownedCacheKeysSize++; + if(_cacheKeys.size() != blockIdx + 1 || _ownedCacheKeysSize != blockIdx + 1) + throw new IllegalStateException("Invalid cache key registration order"); + } + public synchronized void activateIndexing() { if (_index == null) _index = new DualHashBidiMap<>(); @@ -380,6 +531,16 @@ public OOCStream getWriteStream() { return _source.getWriteStream(); } + @Override + public boolean hasStreamCache() { + return true; + } + + @Override + public CachingStream getStreamCache() { + return this; + } + @Override public boolean isProcessed() { return false; @@ -500,19 +661,55 @@ public void setSubscriber(Consumer> for(int i = 0; i < mNumBlocks; i++) { final int idx = i; - OOCCacheManager.requestBlock(_streamId, i).whenComplete((cb, r) -> { + int gIdx; + int gSize; + synchronized(this) { + gIdx = _groupIndices.getInt(idx); + gSize = _groupSizes.getInt(idx); + } + final int groupIdx = gIdx; + final int groupSize = gSize; + if(groupIdx > 0) + continue; // only replay grouped blocks once at the base index + + BlockKey replayKey = (groupSize > 1 && groupIdx == 0) ? new BlockKey(_streamId, idx) : getBlockKey(i); + OOCCacheManager.requestBlock(replayKey).whenComplete((cb, r) -> { if(r != null) { subscriber.accept(OOCStream.eos(DMLRuntimeException.of(r))); return; } try(cb) { synchronized(CachingStream.this) { - if(_index != null) - _index.put(cb.get().getIndexes(), idx); + if(_index != null) { + if(cb instanceof OOCStream.GroupQueueCallback && groupSize > 1) { + @SuppressWarnings("unchecked") + OOCStream.GroupQueueCallback group = + (OOCStream.GroupQueueCallback) cb; + for(int gi = 0; gi < groupSize; gi++) { + OOCStream.QueueCallback sub = group.getCallback(gi); + try(sub) { + _index.put(sub.get().getIndexes(), idx + gi); + } + } + } + else { + _index.put(cb.get().getIndexes(), idx); + } + } } subscriber.accept(cb); - if(onConsumed(idx, consumerIdx)) + boolean done = false; + if(groupSize > 1) { + for(int gi = 0; gi < groupSize; gi++) { + if(onConsumed(idx + gi, consumerIdx)) + done = true; + } + } + else if(onConsumed(idx, consumerIdx)) { + done = true; + } + if(done) subscriber.accept(OOCStream.eos(_failure)); // NO_MORE_TASKS } }); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MapMMChainOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MapMMChainOOCInstruction.java index 2fd5585edd3..d4851ee2ed1 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/MapMMChainOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/MapMMChainOOCInstruction.java @@ -100,6 +100,9 @@ public void processInstruction(ExecutionContext ec) { OOCStream qU; CompletableFuture uFuture; + CompletableFuture mapXvFuture = null; + boolean directXtFromXv = false; + OOCStream qPartialXtDirect = null; if(!hasV && _type == ChainType.XtXvy) { MatrixObject mw = ec.getMatrixObject(input3); @@ -115,70 +118,97 @@ public void processInstruction(ExecutionContext ec) { qU = qNegW; } else { - OOCStream qPartialXv = createWritableStream(); - OOCStream qXv = createWritableStream(); - OOCStream qInXv = xCache.getReadStream(); + if(numColBlocks == 1 && !_type.isWeighted()) { + directXtFromXv = true; + qPartialXtDirect = createWritableStream(); + mapXvFuture = broadcastJoinOOC(xCache.getReadStream(), qV, qPartialXtDirect, (x, v) -> { + MatrixBlock xBlock = (MatrixBlock) x.getValue(); + MatrixBlock vBlock = (MatrixBlock) v.getValue().getValue(); + MatrixBlock xv = xBlock.aggregateBinaryOperations(xBlock, vBlock, new MatrixBlock(), mmOp); + MatrixBlock partial = multTransposeVector(xBlock, xv); + return new IndexedMatrixValue(new MatrixIndexes(x.getIndexes().getColumnIndex(), 1L), partial); + }, tmp -> tmp.getIndexes().getColumnIndex(), tmp -> tmp.getIndexes().getRowIndex()); + uFuture = mapXvFuture; + qU = null; + } + else { + OOCStream qPartialXv = createWritableStream(); + OOCStream qXv = createWritableStream(); + OOCStream qInXv = xCache.getReadStream(); - CompletableFuture mapXvFuture = broadcastJoinOOC(qInXv, qV, qPartialXv, (x, v) -> { - MatrixBlock xBlock = (MatrixBlock) x.getValue(); - MatrixBlock vBlock = (MatrixBlock) v.getValue().getValue(); - MatrixBlock partial = xBlock.aggregateBinaryOperations(xBlock, vBlock, new MatrixBlock(), mmOp); - return new IndexedMatrixValue(new MatrixIndexes(x.getIndexes().getRowIndex(), 1L), partial); - }, tmp -> tmp.getIndexes().getColumnIndex(), tmp -> tmp.getIndexes().getRowIndex()); + mapXvFuture = broadcastJoinOOC(qInXv, qV, qPartialXv, (x, v) -> { + MatrixBlock xBlock = (MatrixBlock) x.getValue(); + MatrixBlock vBlock = (MatrixBlock) v.getValue().getValue(); + MatrixBlock partial = xBlock.aggregateBinaryOperations(xBlock, vBlock, new MatrixBlock(), mmOp); + return new IndexedMatrixValue(new MatrixIndexes(x.getIndexes().getRowIndex(), 1L), partial); + }, tmp -> tmp.getIndexes().getColumnIndex(), tmp -> tmp.getIndexes().getRowIndex()); - CompletableFuture reduceXvFuture = groupedReduceOOC(qPartialXv, qXv, (left, right) -> { - MatrixBlock mb = ((MatrixBlock) left.getValue()).binaryOperationsInPlace(plus, right.getValue()); - left.setValue(mb); - return left; - }, numColBlocks); + CompletableFuture reduceXvFuture = groupedReduceOOC(qPartialXv, qXv, (left, right) -> { + MatrixBlock mb = ((MatrixBlock) left.getValue()).binaryOperationsInPlace(plus, right.getValue()); + left.setValue(mb); + return left; + }, numColBlocks); - if(_type.isWeighted()) { - MatrixObject mw = ec.getMatrixObject(input3); - OOCStream qW = mw.getStreamHandle(); - OOCStream qWeighted = createWritableStream(); - BinaryOperator weightOp = InstructionUtils.parseBinaryOperator( - _type == ChainType.XtwXv ? Opcodes.MULT.toString() : Opcodes.MINUS.toString()); + if(_type.isWeighted()) { + MatrixObject mw = ec.getMatrixObject(input3); + OOCStream qW = mw.getStreamHandle(); + OOCStream qWeighted = createWritableStream(); + BinaryOperator weightOp = InstructionUtils.parseBinaryOperator( + _type == ChainType.XtwXv ? Opcodes.MULT.toString() : Opcodes.MINUS.toString()); - uFuture = broadcastJoinOOC(qXv, qW, qWeighted, (u, w) -> { - MatrixBlock uBlock = (MatrixBlock) u.getValue(); - MatrixBlock wBlock = (MatrixBlock) w.getValue().getValue(); - MatrixBlock updated = uBlock.binaryOperationsInPlace(weightOp, wBlock); - u.setValue(updated); - return u; - }, tmp -> tmp.getIndexes().getRowIndex(), tmp -> tmp.getIndexes().getRowIndex()); - qU = qWeighted; - } - else { - uFuture = reduceXvFuture; - qU = qXv; + uFuture = broadcastJoinOOC(qXv, qW, qWeighted, (u, w) -> { + MatrixBlock uBlock = (MatrixBlock) u.getValue(); + MatrixBlock wBlock = (MatrixBlock) w.getValue().getValue(); + MatrixBlock updated = uBlock.binaryOperationsInPlace(weightOp, wBlock); + u.setValue(updated); + return u; + }, tmp -> tmp.getIndexes().getRowIndex(), tmp -> tmp.getIndexes().getRowIndex()); + qU = qWeighted; + } + else { + uFuture = reduceXvFuture; + qU = qXv; + } } - mapXvFuture.exceptionally(err -> { - qOut.propagateFailure(DMLRuntimeException.of(err)); - return null; - }); } - OOCStream qInXt = xCache.getReadStream(); - OOCStream qPartialXt = createWritableStream(); - CompletableFuture joinXtFuture = broadcastJoinOOC(qInXt, qU, qPartialXt, (x, u) -> { - MatrixBlock xBlock = (MatrixBlock) x.getValue(); - MatrixBlock uBlock = (MatrixBlock) u.getValue().getValue(); - MatrixBlock partial = multTransposeVector(xBlock, uBlock); - return new IndexedMatrixValue(new MatrixIndexes(x.getIndexes().getColumnIndex(), 1L), partial); - }, tmp -> tmp.getIndexes().getRowIndex(), tmp -> tmp.getIndexes().getRowIndex()); + OOCStream qPartialXtOut; + CompletableFuture joinXtFuture; + if(directXtFromXv) { + joinXtFuture = CompletableFuture.completedFuture(null); + qPartialXtOut = qPartialXtDirect; + } + else { + OOCStream qInXt = xCache.getReadStream(); + OOCStream qPartialXt = createWritableStream(); + joinXtFuture = broadcastJoinOOC(qInXt, qU, qPartialXt, (x, u) -> { + MatrixBlock xBlock = (MatrixBlock) x.getValue(); + MatrixBlock uBlock = (MatrixBlock) u.getValue().getValue(); + MatrixBlock partial = multTransposeVector(xBlock, uBlock); + return new IndexedMatrixValue(new MatrixIndexes(x.getIndexes().getColumnIndex(), 1L), partial); + }, tmp -> tmp.getIndexes().getRowIndex(), tmp -> tmp.getIndexes().getRowIndex()); + qPartialXtOut = qPartialXt; + } - CompletableFuture outFuture = groupedReduceOOC(qPartialXt, qOut, (left, right) -> { + CompletableFuture outFuture = groupedReduceOOC(qPartialXtOut, qOut, (left, right) -> { MatrixBlock mb = ((MatrixBlock) left.getValue()).binaryOperationsInPlace(plus, right.getValue()); left.setValue(mb); return left; }, numRowBlocks); + final boolean deleteXCache = createdCache; outFuture.whenComplete((res, err) -> { - if(createdCache) + if(deleteXCache) xCache.scheduleDeletion(); }); + if(mapXvFuture != null) { + mapXvFuture.exceptionally(err -> { + qOut.propagateFailure(DMLRuntimeException.of(err)); + return null; + }); + } uFuture.exceptionally(err -> { qOut.propagateFailure(DMLRuntimeException.of(err)); return null; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index 9ce4c0eb9c4..f7cefe635df 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -35,6 +35,8 @@ import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; import org.apache.sysds.runtime.ooc.stats.OOCEventLog; import org.apache.sysds.runtime.ooc.stream.FilteredOOCStream; +import org.apache.sysds.runtime.ooc.stream.MergedOOCStream; +import org.apache.sysds.runtime.ooc.stream.SplittingOOCStream; import org.apache.sysds.runtime.ooc.stream.StreamContext; import org.apache.sysds.runtime.ooc.stream.TaskContext; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -63,6 +65,7 @@ import java.util.function.Function; public abstract class OOCInstruction extends Instruction { + public static final boolean ALLOW_PIPELINING = true; public static final ExecutorService COMPUTE_EXECUTOR = CommonThreadPool.get(); private static final AtomicInteger COMPUTE_IN_FLIGHT = new AtomicInteger(0); private static final int COMPUTE_BACKPRESSURE_THRESHOLD = 100; @@ -184,6 +187,23 @@ protected OOCStream filteredOOCStream(OOCStream qIn, Function(qIn, predicate); } + protected OOCStream mergeOOCStreams(List> streams) { + return new MergedOOCStream<>(streams); + } + + protected OOCStream mergeOOCStreams(OOCStream... streams) { + return new MergedOOCStream<>(streams); + } + + protected List> splitOOCStream(OOCStream source, Function partitionFunc, + int numPartitions) { + SplittingOOCStream split = new SplittingOOCStream<>(source, partitionFunc, numPartitions); + List> out = new ArrayList<>(numPartitions); + for(int i = 0; i < numPartitions; i++) + out.add(split.getSubStream(i)); + return out; + } + protected CompletableFuture mapOOC(OOCStream qIn, OOCStream qOut, Function mapper) { return mapOptionalOOC(qIn, qOut, tmp -> Optional.of(mapper.apply(tmp))); } @@ -216,7 +236,7 @@ protected CompletableFuture mapOptionalOOC(OOCStream qIn, OOCStr submitOOCTasks(qIn, exec, tmp -> { // Try to run as a predicate to prefer pipelining rather than fan-out - if(ForkJoinTask.getPool() == COMPUTE_EXECUTOR) { + if(ALLOW_PIPELINING && ForkJoinTask.getPool() == COMPUTE_EXECUTOR && TaskContext.canPipe()) { exec.accept(tmp); return false; } @@ -246,8 +266,10 @@ protected CompletableFuture broadcastJoinOOC(OOCStream CompletableFuture broadcastJoinOOC(OOCStream, OOCStream.QueueCallback, BroadcastedElement>> broadcastingQueue = createWritableStream(); AtomicInteger waitCtr = new AtomicInteger(1); Object lock = new Object(); + OOCStream leftReadStream = leftCached ? qIn : leftCache.getReadStream(); + OOCStream rightReadStream = rightCached ? broadcast : rightCache.getReadStream(); - CompletableFuture fut1 = submitOOCTasks(List.of(leftCache.getReadStream(), rightCache.getReadStream()), (i, tmp) -> { + CompletableFuture fut1 = submitOOCTasks(List.of(leftReadStream, rightReadStream), (i, tmp) -> { try(tmp) { P key = i == 0 ? onLeft.apply(tmp.get()) : onRight.apply(tmp.get()); @@ -354,15 +378,17 @@ protected CompletableFuture broadcastJoinOOC(OOCStream fut = CompletableFuture.allOf(fut1, fut2); - fut.whenComplete((res, t) -> { + final StreamContext context = _streamContext.copy(); + return fut.thenRun(() -> { availableBroadcastInput.forEach((k, v) -> { rightCache.incrProcessingCount(rightCache.findCachedIndex(v.idx), 1); }); availableBroadcastInput.clear(); qOut.closeInput(); + }).exceptionally(t -> { + context.failAll(DMLRuntimeException.of(t)); + return null; }); - - return fut; } protected CompletableFuture joinManyOOC(OOCStream left, @@ -372,8 +398,10 @@ protected CompletableFuture joinManyOOC(OOCStream CompletableFuture joinManyOOC(OOCStream, OOCStream.QueueCallback, BroadcastedElement, BroadcastedElement>> joinQueue = createWritableStream(); AtomicInteger waitCtr = new AtomicInteger(1); + OOCStream leftReadStream = leftCached ? left : leftCache.getReadStream(); + OOCStream rightReadStream = rightCached ? right : rightCache.getReadStream(); - CompletableFuture fut1 = submitOOCTasks(List.of(leftCache.getReadStream(), rightCache.getReadStream()), + CompletableFuture fut1 = submitOOCTasks(List.of(leftReadStream, rightReadStream), (i, tmp) -> { try(tmp) { boolean leftItem = i == 0; @@ -395,31 +425,36 @@ protected CompletableFuture joinManyOOC(OOCStream matches = leftItem ? tuple._2 : tuple._1; List toInsert = leftItem ? tuple._1 : tuple._2; + int matchesSize; boolean remove; synchronized(tuple) { toInsert.add(b); - - for(BroadcastedElement e : matches) { - waitCtr.incrementAndGet(); - OOCCacheManager.requestManyBlocks( - List.of(leftCache.peekCachedBlockKey(leftItem ? b.idx : e.idx), - rightCache.peekCachedBlockKey(leftItem ? e.idx : b.idx))).thenApply(joined -> { - try { - joinQueue.enqueue( - new Tuple5<>(key, joined.get(0).keepOpen(), joined.get(1).keepOpen(), - leftItem ? b : e, leftItem ? e : b)); - } - finally { - joined.forEach(OOCStream.QueueCallback::close); - } - return null; - }).exceptionally(t -> { - joinQueue.propagateFailure(DMLRuntimeException.of(t)); - return null; - }); - } + matchesSize = matches.size(); + waitCtr.addAndGet(matchesSize); remove = tuple._1.size() == releaseRightCount && tuple._2.size() == releaseLeftCount; } + + // We have the guarantee that matches is append only so we don't need to synchronize for this + for(int mIdx = 0; mIdx < matchesSize; mIdx++) { + BroadcastedElement e = matches.get(mIdx); + OOCCacheManager.requestManyBlocks( + List.of(leftCache.peekCachedBlockKey(leftItem ? b.idx : e.idx), + rightCache.peekCachedBlockKey(leftItem ? e.idx : b.idx))).thenApply(joined -> { + try { + joinQueue.enqueue( + new Tuple5<>(key, joined.get(0).keepOpen(), joined.get(1).keepOpen(), + leftItem ? b : e, leftItem ? e : b)); + } + finally { + joined.forEach(OOCStream.QueueCallback::close); + } + return null; + }).exceptionally(t -> { + joinQueue.propagateFailure(DMLRuntimeException.of(t)); + return null; + }); + } + if(remove) joinMap.remove(key); } @@ -766,7 +801,7 @@ protected CompletableFuture pipeOOC(OOCStream queue, Consumer CompletableFuture pipeOOC(List> queues, BiConsumer> consumer) { return submitOOCTasks(queues, consumer, (i, tmp) -> { // Try to run as a predicate to prefer pipelining rather than fan-out - if(ForkJoinTask.getPool() == COMPUTE_EXECUTOR) { + if(ALLOW_PIPELINING && ForkJoinTask.getPool() == COMPUTE_EXECUTOR && TaskContext.canPipe()) { consumer.accept(i, tmp); return false; } @@ -821,53 +856,114 @@ protected CompletableFuture submitOOCTasks(final List> qu return; } - if(predicate != null && !predicate.apply(k, callback)) { // Can get closed due to cancellation - if(onNotProcessed != null) - onNotProcessed.accept(k, callback); - return; - } + Consumer> process = cb -> { + if(predicate != null && !predicate.apply(k, cb)) { // Can get closed due to cancellation + if(onNotProcessed != null) + onNotProcessed.accept(k, cb); + return; + } - if(localFuture.isDone()) { - if(onNotProcessed != null) - onNotProcessed.accept(k, callback); - return; - } - else { - localTaskCtr.incrementAndGet(); - } + if(localFuture.isDone()) { + if(onNotProcessed != null) + onNotProcessed.accept(k, cb); + return; + } + else { + localTaskCtr.incrementAndGet(); + } - // The item needs to be pinned in memory to be accessible in the executor thread - final OOCStream.QueueCallback pinned = callback.keepOpen(); + // The item needs to be pinned in memory to be accessible in the executor thread + final OOCStream.QueueCallback pinned = cb.keepOpen(); - COMPUTE_IN_FLIGHT.incrementAndGet(); - try { - Runnable oocTask = oocTask(() -> { - long taskStartTime = DMLScript.STATISTICS ? System.nanoTime() : 0; - try(pinned) { - consumer.accept(k, pinned); + COMPUTE_IN_FLIGHT.incrementAndGet(); + try { + Runnable oocTask = oocTask(() -> { + long taskStartTime = DMLScript.STATISTICS ? System.nanoTime() : 0; + try(pinned) { + consumer.accept(k, pinned); - if(localTaskCtr.decrementAndGet() == 0) { - TaskContext.defer(() -> localFuture.complete(null)); + if(localTaskCtr.decrementAndGet() == 0) { + TaskContext.defer(() -> localFuture.complete(null)); + } } - } - finally { - COMPUTE_IN_FLIGHT.decrementAndGet(); - if (DMLScript.STATISTICS) { - _localStatisticsAdder.add(System.nanoTime() - taskStartTime); - if (globalFuture.isDone()) { - Statistics.maintainOOCHeavyHitter(getExtendedOpcode(), _localStatisticsAdder.sum()); - _localStatisticsAdder.reset(); + finally { + COMPUTE_IN_FLIGHT.decrementAndGet(); + if (DMLScript.STATISTICS) { + _localStatisticsAdder.add(System.nanoTime() - taskStartTime); + if (globalFuture.isDone()) { + Statistics.maintainOOCHeavyHitter(getExtendedOpcode(), _localStatisticsAdder.sum()); + _localStatisticsAdder.reset(); + } + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onComputeEvent(_callerId, taskStartTime, System.nanoTime()); } - if (DMLScript.OOC_LOG_EVENTS) - OOCEventLog.onComputeEvent(_callerId, taskStartTime, System.nanoTime()); + } + }, localFuture, streamContext); + COMPUTE_EXECUTOR.submit(oocTask); + } + catch (Exception e) { + COMPUTE_IN_FLIGHT.decrementAndGet(); + throw e; + } + }; + + if(callback instanceof OOCStream.GroupQueueCallback) { + @SuppressWarnings("unchecked") + OOCStream.GroupQueueCallback group = (OOCStream.GroupQueueCallback) callback; + + if(localFuture.isDone()) { + for(int idx = 0; idx < group.size(); idx++) { + OOCStream.QueueCallback sub = group.getCallback(idx); + try(sub) { + if(onNotProcessed != null) + onNotProcessed.accept(k, sub); } } - }, localFuture, streamContext); - COMPUTE_EXECUTOR.submit(oocTask); + return; + } + + localTaskCtr.incrementAndGet(); + final OOCStream.GroupQueueCallback pinnedGroup = + (OOCStream.GroupQueueCallback) group.keepOpen(); + + COMPUTE_IN_FLIGHT.incrementAndGet(); + try { + Runnable oocTask = oocTask(() -> { + long taskStartTime = DMLScript.STATISTICS ? System.nanoTime() : 0; + try(pinnedGroup) { + for(int idx = 0; idx < pinnedGroup.size(); idx++) { + OOCStream.QueueCallback sub = pinnedGroup.getCallback(idx); + try(sub) { + process.accept(sub); + } + } + + if(localTaskCtr.decrementAndGet() == 0) { + TaskContext.defer(() -> localFuture.complete(null)); + } + } + finally { + COMPUTE_IN_FLIGHT.decrementAndGet(); + if (DMLScript.STATISTICS) { + _localStatisticsAdder.add(System.nanoTime() - taskStartTime); + if (globalFuture.isDone()) { + Statistics.maintainOOCHeavyHitter(getExtendedOpcode(), _localStatisticsAdder.sum()); + _localStatisticsAdder.reset(); + } + if (DMLScript.OOC_LOG_EVENTS) + OOCEventLog.onComputeEvent(_callerId, taskStartTime, System.nanoTime()); + } + } + }, localFuture, streamContext); + COMPUTE_EXECUTOR.submit(oocTask); + } + catch (Exception e) { + COMPUTE_IN_FLIGHT.decrementAndGet(); + throw e; + } } - catch (Exception e) { - COMPUTE_IN_FLIGHT.decrementAndGet(); - throw e; + else { + process.accept(callback); } if(closeRaceWatchdog.get()) // Sanity check diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java index 7ee12e9f025..ce53e5f0949 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStream.java @@ -36,10 +36,6 @@ static QueueCallback eos(DMLRuntimeException e) { void propagateFailure(DMLRuntimeException re); - boolean hasStreamCache(); - - CachingStream getStreamCache(); - /** * Registers a new subscriber that consumes the stream. * While there is no guarantee for any specific order, the closing item LocalTaskQueue.NO_MORE_TASKS @@ -65,6 +61,12 @@ interface QueueCallback extends AutoCloseable { boolean isFailure(); } + interface GroupQueueCallback extends QueueCallback { + int size(); + + QueueCallback getCallback(int idx); + } + class SimpleQueueCallback implements QueueCallback { private final T _result; private DMLRuntimeException _failure; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java index 26fd227e86a..4f212f544b2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCStreamable.java @@ -32,6 +32,10 @@ public interface OOCStreamable { OOCStream getWriteStream(); + boolean hasStreamCache(); + + CachingStream getStreamCache(); + boolean isProcessed(); DataCharacteristics getDataCharacteristics(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java index 058a61c208c..e5c48decdd1 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/SubscribableTaskQueue.java @@ -111,6 +111,21 @@ public void enqueue(T t) { onDeliveryFinished(); } + protected boolean tryDeliverCallback(QueueCallback cb, int blockCount) { + Consumer> s = _subscriber; + if (s == null) + return false; + int cnt = _availableCtr.incrementAndGet(); + if (cnt <= 1) { // Then the queue was already closed and we disallow further enqueues + _availableCtr.decrementAndGet(); // Undo increment + throw new DMLRuntimeException("Cannot enqueue into closed SubscribableTaskQueue"); + } + _blockCount.addAndGet(blockCount); + s.accept(cb); + onDeliveryFinished(); + return true; + } + @Override public synchronized void enqueueTask(T t) { enqueue(t); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java index aba36297e7f..493aba06c72 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/TeeOOCInstruction.java @@ -42,10 +42,11 @@ public static void reset() { * Increments the reference counter of a stream by the set amount. */ public static void incrRef(OOCStreamable stream, int incr) { - if (!(stream instanceof CachingStream)) + if (!stream.hasStreamCache()) return; + CachingStream cache = stream.getStreamCache(); - Integer ref = refCtr.compute((CachingStream)stream, (k, v) -> { + Integer ref = refCtr.compute(cache, (k, v) -> { if (v == null) v = 0; v += incr; @@ -53,7 +54,7 @@ public static void incrRef(OOCStreamable stream, int incr) { }); if (ref == null) - ((CachingStream)stream).scheduleDeletion(); + cache.scheduleDeletion(); } protected TeeOOCInstruction(OOCType type, CPOperand in1, CPOperand out, String opcode, String istr) { @@ -69,21 +70,22 @@ public static TeeOOCInstruction parseInstruction(String str) { return new TeeOOCInstruction(OOCType.Tee, in1, out, opcode, str); } - public void processInstruction( ExecutionContext ec ) { + public void processInstruction(ExecutionContext ec) { //get input stream MatrixObject min = ec.getMatrixObject(input1); - OOCStream qIn = min.getStreamHandle(); + OOCStreamable streamable = min.getStreamable(); + CachingStream handle; - CachingStream handle = qIn.hasStreamCache() ? qIn.getStreamCache() : new CachingStream(qIn); - - if (!qIn.hasStreamCache()) { + if(streamable.hasStreamCache()) { + handle = streamable.getStreamCache(); + incrRef(handle, 1); + } + else { // We also set the input stream handle + handle = new CachingStream(min.getStreamHandle()); min.setStreamHandle(handle); incrRef(handle, 2); } - else { - incrRef(handle, 1); - } //get output and create new resettable stream MatrixObject mo = ec.getMatrixObject(output); diff --git a/src/main/java/org/apache/sysds/runtime/io/MatrixWriter.java b/src/main/java/org/apache/sysds/runtime/io/MatrixWriter.java index 8681a91c7e0..2f3c31527c9 100644 --- a/src/main/java/org/apache/sysds/runtime/io/MatrixWriter.java +++ b/src/main/java/org/apache/sysds/runtime/io/MatrixWriter.java @@ -46,8 +46,8 @@ public abstract void writeMatrixToHDFS( MatrixBlock src, String fname, long rlen throws IOException; /** - * Consumes an out-of-core stream of matrix blocks and writes them to a single file. - * This method must be implemented by writers that support OOC streaming output. + * Consumes an out-of-core stream of matrix blocks and writes them to the target output path. + * Implementations may choose single-file or multipart output depending on format and parallelism. * * @param fname The target output filename * @param stream The OOC stream of matrix blocks to consume diff --git a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlockParallel.java b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlockParallel.java index fe671cc226b..c00e58b7fac 100644 --- a/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlockParallel.java +++ b/src/main/java/org/apache/sysds/runtime/io/WriterBinaryBlockParallel.java @@ -23,14 +23,22 @@ import java.util.ArrayList; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.Future; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.mapred.JobConf; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.parfor.LocalTaskQueue; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.ooc.stream.SplittingOOCStream; import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.utils.stats.InfrastructureAnalyzer; @@ -83,6 +91,62 @@ protected void writeBinaryBlockMatrixToHDFS( Path path, JobConf job, MatrixBlock } } + @Override + public long writeMatrixFromStream(String fname, OOCStream stream, long rlen, long clen, int blen) + throws IOException { + Path path = new Path(fname); + long nnz = -1; + DataCharacteristics dc = stream.getDataCharacteristics(); + if(dc != null) + nnz = dc.getNonZeros(); + if(nnz < 0 && rlen > 0 && clen > 0) { + if(rlen > Long.MAX_VALUE / clen) + nnz = Long.MAX_VALUE - 1; + else + nnz = rlen * clen; + } + else if(nnz < 0) + nnz = 0; + + int numPartFiles = numPartsFiles(path.getFileSystem(job), rlen, clen, blen, nnz); + int numThreads = OptimizerUtils.getParallelBinaryWriteParallelism(); + numThreads = Math.min(numThreads, numPartFiles); + + // fall back to sequential write if dop is 1 in order to create a single file + if(numThreads <= 1) + return super.writeMatrixFromStream(fname, stream, rlen, clen, blen); + + // Match CP parallel writer partitioning by contiguous row ranges. + final int parallelism = numThreads; + final int blklen = (int) Math.ceil((double) rlen / blen / parallelism) * blen; + SplittingOOCStream split = new SplittingOOCStream<>(stream, iVal -> { + int partition = (int) (((iVal.getIndexes().getRowIndex() - 1) * blen) / (long) blklen); + return Math.max(0, Math.min(partition, parallelism - 1)); + }, parallelism); + + final ExecutorService pool = Executors.newFixedThreadPool(parallelism); + try { + ArrayList tasks = new ArrayList<>(); + for(int i = 0; i < parallelism && i * (long) blklen < rlen; i++) { + Path newPath = new Path(path, IOUtilFunctions.getPartFileName(i)); + tasks.add(new WriteStreamTask(newPath, job, split.getSubStream(i))); + } + + long totalNnz = 0; + for(Future task : pool.invokeAll(tasks)) + totalNnz += task.get(); + return totalNnz; + } + catch(Exception e) { + DMLRuntimeException ex = DMLRuntimeException.of(e); + split.propagateFailure(ex); + throw ex; + } + finally { + pool.shutdown(); + } + } + public static int numPartsFiles(FileSystem fs, long rlen, long clen, long blen, long nZeros) { int numPartFiles = (int) (OptimizerUtils.estimatePartitionedSizeExactSparsity(rlen, clen, blen, nZeros) / InfrastructureAnalyzer.getBlockSize(fs)); @@ -117,4 +181,37 @@ public Object call() throws Exception { return null; } } + + private class WriteStreamTask implements Callable { + private final Path _path; + private final JobConf _job; + private final OOCStream _stream; + + public WriteStreamTask(Path path, JobConf job, OOCStream stream) { + _path = path; + _job = job; + _stream = stream; + } + + @Override + public Long call() throws Exception { + SequenceFile.Writer writer = null; + long totalNnz = 0; + try { + writer = IOUtilFunctions.getSeqWriter(_path, _job, _replication); + IndexedMatrixValue i_val; + while((i_val = _stream.dequeue()) != LocalTaskQueue.NO_MORE_TASKS) { + MatrixBlock mb = (MatrixBlock) i_val.getValue(); + MatrixIndexes ix = i_val.getIndexes(); + writer.append(ix, mb); + totalNnz += mb.getNonZeros(); + } + } + finally { + IOUtilFunctions.closeSilently(writer); + } + IOUtilFunctions.deleteCrcFilesFromLocalFileSystem(_job, _path); + return totalNnz; + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java index b5da05a598d..901e043d985 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/BlockEntry.java @@ -21,6 +21,8 @@ import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import java.util.List; + public final class BlockEntry { private final BlockKey _key; private final long _size; @@ -52,11 +54,25 @@ public Object getData() { throw new IllegalStateException("Cannot get the data of an unpinned entry"); } + public int getGroupSize() { + if(_pinCount > 0) + return ((List)_data).size(); + throw new IllegalStateException("Cannot get the data of an unpinned entry"); + } + + public boolean isGrouped() { + if(_pinCount > 0) + return _data instanceof List; + throw new IllegalStateException("Cannot get the data of an unpinned entry"); + } + Object getDataUnsafe() { return _data; } void setDataUnsafe(Object data) { + if(data != null && _data != null) + throw new IllegalStateException("Cannot overwrite data"); _data = data; } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/GroupedBlockKey.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/GroupedBlockKey.java new file mode 100644 index 00000000000..37789e02f20 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/GroupedBlockKey.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.cache; + +public class GroupedBlockKey extends BlockKey { + private final int _groupIndex; + + public GroupedBlockKey(long streamId, int blockId, int groupIndex) { + super(streamId, blockId); + _groupIndex = groupIndex; + } + + public int getGroupIndex() { + return _groupIndex; + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java index ef6824022cc..7f4f6be28ac 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java @@ -34,6 +34,7 @@ import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; public class OOCCacheManager { @@ -135,11 +136,26 @@ public static void putSourceBacked(long streamId, int blockId, IndexedMatrixValu getCache().putSourceBacked(key, value, ((MatrixBlock) value.getValue()).getExactSerializedSize(), descriptor); } + public static void putRawSourceBacked(BlockKey key, Object data, long size, OOCIOHandler.SourceBlockDescriptor descriptor) { + getCache().putSourceBacked(key, data, size, descriptor); + } + public static OOCStream.QueueCallback putAndPin(long streamId, int blockId, IndexedMatrixValue value) { BlockKey key = new BlockKey(streamId, blockId); return new CachedQueueCallback<>(getCache().putAndPin(key, value, ((MatrixBlock)value.getValue()).getExactSerializedSize()), null); } + public static void putRaw(BlockKey key, Object data, long size) { + getCache().put(key, data, size); + } + + public static OOCStream.QueueCallback putAndPinRaw(BlockKey key, Object data, long size) { + BlockEntry entry = getCache().putAndPin(key, data, size); + if (data instanceof List) + return new CachedGroupCallback<>(entry, null); + return new CachedQueueCallback<>(entry, null); + } + public static OOCStream.QueueCallback putAndPinSourceBacked(long streamId, int blockId, IndexedMatrixValue value, OOCIOHandler.SourceBlockDescriptor descriptor) { BlockKey key = new BlockKey(streamId, blockId); @@ -148,35 +164,74 @@ public static OOCStream.QueueCallback putAndPinSourceBacked( descriptor), null); } + public static OOCStream.QueueCallback putAndPinRawSourceBacked(BlockKey key, Object data, long size, + OOCIOHandler.SourceBlockDescriptor descriptor) { + BlockEntry entry = getCache().putAndPinSourceBacked(key, data, size, descriptor); + if (data instanceof List) + return new CachedGroupCallback<>(entry, null); + return new CachedQueueCallback<>(entry, null); + } + public static void prioritize(BlockKey key, int priority) { getCache().prioritize(key, priority); } public static CompletableFuture> requestBlock(long streamId, long blockId) { - BlockKey key = new BlockKey(streamId, blockId); - return getCache().request(key).thenApply(e -> new CachedQueueCallback<>(e, null)); + return requestBlock(new BlockKey(streamId, (int)blockId)); + } + + public static CompletableFuture> requestBlock(BlockKey key) { + return getCache().request(key).thenApply(e -> toCallback(e, key, null)); } public static CompletableFuture>> requestManyBlocks(List keys) { return getCache().request(keys).thenApply( - l -> l.stream().map(e -> (OOCStream.QueueCallback)new CachedQueueCallback(e, null)).toList()); + l -> { + List> out = new java.util.ArrayList<>(l.size()); + for (int i = 0; i < l.size(); i++) + out.add(toCallback(l.get(i), keys.get(i), null)); + return out; + }); } public static List> tryRequestManyBlocks(List keys) { List entries = getCache().tryRequest(keys); if(entries == null) return null; - return entries.stream().map(e -> (OOCStream.QueueCallback)new CachedQueueCallback(e, null)).toList(); + List> out = new java.util.ArrayList<>(entries.size()); + for (int i = 0; i < entries.size(); i++) + out.add(toCallback(entries.get(i), keys.get(i), null)); + return out; } public static CompletableFuture>> requestAnyOf(List keys, int n, List sel) { return getCache().requestAnyOf(keys, n, sel) .thenApply( - l -> l.stream().map(e -> (OOCStream.QueueCallback)new CachedQueueCallback(e, null)).toList()); + l -> { + List> out = new java.util.ArrayList<>(l.size()); + for (int i = 0; i < l.size(); i++) { + BlockKey key = sel.size() == l.size() ? sel.get(i) : keys.get(i); + out.add(toCallback(l.get(i), key, null)); + } + return out; + }); + } + + private static OOCStream.QueueCallback toCallback(BlockEntry entry, BlockKey key, DMLRuntimeException failure) { + if (entry.getData() instanceof java.util.List) { + CachedGroupCallback group = new CachedGroupCallback<>(entry, failure); + if (key instanceof GroupedBlockKey gk) { + OOCStream.QueueCallback sub = group.getCallback(gk.getGroupIndex()); + group.close(); // drop the group-level pin, sub keeps it pinned + return sub; + } + return group; + } + return new CachedQueueCallback<>(entry, failure); } public static boolean canClaimMemory() { - return getCache().isWithinSoftLimits() && OOCInstruction.getComputeInFlight() <= OOCInstruction.getComputeBackpressureThreshold(); + return getCache().isWithinLimits() && OOCInstruction.getComputeInFlight() <= OOCInstruction.getComputeBackpressureThreshold(); } private static void pin(BlockEntry entry) { @@ -190,12 +245,11 @@ private static void unpin(BlockEntry entry) { - static class CachedQueueCallback implements OOCStream.QueueCallback { + public static class CachedQueueCallback implements OOCStream.QueueCallback { private final BlockEntry _result; private final AtomicBoolean _pinned; private T _data; private DMLRuntimeException _failure; - private CompletableFuture _future; @SuppressWarnings("unchecked") CachedQueueCallback(BlockEntry result, DMLRuntimeException failure) { @@ -205,7 +259,6 @@ static class CachedQueueCallback implements OOCStream.QueueCallback { this._pinned = new AtomicBoolean(true); } - @SuppressWarnings("unchecked") @Override public T get() { if(_failure != null) @@ -243,9 +296,146 @@ public void close() { if(_pinned.compareAndSet(true, false)) { _data = null; unpin(_result); - if(_future != null) - _future.complete(null); } } + + public BlockKey getBlockKey() { + return _result.getKey(); + } + } + + public static class CachedSubCallback implements OOCStream.QueueCallback { + private final CachedGroupCallback _parent; + private final AtomicBoolean _pinned; + private T _data; + private final int _groupIndex; + + CachedSubCallback(CachedGroupCallback parent, T data, int groupIndex) { + _parent = parent; + _data = data; + _groupIndex = groupIndex; + _pinned = new AtomicBoolean(true); + } + + @Override + public T get() { + if(_parent.isFailure()) + throw _parent._failure; + return _data; + } + + @Override + public OOCStream.QueueCallback keepOpen() { + _parent.registerQueueCallback(); + return new CachedSubCallback<>(_parent, _data, _groupIndex); + } + + @Override + public void close() { + if(_pinned.compareAndSet(true, false)) { + _data = null; + _parent.close(); + } + } + + @Override + public void fail(DMLRuntimeException failure) { + _parent.fail(failure); + } + + @Override + public boolean isEos() { + return false; + } + + @Override + public boolean isFailure() { + return _parent.isFailure(); + } + + public CachedGroupCallback getParent() { + return _parent; + } + + public int getGroupIndex() { + return _groupIndex; + } + } + + public static class CachedGroupCallback implements OOCStream.GroupQueueCallback { + private final BlockEntry _result; + private final AtomicInteger _pinCounter; + private List _data; + private DMLRuntimeException _failure; + + @SuppressWarnings("unchecked") + CachedGroupCallback(BlockEntry result, DMLRuntimeException failure) { + this._result = result; + this._data = (List)result.getData(); + this._failure = failure; + this._pinCounter = new AtomicInteger(1); + } + + public OOCStream.QueueCallback getCallback(int idx) { + if(_pinCounter.get() <= 0) + throw new IllegalStateException("Cannot open sub-callback on a closed GroupCallback"); + registerQueueCallback(); + return new CachedSubCallback<>(this, _data.get(idx), idx); + } + + public void registerQueueCallback() { + if(_pinCounter.incrementAndGet() <= 1) + throw new IllegalStateException(); + } + + @Override + public T get() { + throw new UnsupportedOperationException(); + } + + @Override + public int size() { + return _data.size(); + } + + public T get(int idx) { + return _data.get(idx); + } + + @Override + public OOCStream.QueueCallback keepOpen() { + if(_pinCounter.get() <= 0) + throw new IllegalStateException("Cannot keep open an already closed callback"); + pin(_result); + return new CachedGroupCallback<>(_result, _failure); + } + + @Override + public void close() { + int cnt = _pinCounter.decrementAndGet(); + if(cnt == 0) { + _data = null; + unpin(_result); + } + } + + @Override + public void fail(DMLRuntimeException failure) { + _failure = failure; + } + + @Override + public boolean isEos() { + return false; + } + + @Override + public boolean isFailure() { + return _failure != null; + } + + public BlockKey getBlockKey() { + return _result.getKey(); + } } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java index 0699597c8b7..699a4f65493 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java @@ -116,4 +116,17 @@ public SourceBlockDescriptor(String path, org.apache.sysds.common.Types.FileForm this.serializedSize = serializedSize; } } + + class GroupSourceBlockDescriptor extends SourceBlockDescriptor { + public final List blocks; + public final int count; + + public GroupSourceBlockDescriptor(String path, org.apache.sysds.common.Types.FileFormat format, + org.apache.sysds.runtime.matrix.data.MatrixIndexes indexes, long offset, int recordLength, + long serializedSize, List blocks) { + super(path, format, indexes, offset, recordLength, serializedSize); + this.blocks = blocks; + this.count = blocks.size(); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java index 00da1681813..813dcd1d804 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCLRUCacheScheduler.java @@ -390,12 +390,15 @@ public boolean isWithinLimits() { @Override public boolean isWithinSoftLimits() { - return _cacheSize < _evictionLimit; + return _cacheSize < (_evictionLimit + _hardLimit) / 2; } @Override public synchronized void shutdown() { this._running = false; + if(!_cache.isEmpty() || !_evictionCache.isEmpty()) { + System.out.println("[WARN] Cache still holds " + _cache.size() + " / " + _evictionCache.size() + " blocks"); + } _cache.clear(); _evictionCache.clear(); _processingReadRequests.clear(); @@ -624,7 +627,6 @@ else if(allReserved && reading && req.isComplete()) { } else { LOG.error("Uncaught CacheError", t); - t.printStackTrace(); } return; } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java index ea508274402..aca99ed0966 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCMatrixIOHandler.java @@ -71,6 +71,9 @@ public class OOCMatrixIOHandler implements OOCIOHandler { private static final int READER_SIZE = 10; private static final long OVERFLOW = 8192 * 1024; private static final long MAX_PARTITION_SIZE = 8192 * 8192; + private static final long GROUP_TARGET_BYTES = 8L * 1024 * 1024; + private static final long GROUP_MAX_BYTES = 16L * 1024 * 1024; + private static final int GROUP_MAX_COUNT = 64; private final String _spillDir; private final ThreadPoolExecutor _writeExec; @@ -364,7 +367,12 @@ private void readSequenceFile(JobConf job, Path path, SourceReadRequest request, AtomicLong bytesRead, long byteLimit, Object budgetLock, ConcurrentLinkedDeque descriptors) throws IOException { MatrixIndexes key = new MatrixIndexes(); - MatrixBlock value = new MatrixBlock(); + List groupValues = new ArrayList<>(); + List groupDescs = new ArrayList<>(); + long groupBytes = 0; + long groupSerialized = 0; + long groupStart = -1; + long groupEnd = -1; try(SequenceFile.Reader reader = new SequenceFile.Reader(job, SequenceFile.Reader.file(path))) { long pos = filePositions.get(fileIdx); @@ -374,6 +382,7 @@ private void readSequenceFile(JobConf job, Path path, SourceReadRequest request, long ioStart = DMLScript.OOC_LOG_EVENTS ? System.nanoTime() : 0; while(!stop.get()) { long recordStart = reader.getPosition(); + MatrixBlock value = new MatrixBlock(); if (!reader.next(key, value)) break; long recordEnd = reader.getPosition(); @@ -392,17 +401,54 @@ else if (bytesRead.get() + blockSize > byteLimit) { } MatrixIndexes outIdx = new MatrixIndexes(key); - MatrixBlock outBlk = new MatrixBlock(value); - IndexedMatrixValue imv = new IndexedMatrixValue(outIdx, outBlk); + IndexedMatrixValue imv = new IndexedMatrixValue(outIdx, value); SourceBlockDescriptor descriptor = new SourceBlockDescriptor(path.toString(), request.format, outIdx, recordStart, (int)(recordEnd - recordStart), blockSize); - if (request.target instanceof SourceOOCStream src) - src.enqueue(imv, descriptor); - else - request.target.enqueue(imv); + boolean small = blockSize <= GROUP_TARGET_BYTES; + boolean contiguous = groupValues.isEmpty() || recordStart == groupEnd; + boolean canAdd = small + && contiguous + && groupValues.size() < GROUP_MAX_COUNT + && (groupBytes + (recordEnd - recordStart)) <= GROUP_MAX_BYTES; - descriptors.add(descriptor); + if (!canAdd && !groupValues.isEmpty()) { + flushSourceGroup(request, groupValues, groupDescs, groupStart, groupEnd, groupSerialized, + descriptors); + groupValues.clear(); + groupDescs.clear(); + groupBytes = 0; + groupSerialized = 0; + groupStart = -1; + groupEnd = -1; + } + + if (small) { + if (groupValues.isEmpty()) + groupStart = recordStart; + groupEnd = recordEnd; + groupValues.add(imv); + groupDescs.add(descriptor); + groupBytes += (recordEnd - recordStart); + groupSerialized += blockSize; + if (groupSerialized >= GROUP_TARGET_BYTES || groupBytes >= GROUP_MAX_BYTES || groupValues.size() >= GROUP_MAX_COUNT) { + flushSourceGroup(request, groupValues, groupDescs, groupStart, groupEnd, groupSerialized, + descriptors); + groupValues.clear(); + groupDescs.clear(); + groupBytes = 0; + groupSerialized = 0; + groupStart = -1; + groupEnd = -1; + } + } + else { + if (request.target instanceof SourceOOCStream src) + src.enqueue(imv, descriptor); + else + request.target.enqueue(imv); + descriptors.add(descriptor); + } filePositions.set(fileIdx, reader.getPosition()); if (DMLScript.OOC_LOG_EVENTS) { @@ -415,11 +461,34 @@ else if (bytesRead.get() + blockSize > byteLimit) { break; // Note that we knowingly go over limit, which could result in READER_SIZE*8MB overshoot } + if (!groupValues.isEmpty()) { + flushSourceGroup(request, groupValues, groupDescs, groupStart, groupEnd, groupSerialized, + descriptors); + } + if (!stop.get()) completed.set(fileIdx, 1); } } + private void flushSourceGroup(SourceReadRequest request, List values, + List descs, long start, long end, long totalSerialized, + ConcurrentLinkedDeque descriptors) { + if (values.isEmpty()) + return; + SourceBlockDescriptor first = descs.get(0); + OOCIOHandler.GroupSourceBlockDescriptor group = + new OOCIOHandler.GroupSourceBlockDescriptor(first.path, first.format, first.indexes, start, + (int)(end - start), totalSerialized, new ArrayList<>(descs)); + if (request.target instanceof SourceOOCStream src) + src.enqueueGroup(new ArrayList<>(values), group); + else { + for (IndexedMatrixValue v : values) + request.target.enqueue(v); + } + descriptors.addAll(group.blocks); + } + private void closeTarget(org.apache.sysds.runtime.instructions.ooc.OOCStream target, boolean close) { if(close) { try { @@ -442,6 +511,7 @@ private void loadFromDisk(BlockEntry block) { if (DMLScript.OOC_STATISTICS) { Statistics.incrementOOCLoadFromDisk(); Statistics.accumulateOOCLoadFromDiskTime(System.nanoTime() - ioStart); + Statistics.accumulateOOCLoadFromDiskBytes(block.getSize()); } return; } @@ -481,6 +551,7 @@ private void loadFromDisk(BlockEntry block) { if (DMLScript.OOC_STATISTICS) { Statistics.incrementOOCLoadFromDisk(); Statistics.accumulateOOCLoadFromDiskTime(ioDuration); + Statistics.accumulateOOCLoadFromDiskBytes(block.getSize()); } } @@ -491,19 +562,39 @@ private void loadFromSource(BlockEntry block, SourceBlockDescriptor src) { JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); Path path = new Path(src.path); - MatrixIndexes ix = new MatrixIndexes(); - MatrixBlock mb = new MatrixBlock(); - - try(SequenceFile.Reader reader = new SequenceFile.Reader(job, SequenceFile.Reader.file(path))) { - reader.seek(src.offset); - if (!reader.next(ix, mb)) - throw new DMLRuntimeException("Failed to read source block at offset " + src.offset + " in " + src.path); - } - catch(IOException e) { - throw new DMLRuntimeException(e); + if (src instanceof OOCIOHandler.GroupSourceBlockDescriptor gsrc) { + List values = new ArrayList<>(gsrc.count); + try(SequenceFile.Reader reader = new SequenceFile.Reader(job, SequenceFile.Reader.file(path))) { + reader.seek(gsrc.offset); + for (int i = 0; i < gsrc.blocks.size(); i++) { + SourceBlockDescriptor d = gsrc.blocks.get(i); + MatrixIndexes ix = new MatrixIndexes(); + MatrixBlock mb = new MatrixBlock(); + if (!reader.next(ix, mb)) + throw new DMLRuntimeException("Failed to read source block at offset " + d.offset + " in " + d.path); + values.add(new IndexedMatrixValue(ix, mb)); + } + } + catch(IOException e) { + throw new DMLRuntimeException(e); + } + block.setDataUnsafe(values); } + else { + MatrixIndexes ix = new MatrixIndexes(); + MatrixBlock mb = new MatrixBlock(); - block.setDataUnsafe(new IndexedMatrixValue(ix, mb)); + try(SequenceFile.Reader reader = new SequenceFile.Reader(job, SequenceFile.Reader.file(path))) { + reader.seek(src.offset); + if (!reader.next(ix, mb)) + throw new DMLRuntimeException("Failed to read source block at offset " + src.offset + " in " + src.path); + } + catch(IOException e) { + throw new DMLRuntimeException(e); + } + + block.setDataUnsafe(new IndexedMatrixValue(ix, mb)); + } } private void evictTask(CloseableQueue>> q) { @@ -542,6 +633,7 @@ private void evictTask(CloseableQueue if(DMLScript.OOC_STATISTICS && wrote > 0) { Statistics.incrementOOCEvictionWrite(); Statistics.accumulateOOCEvictionWriteTime(System.nanoTime() - ioStart); + Statistics.accumulateOOCEvictionWriteBytes(wrote); } byteCtr += wrote; diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/FilteredOOCStream.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/FilteredOOCStream.java index 0af68edd521..a6e978e030f 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/stream/FilteredOOCStream.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/FilteredOOCStream.java @@ -79,7 +79,28 @@ public CachingStream getStreamCache() { @Override public void setSubscriber(Consumer> subscriber) { _sourceStream.setSubscriber(cb -> { - if(cb.isFailure() || cb.isEos() || _predicate.apply(cb.get())) + if(cb.isFailure() || cb.isEos()) { + subscriber.accept(cb); + return; + } + + if(cb instanceof OOCStream.GroupQueueCallback) { + @SuppressWarnings("unchecked") + OOCStream.GroupQueueCallback group = (OOCStream.GroupQueueCallback) cb; + for(int i = 0; i < group.size(); i++) { + QueueCallback sub = group.getCallback(i); + boolean pass = sub.isFailure() || sub.isEos(); + if(!pass) + pass = _predicate.apply(sub.get()); + if(pass) + subscriber.accept(sub); + else + sub.close(); + } + return; + } + + if(_predicate.apply(cb.get())) subscriber.accept(cb); }); } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java new file mode 100644 index 00000000000..7d0a27932f1 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/MergedOOCStream.java @@ -0,0 +1,253 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.stream; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.instructions.ooc.CachingStream; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage; +import org.apache.sysds.runtime.util.IndexRange; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +public class MergedOOCStream implements OOCStream { + private final List> _sources; + private final SubscribableTaskQueue> _taskQueue; + private final AtomicInteger _openSources; + private final AtomicBoolean _failed; + private final CachingStream _sharedCache; + private QueueCallback _last; + + public MergedOOCStream(List> sources) { + if(sources == null || sources.isEmpty()) + throw new IllegalArgumentException("MergedOOCStream requires at least one source stream"); + _sources = sources; + _taskQueue = new SubscribableTaskQueue<>(); + _openSources = new AtomicInteger(sources.size()); + _failed = new AtomicBoolean(false); + _sharedCache = findSharedCache(sources); + + _taskQueue.setUpstreamMessageRelay(msg -> { + for(OOCStream source : _sources) + source.messageUpstream(msg); + }); + + for(OOCStream source : _sources) { + source.setSubscriber(cb -> { + try { + try(cb) { + if(cb.isFailure()) { + DMLRuntimeException failure; + try { + cb.get(); + failure = new DMLRuntimeException("Stream callback indicated failure without cause"); + } + catch(DMLRuntimeException re) { + failure = re; + } + propagateFailure(failure); + return; + } + + if(cb.isEos()) { + if(_failed.get()) + return; + if(_openSources.decrementAndGet() == 0) + _taskQueue.closeInput(); + return; + } + + if(_failed.get()) + return; + + _taskQueue.enqueue(cb.keepOpen()); + } + } + catch(DMLRuntimeException re) { + propagateFailure(re); + } + }); + } + } + + @SafeVarargs + public MergedOOCStream(OOCStream... sources) { + this(Arrays.asList(sources)); + } + + private static CachingStream findSharedCache(List> sources) { + CachingStream shared = null; + for(OOCStream source : sources) { + if(!source.hasStreamCache()) + return null; + CachingStream cache = source.getStreamCache(); + if(cache == null) + return null; + if(shared == null) + shared = cache; + else if(shared != cache) + return null; + } + return shared; + } + + @Override + public void enqueue(T t) { + throw new UnsupportedOperationException(); + } + + @Override + public synchronized T dequeue() { + if(_last != null) + _last.close(); + _last = _taskQueue.dequeue(); + if(_last == null) + return null; + return _last.get(); + } + + @Override + public void closeInput() { + throw new UnsupportedOperationException(); + } + + @Override + public void propagateFailure(DMLRuntimeException re) { + if(!_failed.compareAndSet(false, true)) + return; + _taskQueue.propagateFailure(re); + for(OOCStream source : _sources) + source.propagateFailure(re); + } + + @Override + public boolean hasStreamCache() { + return _sharedCache != null; + } + + @Override + public CachingStream getStreamCache() { + return _sharedCache; + } + + @Override + public void setSubscriber(Consumer> subscriber) { + _taskQueue.setSubscriber(cb -> { + if(cb.isEos()) { + subscriber.accept(OOCStream.eos(null)); + return; + } + if(cb.isFailure()) { + try { + cb.get(); + subscriber.accept(OOCStream.eos(new DMLRuntimeException("Stream callback indicated failure without cause"))); + } + catch(DMLRuntimeException re) { + subscriber.accept(OOCStream.eos(re)); + } + return; + } + subscriber.accept(cb.get()); + }); + } + + @Override + public OOCStream getReadStream() { + return this; + } + + @Override + public OOCStream getWriteStream() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isProcessed() { + return _taskQueue.isProcessed(); + } + + @Override + public DataCharacteristics getDataCharacteristics() { + return _taskQueue.getDataCharacteristics(); + } + + @Override + public CacheableData getData() { + return _taskQueue.getData(); + } + + @Override + public void setData(CacheableData data) { + _taskQueue.setData(data); + } + + @Override + public void messageUpstream(OOCStreamMessage msg) { + _taskQueue.messageUpstream(msg); + } + + @Override + public void messageDownstream(OOCStreamMessage msg) { + _taskQueue.messageDownstream(msg); + } + + @Override + public void setUpstreamMessageRelay(Consumer relay) { + throw new UnsupportedOperationException(); + } + + @Override + public void setDownstreamMessageRelay(Consumer relay) { + _taskQueue.setDownstreamMessageRelay(relay); + } + + @Override + public void addUpstreamMessageRelay(Consumer relay) { + throw new UnsupportedOperationException(); + } + + @Override + public void addDownstreamMessageRelay(Consumer relay) { + _taskQueue.addDownstreamMessageRelay(relay); + } + + @Override + public void clearUpstreamMessageRelays() { + _taskQueue.clearUpstreamMessageRelays(); + } + + @Override + public void clearDownstreamMessageRelays() { + _taskQueue.clearDownstreamMessageRelays(); + } + + @Override + public void setIXTransform(BiFunction transform) { + _taskQueue.setIXTransform(transform); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/SourceOOCStream.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/SourceOOCStream.java index df1b415cb96..553767ef8ce 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/stream/SourceOOCStream.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/SourceOOCStream.java @@ -21,6 +21,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; import org.apache.sysds.runtime.ooc.cache.OOCCacheManager; @@ -29,12 +30,13 @@ import org.apache.sysds.runtime.matrix.data.MatrixIndexes; import java.util.concurrent.ConcurrentHashMap; +import java.util.List; import java.util.concurrent.locks.LockSupport; public class SourceOOCStream extends SubscribableTaskQueue { private final ConcurrentHashMap _idx; private static final long BACKPRESSURE_PARK_NANOS = 1_000_000L; - private static final long MAX_BACKPRESSURE_PARK_NANOS = 2_000_000_000L; + private static final long MAX_BACKPRESSURE_PARK_NANOS = 200_000_000L; public SourceOOCStream() { this._idx = new ConcurrentHashMap<>(); @@ -49,6 +51,22 @@ public void enqueue(IndexedMatrixValue value, OOCIOHandler.SourceBlockDescriptor super.enqueue(value); } + public void enqueueGroup(List values, OOCIOHandler.GroupSourceBlockDescriptor descriptor) { + if (descriptor == null) + throw new IllegalArgumentException("Group source descriptor must not be null"); + if (values == null || values.isEmpty()) + return; + waitForBackpressure(); + boolean delivered = tryDeliverCallback(new SourceGroupCallback(values, descriptor), values.size()); + if (!delivered) { + // Fallback to individual enqueues if no subscriber yet + for (int i = 0; i < values.size(); i++) { + OOCIOHandler.SourceBlockDescriptor d = descriptor.blocks.get(i); + enqueue(values.get(i), d); + } + } + } + @Override public void enqueue(IndexedMatrixValue val) { throw new UnsupportedOperationException("Use enqueue(value, descriptor) for source streams"); @@ -78,4 +96,58 @@ public void messageUpstream(OOCStreamMessage msg) { return; super.messageUpstream(msg); } + + public static class SourceGroupCallback implements OOCStream.GroupQueueCallback { + private final List _data; + private final OOCIOHandler.GroupSourceBlockDescriptor _descriptor; + private DMLRuntimeException _failure; + + SourceGroupCallback(List data, OOCIOHandler.GroupSourceBlockDescriptor descriptor) { + _data = data; + _descriptor = descriptor; + } + + public OOCIOHandler.GroupSourceBlockDescriptor getDescriptor() { + return _descriptor; + } + + @Override + public int size() { + return _data.size(); + } + + @Override + public OOCStream.QueueCallback getCallback(int idx) { + return new OOCStream.SimpleQueueCallback<>(_data.get(idx), _failure); + } + + @Override + public IndexedMatrixValue get() { + throw new UnsupportedOperationException(); + } + + @Override + public OOCStream.QueueCallback keepOpen() { + return this; + } + + @Override + public void close() { + } + + @Override + public void fail(DMLRuntimeException failure) { + _failure = failure; + } + + @Override + public boolean isEos() { + return false; + } + + @Override + public boolean isFailure() { + return _failure != null; + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/SourceOOCStreamable.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/SourceOOCStreamable.java new file mode 100644 index 00000000000..4a5f018b2ac --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/SourceOOCStreamable.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.stream; + +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.instructions.ooc.CachingStream; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.ooc.OOCStreamable; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage; +import org.apache.sysds.runtime.util.IndexRange; + +import java.util.function.BiFunction; +import java.util.function.Consumer; + +public class SourceOOCStreamable implements OOCStreamable { + private final CacheableData _data; + + public SourceOOCStreamable(CacheableData data) { + _data = data; + } + + @Override + public OOCStream getReadStream() { + return _data.getStreamHandle(); + } + + @Override + public OOCStream getWriteStream() { + return _data.getStreamHandle(); + } + + @Override + public boolean hasStreamCache() { + return false; + } + + @Override + public CachingStream getStreamCache() { + return null; + } + + @Override + public boolean isProcessed() { + return false; + } + + @Override + public DataCharacteristics getDataCharacteristics() { + return _data.getDataCharacteristics(); + } + + @Override + public CacheableData getData() { + return _data; + } + + @Override + public void setData(CacheableData data) { + throw new UnsupportedOperationException(); + } + + @Override + public void messageUpstream(OOCStreamMessage msg) { + + } + + @Override + public void messageDownstream(OOCStreamMessage msg) { + + } + + @Override + public void setUpstreamMessageRelay(Consumer relay) { + + } + + @Override + public void setDownstreamMessageRelay(Consumer relay) { + + } + + @Override + public void addUpstreamMessageRelay(Consumer relay) { + + } + + @Override + public void addDownstreamMessageRelay(Consumer relay) { + + } + + @Override + public void clearUpstreamMessageRelays() { + + } + + @Override + public void clearDownstreamMessageRelays() { + + } + + @Override + public void setIXTransform(BiFunction transform) { + + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/SplittingOOCStream.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/SplittingOOCStream.java new file mode 100644 index 00000000000..a83ae4d01f3 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/SplittingOOCStream.java @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.stream; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.instructions.ooc.CachingStream; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage; +import org.apache.sysds.runtime.util.IndexRange; + +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; + +public class SplittingOOCStream implements OOCStream { + private OOCStream _sourceStream; + private SubOOCStream[] _subStreams; + + public SplittingOOCStream(OOCStream sourceStream, Function partitionFunc, int numPartitions) { + _sourceStream = sourceStream; + _subStreams = new SubOOCStream[numPartitions]; + for(int i = 0; i < numPartitions; i++) + _subStreams[i] = new SubOOCStream<>(this); + + _sourceStream.setSubscriber(cb -> { + try { + try(cb) { + if(cb.isFailure()) { + DMLRuntimeException failure; + try { + cb.get(); + failure = new DMLRuntimeException("Stream callback indicated failure without cause"); + } + catch(DMLRuntimeException re) { + failure = re; + } + + for(int i = 0; i < numPartitions; i++) { + SubOOCStream current = _subStreams[i]; + if(current != null) + current.propagateFailure(failure); + } + return; + } + + if(cb.isEos()) { + SubOOCStream current; + for(int i = 0; i < numPartitions; i++) { + // This requires no additional locking because we know EOS + // is always triggered after the last non EOS call finished + current = _subStreams[i]; + if(current != null) + current.closeInput(); + } + return; + } + + if(cb instanceof OOCStream.GroupQueueCallback) { + @SuppressWarnings("unchecked") + OOCStream.GroupQueueCallback group = (OOCStream.GroupQueueCallback) cb; + for(int gi = 0; gi < group.size(); gi++) { + OOCStream.QueueCallback sub = group.getCallback(gi); + try(sub) { + int partition = partitionFunc.apply(sub.get()); + if(partition < 0 || partition >= numPartitions) + throw new DMLRuntimeException("Invalid partition index: " + partition + " for " + numPartitions + " partitions"); + _subStreams[partition].enqueue(sub.keepOpen()); + } + } + return; + } + + int partition = partitionFunc.apply(cb.get()); + if(partition < 0 || partition >= numPartitions) + throw new DMLRuntimeException("Invalid partition index: " + partition + " for " + numPartitions + " partitions"); + _subStreams[partition].enqueue(cb.keepOpen()); + } + } + catch(DMLRuntimeException re) { + propagateFailure(re); + } + }); + } + + public OOCStream getSubStream(int idx) { + return _subStreams[idx]; + } + + @Override + public void enqueue(T t) { + throw new UnsupportedOperationException(); + } + + @Override + public T dequeue() { + throw new UnsupportedOperationException(); + } + + @Override + public void closeInput() { + throw new UnsupportedOperationException(); + } + + @Override + public void propagateFailure(DMLRuntimeException re) { + _sourceStream.propagateFailure(re); + for(SubOOCStream subStream : _subStreams) + subStream.propagateFailure(re); + } + + @Override + public boolean hasStreamCache() { + return _sourceStream.hasStreamCache(); + } + + @Override + public CachingStream getStreamCache() { + return _sourceStream.getStreamCache(); + } + + @Override + public void setSubscriber(Consumer> subscriber) { + throw new UnsupportedOperationException(); + } + + @Override + public OOCStream getReadStream() { + throw new UnsupportedOperationException(); + } + + @Override + public OOCStream getWriteStream() { + return _sourceStream.getWriteStream(); + } + + @Override + public boolean isProcessed() { + return _sourceStream.isProcessed(); + } + + @Override + public DataCharacteristics getDataCharacteristics() { + return _sourceStream.getDataCharacteristics(); + } + + @Override + public CacheableData getData() { + return _sourceStream.getData(); + } + + @Override + public void setData(CacheableData data) { + throw new UnsupportedOperationException(); + } + + @Override + public void messageUpstream(OOCStreamMessage msg) { + _sourceStream.messageUpstream(msg); + } + + @Override + public void messageDownstream(OOCStreamMessage msg) { + for(SubOOCStream sub : _subStreams) + sub.messageDownstream(msg); + } + + @Override + public void setUpstreamMessageRelay(Consumer relay) { + throw new UnsupportedOperationException(); + } + + @Override + public void setDownstreamMessageRelay(Consumer relay) { + throw new UnsupportedOperationException(); + } + + @Override + public void addUpstreamMessageRelay(Consumer relay) { + throw new UnsupportedOperationException(); + } + + @Override + public void addDownstreamMessageRelay(Consumer relay) { + throw new UnsupportedOperationException(); + } + + @Override + public void clearUpstreamMessageRelays() { + throw new UnsupportedOperationException(); + } + + @Override + public void clearDownstreamMessageRelays() { + throw new UnsupportedOperationException(); + } + + @Override + public void setIXTransform(BiFunction transform) { + throw new UnsupportedOperationException(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/SubOOCStream.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/SubOOCStream.java new file mode 100644 index 00000000000..e5908c18b04 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/SubOOCStream.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.ooc.stream; + +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.instructions.ooc.CachingStream; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; +import org.apache.sysds.runtime.meta.DataCharacteristics; +import org.apache.sysds.runtime.ooc.stream.message.OOCStreamMessage; +import org.apache.sysds.runtime.util.IndexRange; + +import java.util.function.BiFunction; +import java.util.function.Consumer; + +public class SubOOCStream implements OOCStream { + private OOCStream _sourceStream; + private SubscribableTaskQueue> _taskQueue; + private QueueCallback _last; + + public SubOOCStream(OOCStream sourceStream) { + _sourceStream = sourceStream; + _taskQueue = new SubscribableTaskQueue<>(); + _taskQueue.setUpstreamMessageRelay(_sourceStream::messageUpstream); + } + + public void enqueue(QueueCallback callback) { + _taskQueue.enqueue(callback); + } + + @Override + public void enqueue(T t) { + _taskQueue.enqueue(new SimpleQueueCallback<>(t, null)); + } + + @Override + public synchronized T dequeue() { + if(_last != null) + _last.close(); + _last = _taskQueue.dequeue(); + if(_last != null) + return _last.get(); + return null; + } + + @Override + public void closeInput() { + _taskQueue.closeInput(); + } + + @Override + public void propagateFailure(DMLRuntimeException re) { + _taskQueue.propagateFailure(re); + } + + @Override + public boolean hasStreamCache() { + return _sourceStream.hasStreamCache(); + } + + @Override + public CachingStream getStreamCache() { + return _sourceStream.getStreamCache(); + } + + @Override + public void setSubscriber(Consumer> subscriber) { + _taskQueue.setSubscriber(cb -> { + if(cb.isEos()) { + subscriber.accept(OOCStream.eos(null)); + return; + } + if(cb.isFailure()) { + try { + cb.get(); + subscriber.accept(OOCStream.eos(new DMLRuntimeException("Stream callback indicated failure without cause"))); + } + catch(DMLRuntimeException re) { + subscriber.accept(OOCStream.eos(re)); + } + } + else + subscriber.accept(cb.get()); + }); + } + + @Override + public OOCStream getReadStream() { + return this; + } + + @Override + public OOCStream getWriteStream() { + return _sourceStream.getWriteStream(); + } + + @Override + public boolean isProcessed() { + return _sourceStream.isProcessed(); + } + + @Override + public DataCharacteristics getDataCharacteristics() { + return _taskQueue.getDataCharacteristics(); + } + + @Override + public CacheableData getData() { + return _taskQueue.getData(); + } + + @Override + public void setData(CacheableData data) { + _taskQueue.setData(data); + } + + @Override + public void messageUpstream(OOCStreamMessage msg) { + _taskQueue.messageUpstream(msg); + } + + @Override + public void messageDownstream(OOCStreamMessage msg) { + _taskQueue.messageDownstream(msg); + } + + @Override + public void setUpstreamMessageRelay(Consumer relay) { + // Upstream is handled by source stream + throw new UnsupportedOperationException(); + } + + @Override + public void setDownstreamMessageRelay(Consumer relay) { + _taskQueue.setDownstreamMessageRelay(relay); + } + + @Override + public void addUpstreamMessageRelay(Consumer relay) { + throw new UnsupportedOperationException(); + } + + @Override + public void addDownstreamMessageRelay(Consumer relay) { + _taskQueue.addDownstreamMessageRelay(relay); + } + + @Override + public void clearUpstreamMessageRelays() { + _taskQueue.clearUpstreamMessageRelays(); + } + + @Override + public void clearDownstreamMessageRelays() { + _taskQueue.clearDownstreamMessageRelays(); + } + + @Override + public void setIXTransform(BiFunction transform) { + _taskQueue.setIXTransform(transform); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/TaskContext.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/TaskContext.java index 5b6381d4bec..7681f39180b 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/stream/TaskContext.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/TaskContext.java @@ -49,8 +49,8 @@ public static void defer(Runnable deferred) { if(ctx._deferred == null) ctx._deferred = new ArrayDeque<>(); ctx._deferred.add(deferred); - if(ctx._deferred.size() > 3) - System.out.println("[WARN] Defer size bigger than 3"); + //if(ctx._deferred.size() == 4 || ctx._deferred.size() % 100 == 0) + // System.out.println("[WARN] Defer size bigger than 3 (" + ctx._deferred.size() + ")"); } public static boolean runDeferred() { @@ -62,4 +62,9 @@ public static boolean runDeferred() { deferred.run(); return true; } + + public static boolean canPipe() { + TaskContext ctx = CTX.get(); + return ctx._deferred != null && ctx._deferred.size() < 3; + } } diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java index 9ec94b1025c..5102933911a 100644 --- a/src/main/java/org/apache/sysds/utils/Statistics.java +++ b/src/main/java/org/apache/sysds/utils/Statistics.java @@ -229,8 +229,10 @@ public Object getMeta(String key) { private static final LongAdder oocPutCalls = new LongAdder(); private static final LongAdder oocLoadFromDiskCalls = new LongAdder(); private static final LongAdder oocLoadFromDiskTimeNanos = new LongAdder(); + private static final LongAdder oocLoadFromDiskBytesSize = new LongAdder(); private static final LongAdder oocEvictionWriteCalls = new LongAdder(); private static final LongAdder oocEvictionWriteTimeNanos = new LongAdder(); + private static final LongAdder oocEvictionWriteBytesSize = new LongAdder(); private static final AtomicLong oocStatsStartTime = new AtomicLong(System.nanoTime()); public static long getNoOfExecutedSPInst() { @@ -356,8 +358,10 @@ public static void resetOOCEvictionStats() { oocPutCalls.reset(); oocLoadFromDiskCalls.reset(); oocLoadFromDiskTimeNanos.reset(); + oocLoadFromDiskBytesSize.reset(); oocEvictionWriteCalls.reset(); oocEvictionWriteTimeNanos.reset(); + oocEvictionWriteBytesSize.reset(); oocStatsStartTime.set(System.nanoTime()); } @@ -465,10 +469,18 @@ public static void accumulateOOCLoadFromDiskTime(long nanos) { oocLoadFromDiskTimeNanos.add(nanos); } + public static void accumulateOOCLoadFromDiskBytes(long bytes) { + oocLoadFromDiskBytesSize.add(bytes); + } + public static void accumulateOOCEvictionWriteTime(long nanos) { oocEvictionWriteTimeNanos.add(nanos); } + public static void accumulateOOCEvictionWriteBytes(long bytes) { + oocEvictionWriteBytesSize.add(bytes); + } + public static String displayOOCEvictionStats() { long elapsedNanos = Math.max(1, System.nanoTime() - oocStatsStartTime.get()); double elapsedSeconds = elapsedNanos / 1e9; @@ -483,10 +495,10 @@ public static String displayOOCEvictionStats() { oocGetCalls.longValue(), getThroughput)); sb.append(String.format(Locale.US, " put calls:\t\t%d (%.2f/sec)\n", oocPutCalls.longValue(), putThroughput)); - sb.append(String.format(Locale.US, " loadFromDisk:\t\t%d (time %.3f sec)\n", - oocLoadFromDiskCalls.longValue(), oocLoadFromDiskTimeNanos.longValue() / 1e9)); - sb.append(String.format(Locale.US, " evict writes:\t\t%d (time %.3f sec)\n", - oocEvictionWriteCalls.longValue(), oocEvictionWriteTimeNanos.longValue() / 1e9)); + sb.append(String.format(Locale.US, " loadFromDisk:\t\t%d (time %.3f sec, %.3f GB)\n", + oocLoadFromDiskCalls.longValue(), oocLoadFromDiskTimeNanos.longValue() / 1e9, oocLoadFromDiskBytesSize.longValue() / 1e9)); + sb.append(String.format(Locale.US, " evict writes:\t\t%d (time %.3f sec, %.3f GB)\n", + oocEvictionWriteCalls.longValue(), oocEvictionWriteTimeNanos.longValue() / 1e9, oocEvictionWriteBytesSize.longValue() / 1e9)); return sb.toString(); } diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/BinaryWritePartitioningTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryWritePartitioningTest.java new file mode 100644 index 00000000000..531b839615d --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/BinaryWritePartitioningTest.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; + +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecMode; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.runtime.io.IOUtilFunctions; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +public class BinaryWritePartitioningTest extends AutomatedTestBase { + private static final String TEST_NAME = "UnaryWrite"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + BinaryWritePartitioningTest.class.getSimpleName() + "/"; + private static final String INPUT_NAME = "X"; + private static final String OUTPUT_NAME = "res"; + private static final int ROWS = 3000; + private static final int COLS = 3000; + private static final int BLEN = 1000; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Test + public void testOOCBinaryWriteMultipartWhenParallelEnabled() { + runBinaryWritePartitioningTest(true, true); + } + + @Test + public void testOOCBinaryWriteSingleFileWhenParallelDisabled() { + runBinaryWritePartitioningTest(false, false); + } + + private void runBinaryWritePartitioningTest(boolean parallelBinaryWrite, boolean expectMultipart) { + Types.ExecMode oldPlatform = setExecMode(ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME); + String home = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = home + TEST_NAME + ".dml"; + File configFile = createParallelIOConfig(parallelBinaryWrite); + programArgs = new String[] { + "-config", configFile.getPath(), + "-explain", "-stats", "-ooc", + "-args", input(INPUT_NAME), output(OUTPUT_NAME)}; + + MatrixBlock in = MatrixBlock.randOperations(ROWS, COLS, 1.0, -1, 1, "uniform", 7); + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(FileFormat.BINARY); + writer.writeMatrixToHDFS(in, input(INPUT_NAME), ROWS, COLS, BLEN, in.getNonZeros()); + HDFSTool.writeMetaDataFile(input(INPUT_NAME + ".mtd"), ValueType.FP64, + new MatrixCharacteristics(ROWS, COLS, BLEN, in.getNonZeros()), FileFormat.BINARY); + + runTest(true, false, null, -1); + + int numBinaryFiles = countBinaryFiles(output(OUTPUT_NAME)); + boolean shouldBeMultipart = expectMultipart + && OptimizerUtils.getParallelBinaryWriteParallelism() > 1; + if(shouldBeMultipart) + Assert.assertTrue("Expected multipart binary output but found " + numBinaryFiles + " file(s).", + numBinaryFiles > 1); + else + Assert.assertEquals("Expected single-file binary output.", 1, numBinaryFiles); + + MatrixBlock out = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), FileFormat.BINARY, ROWS, COLS, BLEN, + in.getNonZeros()); + Assert.assertEquals(ROWS, out.getNumRows()); + Assert.assertEquals(COLS, out.getNumColumns()); + } + catch(Exception ex) { + Assert.fail(ex.getMessage()); + } + finally { + resetExecMode(oldPlatform); + } + } + + private File createParallelIOConfig(boolean parallelIO) throws IOException { + File baseConfig = getCurConfigFile(); + String xml = Files.readString(baseConfig.toPath(), StandardCharsets.UTF_8); + String prop = " " + parallelIO + "\n"; + String updated = xml.contains("") ? xml.replace("", prop + "") : xml + "\n\n" + prop + "\n"; + + File out = new File(getCurLocalTempDir(), "SystemDS-config-ooc-pario-" + parallelIO + ".xml"); + Files.writeString(out.toPath(), updated, StandardCharsets.UTF_8); + return out; + } + + private int countBinaryFiles(String path) throws IOException { + Path outPath = new Path(path); + FileSystem fs = IOUtilFunctions.getFileSystem(outPath); + return IOUtilFunctions.getSequenceFilePaths(fs, outPath).length; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/CSVReaderTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/CSVReaderTest.java index 5f5f7fb42f6..0897d4e07c7 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/CSVReaderTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/CSVReaderTest.java @@ -74,12 +74,12 @@ public void testCSVReaderSparseWide() { @Test public void testCSVReaderDenseUltraWide() { - runCSVReaderTest(false, 50, 200000); + runCSVReaderTest(false, 10, 200000); } @Test public void testCSVReaderDenseLarge() { - runCSVReaderTest(false, 750, 50000); + runCSVReaderTest(false, 400, 25000); } @Test @@ -89,7 +89,7 @@ public void testCSVReaderSparseLarge() { @Test public void testCSVReaderDenseLarge2() { - runCSVReaderTest(false, 1200, 25000); + runCSVReaderTest(false, 1200, 10000); } private void runCSVReaderTest(boolean sparse, int rows, int cols) { diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java index 5b348de2ed9..f4c4d364e83 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/LmCGTest.java @@ -65,7 +65,6 @@ public void testLmCGSparse() { runLmCGTest(true); } - // TODO codex resume 019bb84d-bac6-7fd1-bfb8-a149e715e5b5 private void runLmCGTest(boolean sparse) { Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java new file mode 100644 index 00000000000..a25249985d6 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import java.io.IOException; + +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; + +public class PNMFTest extends AutomatedTestBase { + private static final String TEST_NAME = "PNMF"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + PNMFTest.class.getSimpleName() + "/"; + + private static final String INPUT_X = "X"; + private static final String OUTPUT_W_OOC = "W"; + private static final String OUTPUT_H_OOC = "H"; + private static final String OUTPUT_W_CP = "W_cp"; + private static final String OUTPUT_H_CP = "H_cp"; + + private static final int ROWS = 1468; + private static final int COLS = 1207; + private static final int RANK = 20; + private static final int MAX_ITER = 10; + private static final int BLOCK_SIZE = 1000; + + private static final double SPARSITY = 0.7; + private static final double EPS = 1e-6; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + //@Test + public void testPNMFOOCVsCP() { + runPNMFTest(); + } + + private void runPNMFTest() { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + getAndLoadTestConfiguration(TEST_NAME); + + String home = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = home + TEST_NAME + ".dml"; + + double[][] xData = getRandomMatrix(ROWS, COLS, 1, 10, SPARSITY, 7); + writeBinaryWithMTD(INPUT_X, DataConverter.convertToMatrixBlock(xData)); + + programArgs = new String[] {"-explain", "-stats", "-seed", "7", "-ooc", "-args", + input(INPUT_X), String.valueOf(RANK), String.valueOf(MAX_ITER), + output(OUTPUT_W_OOC), output(OUTPUT_H_OOC)}; + runTest(true, false, null, -1); + + programArgs = new String[] {"-explain", "-stats", "-seed", "7", "-args", + input(INPUT_X), String.valueOf(RANK), String.valueOf(MAX_ITER), + output(OUTPUT_W_CP), output(OUTPUT_H_CP)}; + runTest(true, false, null, -1); + + MatrixBlock wOOC = DataConverter.readMatrixFromHDFS(output(OUTPUT_W_OOC), + Types.FileFormat.BINARY, ROWS, RANK, BLOCK_SIZE); + MatrixBlock hOOC = DataConverter.readMatrixFromHDFS(output(OUTPUT_H_OOC), + Types.FileFormat.BINARY, RANK, COLS, BLOCK_SIZE); + + MatrixBlock wCP = DataConverter.readMatrixFromHDFS(output(OUTPUT_W_CP), + Types.FileFormat.BINARY, ROWS, RANK, BLOCK_SIZE); + MatrixBlock hCP = DataConverter.readMatrixFromHDFS(output(OUTPUT_H_CP), + Types.FileFormat.BINARY, RANK, COLS, BLOCK_SIZE); + + TestUtils.compareMatrices(wOOC, wCP, EPS); + TestUtils.compareMatrices(hOOC, hCP, EPS); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/SplitMergeOOCStreamTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/SplitMergeOOCStreamTest.java new file mode 100644 index 00000000000..f70c1a0d7e1 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/SplitMergeOOCStreamTest.java @@ -0,0 +1,299 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.instructions.ooc.CachingStream; +import org.apache.sysds.runtime.instructions.ooc.OOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.ooc.SubscribableTaskQueue; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.meta.MetaDataFormat; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiFunction; +import java.util.function.Function; + +public class SplitMergeOOCStreamTest extends AutomatedTestBase { + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + } + + @Test + public void testSplitMergeNonCached() { + PrimitiveHarness h = new PrimitiveHarness(); + SubscribableTaskQueue src = new SubscribableTaskQueue<>(); + List input = createInput(24); + + List> splits = h.split(src, partitionByRow(3), 3); + OOCStream merged = h.merge(splits); + + Assert.assertFalse("Merged stream should not expose cache for non-cached inputs", merged.hasStreamCache()); + feed(src, input); + Map outByKey = h.collect(merged); + assertSameItems(outByKey, input); + } + + @Test + public void testSplitMergeCached() { + PrimitiveHarness h = new PrimitiveHarness(); + SubscribableTaskQueue base = new SubscribableTaskQueue<>(); + CachingStream cache = new CachingStream(base); + List input = createInput(24); + + feed(base, input); + + OOCStream cachedRead = cache.getReadStream(); + List> splits = h.split(cachedRead, partitionByRow(4), 4); + OOCStream merged = h.merge(splits); + + Assert.assertTrue("Merged stream should expose cache when all inputs share one cache", merged.hasStreamCache()); + Assert.assertSame("Merged stream should expose the shared cache", cache, merged.getStreamCache()); + Map outByKey = h.collect(merged); + assertSameItems(outByKey, input); + } + + @Test + public void testSplitMergeCbind1500x3000() { + PrimitiveHarness h = new PrimitiveHarness(); + SubscribableTaskQueue leftSrc = new SubscribableTaskQueue<>(); + SubscribableTaskQueue rightSrc = new SubscribableTaskQueue<>(); + + // Two logical 1500x1500 matrices tiled into 1k x 1k blocks. + List leftInput = createTiled1500Input(11, 12, 21, 22); + List rightInput = createTiled1500Input(111, 112, 121, 122); + + List> leftByRow = h.split(leftSrc, imv -> (int) (imv.getIndexes().getRowIndex() - 1), 2); + List> rightByRow = h.split(rightSrc, imv -> (int) (imv.getIndexes().getRowIndex() - 1), 2); + + feed(leftSrc, leftInput); + feed(rightSrc, rightInput); + + OOCStream row1Out = buildCbindRowPartition(h, leftByRow.get(0), rightByRow.get(0)); + OOCStream row2Out = buildCbindRowPartition(h, leftByRow.get(1), rightByRow.get(1)); + OOCStream merged = h.merge(List.of(row1Out, row2Out)); + MatrixObject outMo = createStreamBackedMatrixObject(merged, 1500, 3000, 1000); + MatrixBlock oocCbind = outMo.acquireReadAndRelease(); + + MatrixBlock leftCp = materializeThroughMatrixObject(leftInput, 1500, 1500, 1000); + MatrixBlock rightCp = materializeThroughMatrixObject(rightInput, 1500, 1500, 1000); + MatrixBlock cpCbind = leftCp.append(rightCp); + TestUtils.compareMatrices(cpCbind, oocCbind, 0.0, "OOC cbind result differs from CP cbind result"); + } + + private static Function partitionByRow(int numPartitions) { + return imv -> (int)Math.floorMod(imv.getIndexes().getRowIndex(), numPartitions); + } + + private static void feed(SubscribableTaskQueue src, List input) { + for(IndexedMatrixValue imv : input) + src.enqueue(imv); + src.closeInput(); + } + + private static List createInput(int n) { + List input = new ArrayList<>(n); + for(int i = 1; i <= n; i++) + input.add(new IndexedMatrixValue(new MatrixIndexes(i, 1), new MatrixBlock(1, 1, (double)i))); + return input; + } + + private static List createTiled1500Input(double v11, double v12, double v21, double v22) { + return List.of( + new IndexedMatrixValue(new MatrixIndexes(1, 1), new MatrixBlock(1000, 1000, v11)), + new IndexedMatrixValue(new MatrixIndexes(1, 2), new MatrixBlock(1000, 500, v12)), + new IndexedMatrixValue(new MatrixIndexes(2, 1), new MatrixBlock(500, 1000, v21)), + new IndexedMatrixValue(new MatrixIndexes(2, 2), new MatrixBlock(500, 500, v22)) + ); + } + + private OOCStream buildCbindRowPartition(PrimitiveHarness h, + OOCStream leftPart, OOCStream rightPart) { + List> leftByCol = h.split(leftPart, imv -> (int) (imv.getIndexes().getColumnIndex() - 1), 2); + List> rightByCol = h.split(rightPart, imv -> (int) (imv.getIndexes().getColumnIndex() - 1), 2); + + OOCStream leftC1 = leftByCol.get(0); + OOCStream leftC2 = leftByCol.get(1); + OOCStream rightC1 = rightByCol.get(0); + OOCStream rightC2 = rightByCol.get(1); + + CachingStream rightC1Cache = h.cache(rightC1); + OOCStream rightC1ForCritical = rightC1Cache.getReadStream(); + OOCStream rightC1ForTail = rightC1Cache.getReadStream(); + + SubscribableTaskQueue outCol1 = new SubscribableTaskQueue<>(); + SubscribableTaskQueue outCol2 = new SubscribableTaskQueue<>(); + SubscribableTaskQueue outCol3 = new SubscribableTaskQueue<>(); + Function rowKey = imv -> new MatrixIndexes(imv.getIndexes().getRowIndex(), 1); + + CompletableFuture f1 = h.map(leftC1, outCol1, imv -> new IndexedMatrixValue( + new MatrixIndexes(imv.getIndexes().getRowIndex(), 1), imv.getValue())); + CompletableFuture f2 = h.join(leftC2, rightC1ForCritical, outCol2, (left, right) -> { + MatrixBlock lb = (MatrixBlock) left.getValue(); + MatrixBlock rb = (MatrixBlock) right.getValue(); + MatrixBlock critical = cbindBlocks(lb, sliceCols(rb, 0, 500)); + return new IndexedMatrixValue(new MatrixIndexes(left.getIndexes().getRowIndex(), 2), critical); + }, rowKey); + CompletableFuture f3 = h.join(rightC1ForTail, rightC2, outCol3, (r1, r2) -> { + MatrixBlock rb1 = (MatrixBlock) r1.getValue(); + MatrixBlock rb2 = (MatrixBlock) r2.getValue(); + MatrixBlock tail = cbindBlocks(sliceCols(rb1, 500, 1000), rb2); + return new IndexedMatrixValue(new MatrixIndexes(r1.getIndexes().getRowIndex(), 3), tail); + }, rowKey); + + h.await(CompletableFuture.allOf(f1, f2, f3)); + return h.merge(List.of(outCol1, outCol2, outCol3)); + } + + private static MatrixBlock sliceCols(MatrixBlock in, int colStart, int colEndExclusive) { + int rows = in.getNumRows(); + int cols = colEndExclusive - colStart; + MatrixBlock out = new MatrixBlock(rows, cols, false); + for(int r = 0; r < rows; r++) { + for(int c = 0; c < cols; c++) + out.set(r, c, in.get(r, colStart + c)); + } + return out; + } + + private static MatrixBlock cbindBlocks(MatrixBlock left, MatrixBlock right) { + int rows = left.getNumRows(); + if(rows != right.getNumRows()) + throw new IllegalArgumentException("Row mismatch in cbindBlocks"); + int lCols = left.getNumColumns(); + int rCols = right.getNumColumns(); + MatrixBlock out = new MatrixBlock(rows, lCols + rCols, false); + for(int r = 0; r < rows; r++) { + for(int c = 0; c < lCols; c++) + out.set(r, c, left.get(r, c)); + for(int c = 0; c < rCols; c++) + out.set(r, lCols + c, right.get(r, c)); + } + return out; + } + + private static MatrixObject createStreamBackedMatrixObject(OOCStream stream, long rows, + long cols, int blen) { + MatrixCharacteristics mc = new MatrixCharacteristics(rows, cols, blen, -1); + MatrixObject mo = new MatrixObject(ValueType.FP64, null, new MetaDataFormat(mc, FileFormat.BINARY)); + mo.setStreamHandle(stream); + return mo; + } + + private static MatrixBlock materializeThroughMatrixObject(List blocks, int rows, int cols, int blen) { + SubscribableTaskQueue src = new SubscribableTaskQueue<>(); + feed(src, blocks); + return createStreamBackedMatrixObject(src, rows, cols, blen).acquireReadAndRelease(); + } + + private static long key(long row, long col) { + return row * 1_000_000L + col; + } + + private static long key(IndexedMatrixValue imv) { + return key(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex()); + } + + private static void assertSameItems(Map outByKey, List expected) { + Set expectedKeys = new HashSet<>(expected.size()); + for(IndexedMatrixValue imv : expected) + expectedKeys.add(key(imv)); + + Assert.assertEquals("Unexpected number of output blocks", expected.size(), outByKey.size()); + Assert.assertEquals("Output keys differ from input keys", expectedKeys, outByKey.keySet()); + } + + private static class PrimitiveHarness extends OOCInstruction { + PrimitiveHarness() { + super(OOCType.Tee, "split_merge_test", "split_merge_test"); + } + + @Override + public void processInstruction(ExecutionContext ec) {} + + List> split(OOCStream source, Function partitionFunc, int numPartitions) { + return splitOOCStream(source, partitionFunc, numPartitions); + } + + OOCStream merge(List> streams) { + return mergeOOCStreams(streams); + } + + CachingStream cache(OOCStream stream) { + return new CachingStream(stream); + } + + CompletableFuture map(OOCStream in, OOCStream out, + Function mapper) { + return mapOOC(in, out, mapper); + } + + CompletableFuture join(OOCStream left, OOCStream right, + OOCStream out, + BiFunction mapper, + Function keyFn) { + return joinOOC(left, right, out, mapper, keyFn); + } + + Map collect(OOCStream stream) { + Map out = new ConcurrentHashMap<>(); + await(collectToMap(stream, out)); + return out; + } + + void await(CompletableFuture future) { + try { + future.join(); + } + catch(CompletionException ex) { + throw ex.getCause() instanceof RuntimeException ? (RuntimeException) ex.getCause() : ex; + } + } + + private CompletableFuture collectToMap(OOCStream stream, Map out) { + addInStream(stream); + addOutStream(); + return submitOOCTasks(stream, cb -> { + IndexedMatrixValue item = cb.get(); + long k = key(item); + IndexedMatrixValue prev = out.putIfAbsent(k, item); + Assert.assertNull("Duplicate output item for key " + k, prev); + }); + } + } +} diff --git a/src/test/scripts/functions/ooc/PNMF.dml b/src/test/scripts/functions/ooc/PNMF.dml new file mode 100644 index 00000000000..60aecb8963f --- /dev/null +++ b/src/test/scripts/functions/ooc/PNMF.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = read($1); +[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE); + +write(W, $4, format="binary"); +write(H, $5, format="binary"); From 2c6e44032ee8739a8b238162d9e1b9cbb4b3738c Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Thu, 19 Feb 2026 19:13:09 +0100 Subject: [PATCH 2/3] Fix Formatting --- .../instructions/ooc/CachingStream.java | 2 +- .../sysds/runtime/ooc/cache/OOCIOHandler.java | 24 +++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java index 6aa65b9f723..7e1bdac73d2 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/CachingStream.java @@ -134,7 +134,7 @@ public CachingStream(OOCStream source, long streamId) { mCallback = tmp.keepOpen(); } else { - List values = new java.util.ArrayList<>(groupSize); + List values = new ArrayList<>(groupSize); long totalSize = 0; for(int gi = 0; gi < groupSize; gi++) { OOCStream.QueueCallback sub = group.getCallback(gi); diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java index 699a4f65493..0bc5ace1274 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCIOHandler.java @@ -19,6 +19,11 @@ package org.apache.sysds.runtime.ooc.cache; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.ooc.OOCStream; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; + import java.util.concurrent.CompletableFuture; import java.util.List; @@ -58,18 +63,18 @@ interface SourceReadContinuation {} class SourceReadRequest { public final String path; - public final org.apache.sysds.common.Types.FileFormat format; + public final Types.FileFormat format; public final long rows; public final long cols; public final int blen; public final long estNnz; public final long maxBytesInFlight; public final boolean keepOpenOnLimit; - public final org.apache.sysds.runtime.instructions.ooc.OOCStream target; + public final OOCStream target; - public SourceReadRequest(String path, org.apache.sysds.common.Types.FileFormat format, long rows, long cols, + public SourceReadRequest(String path, Types.FileFormat format, long rows, long cols, int blen, long estNnz, long maxBytesInFlight, boolean keepOpenOnLimit, - org.apache.sysds.runtime.instructions.ooc.OOCStream target) { + OOCStream target) { this.path = path; this.format = format; this.rows = rows; @@ -99,14 +104,14 @@ public SourceReadResult(long bytesRead, boolean eof, SourceReadContinuation cont class SourceBlockDescriptor { public final String path; - public final org.apache.sysds.common.Types.FileFormat format; - public final org.apache.sysds.runtime.matrix.data.MatrixIndexes indexes; + public final Types.FileFormat format; + public final MatrixIndexes indexes; public final long offset; public final int recordLength; public final long serializedSize; - public SourceBlockDescriptor(String path, org.apache.sysds.common.Types.FileFormat format, - org.apache.sysds.runtime.matrix.data.MatrixIndexes indexes, long offset, int recordLength, + public SourceBlockDescriptor(String path, Types.FileFormat format, + MatrixIndexes indexes, long offset, int recordLength, long serializedSize) { this.path = path; this.format = format; @@ -121,8 +126,7 @@ class GroupSourceBlockDescriptor extends SourceBlockDescriptor { public final List blocks; public final int count; - public GroupSourceBlockDescriptor(String path, org.apache.sysds.common.Types.FileFormat format, - org.apache.sysds.runtime.matrix.data.MatrixIndexes indexes, long offset, int recordLength, + public GroupSourceBlockDescriptor(String path, Types.FileFormat format, MatrixIndexes indexes, long offset, int recordLength, long serializedSize, List blocks) { super(path, format, indexes, offset, recordLength, serializedSize); this.blocks = blocks; From ad6f8df5d8cdcb898910589bcdd3dff93228a158 Mon Sep 17 00:00:00 2001 From: Jannik Lindemann Date: Fri, 20 Feb 2026 09:11:32 +0100 Subject: [PATCH 3/3] Fix Formatting --- .../apache/sysds/runtime/ooc/cache/OOCCacheManager.java | 9 +++++---- .../org/apache/sysds/runtime/ooc/stream/TaskContext.java | 2 -- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java index 7f4f6be28ac..26e8f010341 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/cache/OOCCacheManager.java @@ -31,6 +31,7 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; +import java.util.ArrayList; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; @@ -187,7 +188,7 @@ public static CompletableFuture> req public static CompletableFuture>> requestManyBlocks(List keys) { return getCache().request(keys).thenApply( l -> { - List> out = new java.util.ArrayList<>(l.size()); + List> out = new ArrayList<>(l.size()); for (int i = 0; i < l.size(); i++) out.add(toCallback(l.get(i), keys.get(i), null)); return out; @@ -198,7 +199,7 @@ public static List> tryRequestManyBl List entries = getCache().tryRequest(keys); if(entries == null) return null; - List> out = new java.util.ArrayList<>(entries.size()); + List> out = new ArrayList<>(entries.size()); for (int i = 0; i < entries.size(); i++) out.add(toCallback(entries.get(i), keys.get(i), null)); return out; @@ -208,7 +209,7 @@ public static CompletableFuture return getCache().requestAnyOf(keys, n, sel) .thenApply( l -> { - List> out = new java.util.ArrayList<>(l.size()); + List> out = new ArrayList<>(l.size()); for (int i = 0; i < l.size(); i++) { BlockKey key = sel.size() == l.size() ? sel.get(i) : keys.get(i); out.add(toCallback(l.get(i), key, null)); @@ -218,7 +219,7 @@ public static CompletableFuture } private static OOCStream.QueueCallback toCallback(BlockEntry entry, BlockKey key, DMLRuntimeException failure) { - if (entry.getData() instanceof java.util.List) { + if (entry.getData() instanceof List) { CachedGroupCallback group = new CachedGroupCallback<>(entry, failure); if (key instanceof GroupedBlockKey gk) { OOCStream.QueueCallback sub = group.getCallback(gk.getGroupIndex()); diff --git a/src/main/java/org/apache/sysds/runtime/ooc/stream/TaskContext.java b/src/main/java/org/apache/sysds/runtime/ooc/stream/TaskContext.java index 7681f39180b..0cbc6d1b9be 100644 --- a/src/main/java/org/apache/sysds/runtime/ooc/stream/TaskContext.java +++ b/src/main/java/org/apache/sysds/runtime/ooc/stream/TaskContext.java @@ -49,8 +49,6 @@ public static void defer(Runnable deferred) { if(ctx._deferred == null) ctx._deferred = new ArrayDeque<>(); ctx._deferred.add(deferred); - //if(ctx._deferred.size() == 4 || ctx._deferred.size() % 100 == 0) - // System.out.println("[WARN] Defer size bigger than 3 (" + ctx._deferred.size() + ")"); } public static boolean runDeferred() {