Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/main/java/org/apache/sysds/hops/ipa/IPAPassInjectOOCTee.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.hops.ipa;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysds.hops.rewrite.ProgramRewriter;
import org.apache.sysds.hops.rewrite.RewriteInjectOOCTee;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.LanguageException;

/**
* Applies OOC tee injection after static/dynamic rewrites in IPA.
*/
public class IPAPassInjectOOCTee extends IPAPass {
@Override
public boolean isApplicable(FunctionCallGraph fgraph) {
return DMLScript.USE_OOC;
}

@Override
public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
try {
ProgramRewriter rewriter = new ProgramRewriter(new RewriteInjectOOCTee());
ProgramRewriteStatus status = new ProgramRewriteStatus();
rewriter.rewriteProgramHopDAGs(prog, true, status);
return false;
}
catch(LanguageException ex) {
throw new HopsException(ex);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.hops.ipa;

import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

/**
* Prune stale parent links by keeping only parent references reachable from the statement block roots/predicates.
*/
public class IPAPassPruneUnreachableHops extends IPAPass {
@Override
public boolean isApplicable(FunctionCallGraph fgraph) {
return DMLScript.USE_OOC;
}

@Override
public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
pruneStatementBlocks(prog.getStatementBlocks());
for(FunctionStatementBlock fsb : prog.getFunctionStatementBlocks())
pruneStatementBlocks(((FunctionStatement) fsb.getStatement(0)).getBody());
return false;
}

private static void pruneStatementBlocks(List<StatementBlock> sbs) {
for(StatementBlock sb : sbs) {
if(sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) sb.getStatement(0);
pruneHops(wsb.getPredicateHops());
pruneStatementBlocks(wstmt.getBody());
}
else if(sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement) sb.getStatement(0);
pruneHops(isb.getPredicateHops());
pruneStatementBlocks(istmt.getIfBody());
if(istmt.getElseBody() != null)
pruneStatementBlocks(istmt.getElseBody());
}
else if(sb instanceof ForStatementBlock) {
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) sb.getStatement(0);
pruneHops(fsb.getFromHops());
pruneHops(fsb.getToHops());
pruneHops(fsb.getIncrementHops());
pruneStatementBlocks(fstmt.getBody());
}
else if(sb instanceof FunctionStatementBlock) {
FunctionStatement fstmt = (FunctionStatement) sb.getStatement(0);
pruneStatementBlocks(fstmt.getBody());
}
else {
pruneHops(sb.getHops());
}
}
}

private static void pruneHops(Hop root) {
if(root == null)
return;
Set<Long> reachable = new HashSet<>();
collectReachable(root, reachable);
pruneParents(root, reachable, new HashSet<Long>());
}

private static void pruneHops(List<Hop> roots) {
if(roots == null || roots.isEmpty())
return;

Set<Long> reachable = new HashSet<>();
for(Hop root : roots)
collectReachable(root, reachable);

for(Hop root : roots)
pruneParents(root, reachable, new HashSet<Long>());
}

private static void collectReachable(Hop hop, Set<Long> reachable) {
if(hop == null || !reachable.add(hop.getHopID()))
return;
for(Hop in : hop.getInput())
collectReachable(in, reachable);
}

