Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions scripts/builtin/pnmf.dml
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@
# H List of amplitude matrices, one for each repetition.
# ------------------------------------------------------------------------------------

m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer maxi = 10, Boolean verbose=TRUE)
m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer maxi = 10, Boolean verbose=TRUE, Integer seed=-1)
return (Matrix[Double] W, Matrix[Double] H)
{
#initialize W and H
W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025);
H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025);
W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025, seed=seed);
H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025, seed=seed);

i = 0;
while(i < maxi) {
Expand Down
16 changes: 16 additions & 0 deletions src/main/java/org/apache/sysds/hops/QuaternaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ else if( et == ExecType.SPARK )
constructCPLopsWeightedDivMM(wtype);
else if( et == ExecType.SPARK )
constructSparkLopsWeightedDivMM(wtype);
else if( et == ExecType.OOC )
constructOOCLopsWeightedDivMM(wtype);
else
throw new HopsException("Unsupported quaternaryop-wdivmm exec type: "+et);
break;
Expand Down Expand Up @@ -462,6 +464,20 @@ private void constructSparkLopsWeightedDivMM( WDivMMType wtype )
}
}

private void constructOOCLopsWeightedDivMM(WDivMMType wtype)
{
WeightedDivMM wdiv = new WeightedDivMM(
getInput().get(0).constructLops(),
getInput().get(1).constructLops(),
getInput().get(2).constructLops(),
getInput().get(3).constructLops(),
getDataType(), getValueType(), wtype, ExecType.OOC);

setOutputDimensions(wdiv);
setLineNumbers(wdiv);
setLops(wdiv);
}

private void constructCPLopsWeightedCeMM(WCeMMType wtype)
{
WeightedCrossEntropy wcemm = new WeightedCrossEntropy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.AppendOOCInstruction;
import org.apache.sysds.runtime.instructions.ooc.QuaternaryOOCInstruction;

public class OOCInstructionParser extends InstructionParser {
protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName());
Expand Down Expand Up @@ -111,6 +112,8 @@ else if(parts.length == 4)
return DataGenOOCInstruction.parseInstruction(str);
case Append:
return AppendOOCInstruction.parseInstruction(str);
case Quaternary:
return QuaternaryOOCInstruction.parseInstruction(str);

default:
throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

public abstract class ComputationOOCInstruction extends OOCInstruction {
public CPOperand output;
public CPOperand input1, input2, input3;
public CPOperand input1, input2, input3, input4;

protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand out, String opcode, String istr) {
super(type, op, opcode, istr);
Expand All @@ -50,6 +50,15 @@ protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CP
output = out;
}

protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr) {
super(type, op, opcode, istr);
input1 = in1;
input2 = in2;
input3 = in3;
input4 = in4;
output = out;
}

