From 995f770331b0522cff4a3925e813f5cc08edd80a Mon Sep 17 00:00:00 2001 From: Jessica Priebe Date: Wed, 25 Mar 2026 16:17:06 +0100 Subject: [PATCH 1/2] add pnmf --- scripts/builtin/pnmf.dml | 6 +- .../org/apache/sysds/hops/QuaternaryOp.java | 16 ++ .../instructions/OOCInstructionParser.java | 3 + .../ooc/ComputationOOCInstruction.java | 11 +- .../instructions/ooc/OOCInstruction.java | 2 +- .../ooc/QuaternaryOOCInstruction.java | 54 ++++++ .../ooc/WDivMMOOCInstruction.java | 161 ++++++++++++++++++ .../sysds/test/functions/ooc/PNMFTest.java | 18 +- src/test/scripts/functions/ooc/PNMF.dml | 6 +- 9 files changed, 264 insertions(+), 13 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java diff --git a/scripts/builtin/pnmf.dml b/scripts/builtin/pnmf.dml index 721ab7232bf..bffc3735926 100644 --- a/scripts/builtin/pnmf.dml +++ b/scripts/builtin/pnmf.dml @@ -42,12 +42,12 @@ # H List of amplitude matrices, one for each repetition. # ------------------------------------------------------------------------------------ -m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer maxi = 10, Boolean verbose=TRUE) +m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer maxi = 10, Boolean verbose=TRUE, Integer seed=-1) return (Matrix[Double] W, Matrix[Double] H) { #initialize W and H - W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025); - H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025); + W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025, seed=seed); + H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025, seed=seed); i = 0; while(i < maxi) { diff --git a/src/main/java/org/apache/sysds/hops/QuaternaryOp.java b/src/main/java/org/apache/sysds/hops/QuaternaryOp.java index c2be949f377..8fede5f0908 100644 --- a/src/main/java/org/apache/sysds/hops/QuaternaryOp.java +++ b/src/main/java/org/apache/sysds/hops/QuaternaryOp.java @@ -211,6 +211,8 @@ else if( et == ExecType.SPARK ) constructCPLopsWeightedDivMM(wtype); else if( et == ExecType.SPARK ) constructSparkLopsWeightedDivMM(wtype); + else if( et == ExecType.OOC ) + constructOOCLopsWeightedDivMM(wtype); else throw new HopsException("Unsupported quaternaryop-wdivmm exec type: "+et); break; @@ -462,6 +464,20 @@ private void constructSparkLopsWeightedDivMM( WDivMMType wtype ) } } + private void constructOOCLopsWeightedDivMM(WDivMMType wtype) + { + WeightedDivMM wdiv = new WeightedDivMM( + getInput().get(0).constructLops(), + getInput().get(1).constructLops(), + getInput().get(2).constructLops(), + getInput().get(3).constructLops(), + getDataType(), getValueType(), wtype, ExecType.OOC); + + setOutputDimensions(wdiv); + setLineNumbers(wdiv); + setLops(wdiv); + } + private void constructCPLopsWeightedCeMM(WCeMMType wtype) { WeightedCrossEntropy wcemm = new WeightedCrossEntropy( diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index affda5910d6..ae41639687b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -43,6 +43,7 @@ import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.AppendOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.QuaternaryOOCInstruction; public class OOCInstructionParser extends InstructionParser { protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName()); @@ -111,6 +112,8 @@ else if(parts.length == 4) return DataGenOOCInstruction.parseInstruction(str); case Append: return AppendOOCInstruction.parseInstruction(str); + case Quaternary: + return QuaternaryOOCInstruction.parseInstruction(str); default: throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java index 4dcdffcb0dc..d6686c11560 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java @@ -24,7 +24,7 @@ public abstract class ComputationOOCInstruction extends OOCInstruction { public CPOperand output; - public CPOperand input1, input2, input3; + public CPOperand input1, input2, input3, input4; protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand out, String opcode, String istr) { super(type, op, opcode, istr); @@ -50,6 +50,15 @@ protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CP output = out; } + protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr) { + super(type, op, opcode, istr); + input1 = in1; + input2 = in2; + input3 = in3; + input4 = in4; + output = out; + } + public String getOutputVariableName() { return output.getName(); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index be9728d87b9..679e7187e5e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -80,7 +80,7 @@ public abstract class OOCInstruction extends Instruction { public enum OOCType { Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, AggregateBinary, AggregateTernary, MAPMM, MMTSJ, - MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append + MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append, Quaternary } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java new file mode 100644 index 00000000000..8df1e33c59a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.lops.WeightedDivMM.WDivMMType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator; + +public abstract class QuaternaryOOCInstruction extends ComputationOOCInstruction { + + protected QuaternaryOOCInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, + CPOperand out, String opcode, String istr) { + super(OOCType.Quaternary, op, in1, in2, in3, in4, out, opcode, istr); + } + + public static QuaternaryOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + + if(opcode.contains(Opcodes.WEIGHTEDDIVMM.toString())) { + InstructionUtils.checkNumFields(parts, 6); + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand in4 = new CPOperand(parts[4]); + CPOperand out = new CPOperand(parts[5]); + QuaternaryOperator qop = new QuaternaryOperator(WDivMMType.valueOf(parts[6])); + return new WDivMMOOCInstruction(qop, in1, in2, in3, in4, out, opcode, str); + } + throw new DMLRuntimeException("Not implemented yet opcode " + opcode); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java new file mode 100644 index 00000000000..52f288a56f8 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.lops.WeightedDivMM.WDivMMType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +import java.util.function.Function; + + +public class WDivMMOOCInstruction extends QuaternaryOOCInstruction +{ + + protected WDivMMOOCInstruction(QuaternaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, + CPOperand out, String opcode, String istr) { + super(op, in1, in2, in3, in4, out, opcode, istr); + } + + public static WDivMMOOCInstruction parseInstruction(QuaternaryOOCInstruction instr) { + String instrStr = instr.getInstructionString(); + String opcode = InstructionUtils.getInstructionPartsWithValueType(instr.getInstructionString())[0]; + return new WDivMMOOCInstruction((QuaternaryOperator) instr.getOperator(), instr.input1, instr.input2, + instr.input3, instr.input4, instr.output, opcode, instrStr); + } + + + @Override + public void processInstruction(ExecutionContext ec) { + QuaternaryOperator _qop = ((QuaternaryOperator)_optr); + final WDivMMType wt = _qop.wtype3; + + if(!(wt.hasFourInputs()&&wt.hasScalar()) || wt.isBasic() || wt.isMult() || wt.isMinus()) throw new DMLRuntimeException("Not implemented: only pnmf supported yet"); + + CachingStream X = new CachingStream(ec.getMatrixObject(input1).getStreamHandle()); + CachingStream U = new CachingStream(ec.getMatrixObject(input2).getStreamHandle()); + CachingStream V = new CachingStream(ec.getMatrixObject(input3).getStreamHandle()); + + double eps = 0.0; + if(_qop.hasFourInputs()) { + if (input4.getDataType() == DataType.SCALAR) + eps = ec.getScalarInput(input4).getDoubleValue(); + } + + OOCStream mmt = matMultOOC(U.getReadStream(), V.getReadStream(), U.getDataCharacteristics(), V.getDataCharacteristics(), false, true); + OOCStream plus = elemPlusOOC(mmt, eps); + OOCStream inter = elemDivOOC(X.getReadStream(), plus); + OOCStream out; + + if(wt.isLeft()) + out = matMultOOC(inter, U.getReadStream(), X.getDataCharacteristics(), U.getDataCharacteristics(), true, false); + else + out = matMultOOC(inter, V.getReadStream(), X.getDataCharacteristics(), V.getDataCharacteristics(), false, false); + + ec.getMatrixObject(output).setStreamHandle(out); + } + + private OOCStream matMultOOC(OOCStream m1, OOCStream m2, DataCharacteristics dc1, DataCharacteristics dc2, boolean leftTranspose, boolean rightTranspose){ + + int emitLeftThreshold = rightTranspose? (int) dc2.getNumRowBlocks() : (int) dc2.getNumColBlocks(); + int emitRightThreshold = leftTranspose? (int) dc1.getNumColBlocks() : (int) dc1.getNumRowBlocks(); + + OOCStream intermediateStream = createWritableStream(); + OOCStream out = createWritableStream(); + + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateBinaryOperator op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); + + joinManyOOC(m1, m2, intermediateStream, + (left, right) -> { + MatrixBlock leftBlock = (MatrixBlock) left.getValue(); + MatrixBlock rightBlock = (MatrixBlock) right.getValue(); + if(leftTranspose) leftBlock = leftBlock.transpose(); + if(rightTranspose) rightBlock = rightBlock.transpose(); + + MatrixBlock partialResult = leftBlock.aggregateBinaryOperations(leftBlock, rightBlock, + new MatrixBlock(), op); + int lidx = (int) (leftTranspose? left.getIndexes().getColumnIndex() : left.getIndexes().getRowIndex()); + int ridx = (int) (rightTranspose? right.getIndexes().getRowIndex() : right.getIndexes().getColumnIndex()); + return new IndexedMatrixValue(new MatrixIndexes(lidx, ridx), partialResult); + }, + tmp -> leftTranspose? tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(), + tmp -> rightTranspose? tmp.getIndexes().getColumnIndex() : tmp.getIndexes().getRowIndex(), + emitLeftThreshold, emitRightThreshold); + + BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); + int emitAggThreshold = leftTranspose? (int) dc1.getNumRowBlocks() : (int) dc1.getNumColBlocks(); + + groupedReduceOOC(intermediateStream, out, (left, right) -> { + MatrixBlock mb = ((MatrixBlock)left.getValue()).binaryOperationsInPlace(plus, right.getValue()); + left.setValue(mb); + return left; + }, emitAggThreshold); + + return out; + } + + private OOCStream elemDivOOC(OOCStream m1, OOCStream m2){ + SubscribableTaskQueue out = new SubscribableTaskQueue<>(); + BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.DIV.toString()); + Function key = imv -> new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex()); + + joinOOC(m1, m2, out, (left, right) -> { + MatrixBlock lb = (MatrixBlock) left.getValue(); + MatrixBlock rb = (MatrixBlock) right.getValue(); + MatrixBlock combined = lb.binaryOperations(div, rb); + return new IndexedMatrixValue( + new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined); + }, key); + + return out; + } + + private OOCStream elemPlusOOC(OOCStream m1, double eps){ + SubscribableTaskQueue out = new SubscribableTaskQueue<>(); + mapOOC(m1, out, blk -> new IndexedMatrixValue( + new MatrixIndexes(blk.getIndexes().getRowIndex(), blk.getIndexes().getColumnIndex()), plusDouble((MatrixBlock) blk.getValue(), eps))); + return out; + } + + private MatrixBlock plusDouble(MatrixBlock blk, double eps){ + for(int i=0; i Date: Tue, 5 May 2026 12:53:20 +0200 Subject: [PATCH 2/2] extend to wdivmm --- .../ooc/WDivMMOOCInstruction.java | 163 ++++++++++++------ .../sysds/test/functions/ooc/WDivMMTest.java | 156 +++++++++++++++++ .../ooc/WeightedDivMM4MultMinusLeft.dml | 32 ++++ .../ooc/WeightedDivMM4MultMinusRight.dml | 32 ++++ .../functions/ooc/WeightedDivMMLeft.dml | 30 ++++ .../functions/ooc/WeightedDivMMLeftEps.dml | 32 ++++ .../functions/ooc/WeightedDivMMMultBasic.dml | 30 ++++ .../functions/ooc/WeightedDivMMMultLeft.dml | 30 ++++ .../ooc/WeightedDivMMMultMinusLeft.dml | 30 ++++ .../ooc/WeightedDivMMMultMinusRight.dml | 30 ++++ .../functions/ooc/WeightedDivMMMultRight.dml | 30 ++++ .../functions/ooc/WeightedDivMMRight.dml | 30 ++++ .../functions/ooc/WeightedDivMMRightEps.dml | 32 ++++ 13 files changed, 604 insertions(+), 53 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/functions/ooc/WDivMMTest.java create mode 100644 src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml create mode 100644 src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml create mode 100644 src/test/scripts/functions/ooc/WeightedDivMMLeft.dml create mode 100644 src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml create mode 100644 src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml create mode 100644 src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml create mode 100644 src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml create mode 100644 src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml create mode 100644 src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml create mode 100644 src/test/scripts/functions/ooc/WeightedDivMMRight.dml create mode 100644 src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java index 52f288a56f8..ec9a7bcd4fe 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java @@ -19,11 +19,8 @@ package org.apache.sysds.runtime.instructions.ooc; - import org.apache.sysds.common.Opcodes; -import org.apache.sysds.common.Types.DataType; import org.apache.sysds.lops.WeightedDivMM.WDivMMType; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.functionobjects.Multiply; import org.apache.sysds.runtime.functionobjects.Plus; @@ -35,14 +32,13 @@ import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator; import org.apache.sysds.runtime.meta.DataCharacteristics; import java.util.function.Function; - -public class WDivMMOOCInstruction extends QuaternaryOOCInstruction -{ +public class WDivMMOOCInstruction extends QuaternaryOOCInstruction { protected WDivMMOOCInstruction(QuaternaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr) { @@ -56,41 +52,68 @@ public static WDivMMOOCInstruction parseInstruction(QuaternaryOOCInstruction ins instr.input3, instr.input4, instr.output, opcode, instrStr); } - @Override public void processInstruction(ExecutionContext ec) { - QuaternaryOperator _qop = ((QuaternaryOperator)_optr); - final WDivMMType wt = _qop.wtype3; - - if(!(wt.hasFourInputs()&&wt.hasScalar()) || wt.isBasic() || wt.isMult() || wt.isMinus()) throw new DMLRuntimeException("Not implemented: only pnmf supported yet"); + QuaternaryOperator qop = ((QuaternaryOperator) _optr); + final WDivMMType wt = qop.wtype3; CachingStream X = new CachingStream(ec.getMatrixObject(input1).getStreamHandle()); CachingStream U = new CachingStream(ec.getMatrixObject(input2).getStreamHandle()); CachingStream V = new CachingStream(ec.getMatrixObject(input3).getStreamHandle()); - double eps = 0.0; - if(_qop.hasFourInputs()) { - if (input4.getDataType() == DataType.SCALAR) - eps = ec.getScalarInput(input4).getDoubleValue(); - } + boolean basic = wt.isBasic(); + boolean left = wt.isLeft(); + boolean mult = wt.isMult(); + boolean minus = wt.isMinus(); + boolean four = wt.hasFourInputs(); + boolean scalar = wt.hasScalar(); - OOCStream mmt = matMultOOC(U.getReadStream(), V.getReadStream(), U.getDataCharacteristics(), V.getDataCharacteristics(), false, true); - OOCStream plus = elemPlusOOC(mmt, eps); - OOCStream inter = elemDivOOC(X.getReadStream(), plus); + OOCStream mmt = matMultOOC(U.getReadStream(), V.getReadStream(), U.getDataCharacteristics(), + V.getDataCharacteristics(), false, true); + OOCStream inter; OOCStream out; - if(wt.isLeft()) - out = matMultOOC(inter, U.getReadStream(), X.getDataCharacteristics(), U.getDataCharacteristics(), true, false); + if(basic) { + out = elemMultOOC(X.getReadStream(), mmt); + ec.getMatrixObject(output).setStreamHandle(out); + return; + } + else if(four) { + if(scalar) { + double eps = ec.getScalarInput(input4).getDoubleValue(); + inter = elemDivOOC(X.getReadStream(), elemPlusOOC(mmt, eps)); + } + else { + CachingStream W = new CachingStream(ec.getMatrixObject(input4).getStreamHandle()); + inter = elemMultOOC(X.getReadStream(), elemMinusOOC(mmt, W.getReadStream())); + } + } + else { + if(minus) + inter = maskOOC(X.getReadStream(), elemMinusOOC(mmt, X.getReadStream())); + else { + if(mult) + inter = elemMultOOC(X.getReadStream(), mmt); + else + inter = elemDivOOC(X.getReadStream(), mmt); + } + } + + if(left) + out = matMultOOC(inter, U.getReadStream(), X.getDataCharacteristics(), U.getDataCharacteristics(), + true, false); else - out = matMultOOC(inter, V.getReadStream(), X.getDataCharacteristics(), V.getDataCharacteristics(), false, false); + out = matMultOOC(inter, V.getReadStream(), X.getDataCharacteristics(), V.getDataCharacteristics(), + false, false); ec.getMatrixObject(output).setStreamHandle(out); } - private OOCStream matMultOOC(OOCStream m1, OOCStream m2, DataCharacteristics dc1, DataCharacteristics dc2, boolean leftTranspose, boolean rightTranspose){ + private OOCStream matMultOOC(OOCStream m1, OOCStream m2, + DataCharacteristics dc1, DataCharacteristics dc2, boolean leftTranspose, boolean rightTranspose) { - int emitLeftThreshold = rightTranspose? (int) dc2.getNumRowBlocks() : (int) dc2.getNumColBlocks(); - int emitRightThreshold = leftTranspose? (int) dc1.getNumColBlocks() : (int) dc1.getNumRowBlocks(); + int emitLeftThreshold = rightTranspose ? (int) dc2.getNumRowBlocks() : (int) dc2.getNumColBlocks(); + int emitRightThreshold = leftTranspose ? (int) dc1.getNumColBlocks() : (int) dc1.getNumRowBlocks(); OOCStream intermediateStream = createWritableStream(); OOCStream out = createWritableStream(); @@ -98,28 +121,27 @@ private OOCStream matMultOOC(OOCStream m AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); AggregateBinaryOperator op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); - joinManyOOC(m1, m2, intermediateStream, - (left, right) -> { - MatrixBlock leftBlock = (MatrixBlock) left.getValue(); - MatrixBlock rightBlock = (MatrixBlock) right.getValue(); - if(leftTranspose) leftBlock = leftBlock.transpose(); - if(rightTranspose) rightBlock = rightBlock.transpose(); - - MatrixBlock partialResult = leftBlock.aggregateBinaryOperations(leftBlock, rightBlock, - new MatrixBlock(), op); - int lidx = (int) (leftTranspose? left.getIndexes().getColumnIndex() : left.getIndexes().getRowIndex()); - int ridx = (int) (rightTranspose? right.getIndexes().getRowIndex() : right.getIndexes().getColumnIndex()); - return new IndexedMatrixValue(new MatrixIndexes(lidx, ridx), partialResult); - }, - tmp -> leftTranspose? tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(), - tmp -> rightTranspose? tmp.getIndexes().getColumnIndex() : tmp.getIndexes().getRowIndex(), + joinManyOOC(m1, m2, intermediateStream, (left, right) -> { + MatrixBlock leftBlock = (MatrixBlock) left.getValue(); + MatrixBlock rightBlock = (MatrixBlock) right.getValue(); + if(leftTranspose) + leftBlock = leftBlock.transpose(); + if(rightTranspose) + rightBlock = rightBlock.transpose(); + + MatrixBlock partialResult = leftBlock.aggregateBinaryOperations(leftBlock, rightBlock, new MatrixBlock(), op); + int lidx = (int) (leftTranspose ? left.getIndexes().getColumnIndex() : left.getIndexes().getRowIndex()); + int ridx = (int) (rightTranspose ? right.getIndexes().getRowIndex() : right.getIndexes().getColumnIndex()); + return new IndexedMatrixValue(new MatrixIndexes(lidx, ridx), partialResult); + }, tmp -> leftTranspose ? tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(), + tmp -> rightTranspose ? tmp.getIndexes().getColumnIndex() : tmp.getIndexes().getRowIndex(), emitLeftThreshold, emitRightThreshold); BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); - int emitAggThreshold = leftTranspose? (int) dc1.getNumRowBlocks() : (int) dc1.getNumColBlocks(); + int emitAggThreshold = leftTranspose ? (int) dc1.getNumRowBlocks() : (int) dc1.getNumColBlocks(); groupedReduceOOC(intermediateStream, out, (left, right) -> { - MatrixBlock mb = ((MatrixBlock)left.getValue()).binaryOperationsInPlace(plus, right.getValue()); + MatrixBlock mb = ((MatrixBlock) left.getValue()).binaryOperationsInPlace(plus, right.getValue()); left.setValue(mb); return left; }, emitAggThreshold); @@ -127,15 +149,15 @@ private OOCStream matMultOOC(OOCStream m return out; } - private OOCStream elemDivOOC(OOCStream m1, OOCStream m2){ + private OOCStream elemOOC(OOCStream m1, OOCStream m2, BinaryOperator bop) { SubscribableTaskQueue out = new SubscribableTaskQueue<>(); - BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.DIV.toString()); - Function key = imv -> new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex()); + Function key = imv -> + new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex()); joinOOC(m1, m2, out, (left, right) -> { MatrixBlock lb = (MatrixBlock) left.getValue(); MatrixBlock rb = (MatrixBlock) right.getValue(); - MatrixBlock combined = lb.binaryOperations(div, rb); + MatrixBlock combined = lb.binaryOperations(bop, rb); return new IndexedMatrixValue( new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined); }, key); @@ -143,17 +165,52 @@ private OOCStream elemDivOOC(OOCStream m return out; } - private OOCStream elemPlusOOC(OOCStream m1, double eps){ + private OOCStream elemDivOOC(OOCStream m1, OOCStream m2) { + BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.DIV.toString()); + return elemOOC(m1, m2, div); + } + + private OOCStream elemMultOOC(OOCStream m1, OOCStream m2) { + BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.MULT.toString()); + return elemOOC(m1, m2, div); + } + + private OOCStream elemMinusOOC(OOCStream m1, OOCStream m2) { + BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.MINUS.toString()); + return elemOOC(m1, m2, div); + } + + private OOCStream elemPlusOOC(OOCStream m1, double eps) { + SubscribableTaskQueue out = new SubscribableTaskQueue<>(); + mapOOC(m1, out, blk -> { + MatrixBlock res = ((MatrixBlock) blk.getValue()) + .scalarOperations(new RightScalarOperator(Plus.getPlusFnObject(), eps), null); + return new IndexedMatrixValue( + new MatrixIndexes(blk.getIndexes().getRowIndex(), blk.getIndexes().getColumnIndex()), res); + }); + return out; + } + + private OOCStream maskOOC(OOCStream mask, OOCStream m1) { SubscribableTaskQueue out = new SubscribableTaskQueue<>(); - mapOOC(m1, out, blk -> new IndexedMatrixValue( - new MatrixIndexes(blk.getIndexes().getRowIndex(), blk.getIndexes().getColumnIndex()), plusDouble((MatrixBlock) blk.getValue(), eps))); + Function key = imv -> + new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex()); + + joinOOC(mask, m1, out, (left, right) -> { + MatrixBlock lb = (MatrixBlock) left.getValue(); + MatrixBlock rb = (MatrixBlock) right.getValue(); + MatrixBlock combined = mask(lb, rb); + return new IndexedMatrixValue( + new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined); + }, key); + return out; } - private MatrixBlock plusDouble(MatrixBlock blk, double eps){ - for(int i=0; i data() { + return Arrays.asList(new Object[][] {{TEST_NAME_1}, {TEST_NAME_2}, {TEST_NAME_3}, {TEST_NAME_4}, {TEST_NAME_5}, + {TEST_NAME_6}, {TEST_NAME_7}, {TEST_NAME_8}, {TEST_NAME_9}, {TEST_NAME_10}, {TEST_NAME_11}}); + } + + @Before + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {OUTPUT_NAME})); + } + + @Test + public void testWeightedDivMM() { + runWeightedDivMMTest(TEST_NAME); + } + + private void runWeightedDivMMTest(String TEST_NAME) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + boolean basic = TEST_NAME.equals(TEST_NAME_3); + boolean left = TEST_NAME.equals(TEST_NAME_1) || TEST_NAME.equals(TEST_NAME_4) || + TEST_NAME.equals(TEST_NAME_6) || TEST_NAME.equals(TEST_NAME_8) || TEST_NAME.equals(TEST_NAME_10); + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + + double[][] W = getRandomMatrix(rows, cols, 0, 1, 0.7, 7); + double[][] U = getRandomMatrix(rows, rank, 0, 1, 1.0, 713); + double[][] V = getRandomMatrix(cols, rank, 0, 1, 1.0, 812); + + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(W), input(INPUT_NAME_1), rows, + cols, blen, rows * cols); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(U), input(INPUT_NAME_2), rows, + rank, blen, rows * rank); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(V), input(INPUT_NAME_3), cols, + rank, blen, cols * rank); + + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, blen, rows * cols), Types.FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, rank, blen, rows * rank), Types.FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_3 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(cols, rank, blen, cols * rank), Types.FileFormat.BINARY); + + programArgs = new String[] {"-ooc", "-stats", "-explain", "runtime", "-args", input(INPUT_NAME_1), + input(INPUT_NAME_2), input(INPUT_NAME_3), output(OUTPUT_NAME), Double.toString(div_eps)}; + + runTest(true, false, null, -1); + + Assert.assertTrue("OOC wasn't used for wdivmm", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.WEIGHTEDDIVMM)); + + programArgs = new String[] {"-stats", "-explain", "runtime", "-args", input(INPUT_NAME_1), + input(INPUT_NAME_2), input(INPUT_NAME_3), output(OUTPUT_NAME + "_target"), Double.toString(div_eps)}; + + runTest(true, false, null, -1); + + int rows2 = left ? cols : rows; + int cols2 = basic ? cols : rank; + checkDMLMetaDataFile("R", new MatrixCharacteristics(rows2, cols2)); + + MatrixBlock actual = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, rows2, cols2, blen); + MatrixBlock expected = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, rows2, cols2, blen); + TestUtils.compareMatrices(expected, actual, eps); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml new file mode 100644 index 00000000000..42bd4c96a0c --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +X = W/0.7; +while(FALSE){} +R = t(t(U) %*% (W*(U%*%t(V)-X))); + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml new file mode 100644 index 00000000000..7b393f12310 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +X = W/0.3 +while(FALSE){} +R = (W*(U%*%t(V)-X)) %*% V; + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMLeft.dml b/src/test/scripts/functions/ooc/WeightedDivMMLeft.dml new file mode 100644 index 00000000000..48639a176a7 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMLeft.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = t(t(U) %*% (W/(U%*%t(V)))); + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml b/src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml new file mode 100644 index 00000000000..dc07670feab --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +x = $5; + +R = t(t(U) %*% (W/(U%*%t(V) + x))); + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml b/src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml new file mode 100644 index 00000000000..144e59a773a --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = W*(U%*%t(V)); + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml b/src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml new file mode 100644 index 00000000000..93bc765617f --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = t(t(U) %*% (W*(U%*%t(V)))); + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml new file mode 100644 index 00000000000..84ac35ad899 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = t(t(U) %*% ((W != 0)*(U%*%t(V)-W))); + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml new file mode 100644 index 00000000000..59caa4d17b4 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = ((W != 0)*(U%*%t(V)-W)) %*% V; + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml b/src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml new file mode 100644 index 00000000000..fbb1224d173 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = (W*(U%*%t(V))) %*% V; + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMRight.dml b/src/test/scripts/functions/ooc/WeightedDivMMRight.dml new file mode 100644 index 00000000000..e878a81d14d --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMRight.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = (W/(U%*%t(V))) %*% V; + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml b/src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml new file mode 100644 index 00000000000..9ecbaf56630 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +x = $5; + +R = (W/(U%*%t(V) + x)) %*% V; + +write(R, $4, format="binary");