private static void pruneParents(Hop hop, Set<Long> reachable, Set<Long> visited) {
if(hop == null || !visited.add(hop.getHopID()))
return;
hop.getParent().removeIf(p -> !reachable.contains(p.getHopID()));
for(Hop in : hop.getInput())
pruneParents(in, reachable, visited);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.FunctionOp.FunctionType;
Expand Down Expand Up @@ -141,6 +142,10 @@ public InterProceduralAnalysis(DMLProgram dmlp) {
//would require an update of the function call graph
_passes.add(new IPAPassForwardFunctionCalls());
_passes.add(new IPAPassApplyStaticAndDynamicHopRewrites());
if (DMLScript.USE_OOC) {
_passes.add(new IPAPassPruneUnreachableHops());
_passes.add(new IPAPassInjectOOCTee());
}
}

public InterProceduralAnalysis(StatementBlock sb) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
_dagRuleSet.add( new RewriteConstantFolding() ); //dependency: cse
_sbRuleSet.add( new RewriteRemoveEmptyBasicBlocks() );
_sbRuleSet.add( new RewriteRemoveEmptyForLoops() );
_sbRuleSet.add( new RewriteInjectOOCTee() );
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaData;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.ooc.stream.SourceOOCStreamable;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.LocalFileUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
Expand Down Expand Up @@ -496,7 +497,7 @@ public synchronized OOCStream<IndexedMatrixValue> getStreamHandle() {
}

public OOCStreamable<IndexedMatrixValue> getStreamable() {
return _streamHandle;
return _streamHandle == null ? new SourceOOCStreamable(this) : _streamHandle;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.Lineage;
import org.apache.sysds.runtime.lineage.LineageCache;
Expand Down Expand Up @@ -172,6 +174,8 @@ public void processInstruction(ExecutionContext ec) {

//set input parameter
functionVariables.put(currFormalParam.getName(), value);
if (DMLScript.USE_OOC && value instanceof MatrixObject)
TeeOOCInstruction.incrRef(((MatrixObject) value).getStreamable(), 1);

//map lineage to function arguments
if( lineage != null ) {
Expand Down Expand Up @@ -227,7 +231,8 @@ public void processInstruction(ExecutionContext ec) {
if( expectRetVars.contains(varName) )
continue;
//cleanup unexpected return values to avoid leaks
fn_ec.cleanupDataObject(fn_ec.removeVariable(varName));
//(including OOC reference tracking for matrix streams)
VariableCPInstruction.processRmvarInstruction(fn_ec, varName);
}

// Unpin the pinned variables
Expand All @@ -247,10 +252,12 @@ public void processInstruction(ExecutionContext ec) {

// remove existing data bound to output variable name
Data exdata = ec.removeVariable(boundVarName);
if (DMLScript.USE_OOC && exdata instanceof MatrixObject && exdata != boundValue)
TeeOOCInstruction.incrRef(((MatrixObject) exdata).getStreamable(), -1);
// save old data for cleanup later
if (exdata != boundValue && !retVars.hasReferences(exdata))
toBeCleanedUp.add(exdata);
//FIXME: interferes with reuse. Removes broadcasts before materialization
//FIXME: interferes with reuse. Removes broadcasts before materialization

//add/replace data in symbol table
ec.setVariable(boundVarName, boundValue);
Expand All @@ -276,11 +283,17 @@ public void processInstruction(ExecutionContext ec) {
//update lineage cache with the functions outputs
if ((DMLScript.LINEAGE && LineageCacheConfig.isMultiLevelReuse() && !fpb.isNondeterministic())
|| (LineageCacheConfig.isEstimator() && !fpb.isNondeterministic())) {
LineageCache.putValue(fpb.getOutputParams(), liInputs,
LineageCache.putValue(fpb.getOutputParams(), liInputs,
getCacheFunctionName(_functionName, fpb), fn_ec, t1-t0);
//FIXME: send _boundOutputNames instead of fpb.getOutputParams as
//FIXME: send _boundOutputNames instead of fpb.getOutputParams as
//those are already replaced by boundoutput names in the lineage map.
}

// cleanup declared outputs that are not bound at callsite
for (int i = numOutputs; i < fpb.getOutputParams().size(); i++) {
String retVarName = fpb.getOutputParams().get(i).getName();
VariableCPInstruction.processRmvarInstruction(fn_ec, retVarName);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ public void processInstruction( ExecutionContext ec ) {
else {
OOCStream<MatrixBlock> qLocal = createWritableStream();

mapOOC(qIn, qLocal, tmp -> (MatrixBlock) ((MatrixBlock) tmp.getValue())
mapOOC(qIn, qLocal, tmp -> (MatrixBlock) tmp.getValue()
.aggregateUnaryOperations(aggun, new MatrixBlock(), blen, tmp.getIndexes()));

MatrixBlock ltmp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) {
boolean isRowBroadcast = m1.getNumRows() > 1 && m2.getNumRows() == 1;

if (isColBroadcast && !isRowBroadcast) {
final long maxProcessesPerBroadcast = m1.getNumColumns() / m1.getBlocksize();
final long maxProcessesPerBroadcast = (m1.getNumColumns() + m1.getBlocksize() - 1) / m1.getBlocksize();

broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> {
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
Expand All @@ -96,7 +96,7 @@ protected void processMatrixMatrixInstruction(ExecutionContext ec) {
}, tmp -> tmp.getIndexes().getRowIndex());
}
else if (isRowBroadcast && !isColBroadcast) {
final long maxProcessesPerBroadcast = m1.getNumRows() / m1.getBlocksize();
final long maxProcessesPerBroadcast = (m1.getNumRows() + m1.getBlocksize() - 1) / m1.getBlocksize();

broadcastJoinOOC(qIn1, qIn2, qOut, (tmp1, b) -> {
IndexedMatrixValue tmpOut = new IndexedMatrixValue();
Expand Down
Loading
Loading