public String getOutputVariableName() {
return output.getName();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public abstract class OOCInstruction extends Instruction {

public enum OOCType {
Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, AggregateBinary, AggregateTernary, MAPMM, MMTSJ,
MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append
MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append, Quaternary
}

protected final OOCInstruction.OOCType _ooctype;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions.ooc;


import org.apache.sysds.common.Opcodes;
import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;

public abstract class QuaternaryOOCInstruction extends ComputationOOCInstruction {

protected QuaternaryOOCInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
CPOperand out, String opcode, String istr) {
super(OOCType.Quaternary, op, in1, in2, in3, in4, out, opcode, istr);
}

public static QuaternaryOOCInstruction parseInstruction(String str) {
String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];

if(opcode.contains(Opcodes.WEIGHTEDDIVMM.toString())) {
InstructionUtils.checkNumFields(parts, 6);
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand in3 = new CPOperand(parts[3]);
CPOperand in4 = new CPOperand(parts[4]);
CPOperand out = new CPOperand(parts[5]);
QuaternaryOperator qop = new QuaternaryOperator(WDivMMType.valueOf(parts[6]));
return new WDivMMOOCInstruction(qop, in1, in2, in3, in4, out, opcode, str);
}
throw new DMLRuntimeException("Not implemented yet opcode " + opcode);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.sysds.runtime.instructions.ooc;

import org.apache.sysds.common.Opcodes;
import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;

import java.util.function.Function;

public class WDivMMOOCInstruction extends QuaternaryOOCInstruction {

protected WDivMMOOCInstruction(QuaternaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4,
CPOperand out, String opcode, String istr) {
super(op, in1, in2, in3, in4, out, opcode, istr);
}

public static WDivMMOOCInstruction parseInstruction(QuaternaryOOCInstruction instr) {
String instrStr = instr.getInstructionString();
String opcode = InstructionUtils.getInstructionPartsWithValueType(instr.getInstructionString())[0];
return new WDivMMOOCInstruction((QuaternaryOperator) instr.getOperator(), instr.input1, instr.input2,
instr.input3, instr.input4, instr.output, opcode, instrStr);
}

@Override
public void processInstruction(ExecutionContext ec) {
QuaternaryOperator qop = ((QuaternaryOperator) _optr);
final WDivMMType wt = qop.wtype3;

CachingStream X = new CachingStream(ec.getMatrixObject(input1).getStreamHandle());
CachingStream U = new CachingStream(ec.getMatrixObject(input2).getStreamHandle());
CachingStream V = new CachingStream(ec.getMatrixObject(input3).getStreamHandle());

boolean basic = wt.isBasic();
boolean left = wt.isLeft();
boolean mult = wt.isMult();
boolean minus = wt.isMinus();
boolean four = wt.hasFourInputs();
boolean scalar = wt.hasScalar();

OOCStream<IndexedMatrixValue> mmt = matMultOOC(U.getReadStream(), V.getReadStream(), U.getDataCharacteristics(),
V.getDataCharacteristics(), false, true);
OOCStream<IndexedMatrixValue> inter;
OOCStream<IndexedMatrixValue> out;

if(basic) {
out = elemMultOOC(X.getReadStream(), mmt);
ec.getMatrixObject(output).setStreamHandle(out);
return;
}
else if(four) {
if(scalar) {
double eps = ec.getScalarInput(input4).getDoubleValue();
inter = elemDivOOC(X.getReadStream(), elemPlusOOC(mmt, eps));
}
else {
CachingStream W = new CachingStream(ec.getMatrixObject(input4).getStreamHandle());
inter = elemMultOOC(X.getReadStream(), elemMinusOOC(mmt, W.getReadStream()));
}
}
else {
if(minus)
inter = maskOOC(X.getReadStream(), elemMinusOOC(mmt, X.getReadStream()));
else {
if(mult)
inter = elemMultOOC(X.getReadStream(), mmt);
else
inter = elemDivOOC(X.getReadStream(), mmt);
}
}

if(left)
out = matMultOOC(inter, U.getReadStream(), X.getDataCharacteristics(), U.getDataCharacteristics(),
true, false);
else
out = matMultOOC(inter, V.getReadStream(), X.getDataCharacteristics(), V.getDataCharacteristics(),
false, false);

ec.getMatrixObject(output).setStreamHandle(out);
}

private OOCStream<IndexedMatrixValue> matMultOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2,
DataCharacteristics dc1, DataCharacteristics dc2, boolean leftTranspose, boolean rightTranspose) {

int emitLeftThreshold = rightTranspose ? (int) dc2.getNumRowBlocks() : (int) dc2.getNumColBlocks();
int emitRightThreshold = leftTranspose ? (int) dc1.getNumColBlocks() : (int) dc1.getNumRowBlocks();

OOCStream<IndexedMatrixValue> intermediateStream = createWritableStream();
OOCStream<IndexedMatrixValue> out = createWritableStream();

AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject());
AggregateBinaryOperator op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg);

joinManyOOC(m1, m2, intermediateStream, (left, right) -> {
MatrixBlock leftBlock = (MatrixBlock) left.getValue();
MatrixBlock rightBlock = (MatrixBlock) right.getValue();
if(leftTranspose)
leftBlock = leftBlock.transpose();
if(rightTranspose)
rightBlock = rightBlock.transpose();

MatrixBlock partialResult = leftBlock.aggregateBinaryOperations(leftBlock, rightBlock, new MatrixBlock(), op);
int lidx = (int) (leftTranspose ? left.getIndexes().getColumnIndex() : left.getIndexes().getRowIndex());
int ridx = (int) (rightTranspose ? right.getIndexes().getRowIndex() : right.getIndexes().getColumnIndex());
return new IndexedMatrixValue(new MatrixIndexes(lidx, ridx), partialResult);
}, tmp -> leftTranspose ? tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(),
tmp -> rightTranspose ? tmp.getIndexes().getColumnIndex() : tmp.getIndexes().getRowIndex(),
emitLeftThreshold, emitRightThreshold);

BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString());
int emitAggThreshold = leftTranspose ? (int) dc1.getNumRowBlocks() : (int) dc1.getNumColBlocks();

groupedReduceOOC(intermediateStream, out, (left, right) -> {
MatrixBlock mb = ((MatrixBlock) left.getValue()).binaryOperationsInPlace(plus, right.getValue());
left.setValue(mb);
return left;
}, emitAggThreshold);

return out;
}

private OOCStream<IndexedMatrixValue> elemOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2, BinaryOperator bop) {
SubscribableTaskQueue<IndexedMatrixValue> out = new SubscribableTaskQueue<>();
Function<IndexedMatrixValue, MatrixIndexes> key = imv ->
new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex());

joinOOC(m1, m2, out, (left, right) -> {
MatrixBlock lb = (MatrixBlock) left.getValue();
MatrixBlock rb = (MatrixBlock) right.getValue();
MatrixBlock combined = lb.binaryOperations(bop, rb);
return new IndexedMatrixValue(
new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined);
}, key);

return out;
}

private OOCStream<IndexedMatrixValue> elemDivOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2) {
BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.DIV.toString());
return elemOOC(m1, m2, div);
}

private OOCStream<IndexedMatrixValue> elemMultOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2) {
BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.MULT.toString());
return elemOOC(m1, m2, div);
}

private OOCStream<IndexedMatrixValue> elemMinusOOC(OOCStream<IndexedMatrixValue> m1, OOCStream<IndexedMatrixValue> m2) {
BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.MINUS.toString());
return elemOOC(m1, m2, div);
}

private OOCStream<IndexedMatrixValue> elemPlusOOC(OOCStream<IndexedMatrixValue> m1, double eps) {
SubscribableTaskQueue<IndexedMatrixValue> out = new SubscribableTaskQueue<>();
mapOOC(m1, out, blk -> {
MatrixBlock res = ((MatrixBlock) blk.getValue())
.scalarOperations(new RightScalarOperator(Plus.getPlusFnObject(), eps), null);
return new IndexedMatrixValue(
new MatrixIndexes(blk.getIndexes().getRowIndex(), blk.getIndexes().getColumnIndex()), res);
});
return out;
}

private OOCStream<IndexedMatrixValue> maskOOC(OOCStream<IndexedMatrixValue> mask, OOCStream<IndexedMatrixValue> m1) {
SubscribableTaskQueue<IndexedMatrixValue> out = new SubscribableTaskQueue<>();
Function<IndexedMatrixValue, MatrixIndexes> key = imv ->
new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex());

joinOOC(mask, m1, out, (left, right) -> {
MatrixBlock lb = (MatrixBlock) left.getValue();
MatrixBlock rb = (MatrixBlock) right.getValue();
MatrixBlock combined = mask(lb, rb);
return new IndexedMatrixValue(
new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined);
}, key);

return out;
}

private MatrixBlock mask(MatrixBlock mask, MatrixBlock blk) {
for(int i = 0; i < blk.getNumRows(); i++) {
for(int j = 0; j < blk.getNumColumns(); j++) {
if(mask.get(i,j) ==0) blk.set(i, j, 0);
}
}
return blk;
}
}
Loading
Loading