diff --git a/scripts/builtin/pnmf.dml b/scripts/builtin/pnmf.dml index 721ab7232bf..bffc3735926 100644 --- a/scripts/builtin/pnmf.dml +++ b/scripts/builtin/pnmf.dml @@ -42,12 +42,12 @@ # H List of amplitude matrices, one for each repetition. # ------------------------------------------------------------------------------------ -m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer maxi = 10, Boolean verbose=TRUE) +m_pnmf = function(Matrix[Double] X, Integer rnk, Double eps = 1e-8, Integer maxi = 10, Boolean verbose=TRUE, Integer seed=-1) return (Matrix[Double] W, Matrix[Double] H) { #initialize W and H - W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025); - H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025); + W = rand(rows=nrow(X), cols=rnk, min=0, max=0.025, seed=seed); + H = rand(rows=rnk, cols=ncol(X), min=0, max=0.025, seed=seed); i = 0; while(i < maxi) { diff --git a/src/main/java/org/apache/sysds/hops/QuaternaryOp.java b/src/main/java/org/apache/sysds/hops/QuaternaryOp.java index c2be949f377..8fede5f0908 100644 --- a/src/main/java/org/apache/sysds/hops/QuaternaryOp.java +++ b/src/main/java/org/apache/sysds/hops/QuaternaryOp.java @@ -211,6 +211,8 @@ else if( et == ExecType.SPARK ) constructCPLopsWeightedDivMM(wtype); else if( et == ExecType.SPARK ) constructSparkLopsWeightedDivMM(wtype); + else if( et == ExecType.OOC ) + constructOOCLopsWeightedDivMM(wtype); else throw new HopsException("Unsupported quaternaryop-wdivmm exec type: "+et); break; @@ -462,6 +464,20 @@ private void constructSparkLopsWeightedDivMM( WDivMMType wtype ) } } + private void constructOOCLopsWeightedDivMM(WDivMMType wtype) + { + WeightedDivMM wdiv = new WeightedDivMM( + getInput().get(0).constructLops(), + getInput().get(1).constructLops(), + getInput().get(2).constructLops(), + getInput().get(3).constructLops(), + getDataType(), getValueType(), wtype, ExecType.OOC); + + setOutputDimensions(wdiv); + setLineNumbers(wdiv); + setLops(wdiv); + } + private void constructCPLopsWeightedCeMM(WCeMMType wtype) { WeightedCrossEntropy wcemm = new WeightedCrossEntropy( diff --git a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java index affda5910d6..ae41639687b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/OOCInstructionParser.java @@ -43,6 +43,7 @@ import org.apache.sysds.runtime.instructions.ooc.ReorgOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; import org.apache.sysds.runtime.instructions.ooc.AppendOOCInstruction; +import org.apache.sysds.runtime.instructions.ooc.QuaternaryOOCInstruction; public class OOCInstructionParser extends InstructionParser { protected static final Log LOG = LogFactory.getLog(OOCInstructionParser.class.getName()); @@ -111,6 +112,8 @@ else if(parts.length == 4) return DataGenOOCInstruction.parseInstruction(str); case Append: return AppendOOCInstruction.parseInstruction(str); + case Quaternary: + return QuaternaryOOCInstruction.parseInstruction(str); default: throw new DMLRuntimeException("Invalid OOC Instruction Type: " + ooctype); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java index 4dcdffcb0dc..d6686c11560 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/ComputationOOCInstruction.java @@ -24,7 +24,7 @@ public abstract class ComputationOOCInstruction extends OOCInstruction { public CPOperand output; - public CPOperand input1, input2, input3; + public CPOperand input1, input2, input3, input4; protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand out, String opcode, String istr) { super(type, op, opcode, istr); @@ -50,6 +50,15 @@ protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CP output = out; } + protected ComputationOOCInstruction(OOCType type, Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr) { + super(type, op, opcode, istr); + input1 = in1; + input2 = in2; + input3 = in3; + input4 = in4; + output = out; + } + public String getOutputVariableName() { return output.getName(); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java index be9728d87b9..679e7187e5e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/OOCInstruction.java @@ -80,7 +80,7 @@ public abstract class OOCInstruction extends Instruction { public enum OOCType { Reblock, Tee, Binary, Ternary, Unary, AggregateUnary, AggregateBinary, AggregateTernary, MAPMM, MMTSJ, - MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append + MAPMMCHAIN, Reorg, CM, Ctable, MatrixIndexing, ParameterizedBuiltin, Rand, Append, Quaternary } protected final OOCInstruction.OOCType _ooctype; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java new file mode 100644 index 00000000000..8df1e33c59a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/QuaternaryOOCInstruction.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.lops.WeightedDivMM.WDivMMType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.matrix.operators.Operator; +import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator; + +public abstract class QuaternaryOOCInstruction extends ComputationOOCInstruction { + + protected QuaternaryOOCInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, + CPOperand out, String opcode, String istr) { + super(OOCType.Quaternary, op, in1, in2, in3, in4, out, opcode, istr); + } + + public static QuaternaryOOCInstruction parseInstruction(String str) { + String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); + String opcode = parts[0]; + + if(opcode.contains(Opcodes.WEIGHTEDDIVMM.toString())) { + InstructionUtils.checkNumFields(parts, 6); + CPOperand in1 = new CPOperand(parts[1]); + CPOperand in2 = new CPOperand(parts[2]); + CPOperand in3 = new CPOperand(parts[3]); + CPOperand in4 = new CPOperand(parts[4]); + CPOperand out = new CPOperand(parts[5]); + QuaternaryOperator qop = new QuaternaryOperator(WDivMMType.valueOf(parts[6])); + return new WDivMMOOCInstruction(qop, in1, in2, in3, in4, out, opcode, str); + } + throw new DMLRuntimeException("Not implemented yet opcode " + opcode); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java new file mode 100644 index 00000000000..ec9a7bcd4fe --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/ooc/WDivMMOOCInstruction.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.ooc; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.lops.WeightedDivMM.WDivMMType; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.functionobjects.Plus; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.RightScalarOperator; +import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +import java.util.function.Function; + +public class WDivMMOOCInstruction extends QuaternaryOOCInstruction { + + protected WDivMMOOCInstruction(QuaternaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, + CPOperand out, String opcode, String istr) { + super(op, in1, in2, in3, in4, out, opcode, istr); + } + + public static WDivMMOOCInstruction parseInstruction(QuaternaryOOCInstruction instr) { + String instrStr = instr.getInstructionString(); + String opcode = InstructionUtils.getInstructionPartsWithValueType(instr.getInstructionString())[0]; + return new WDivMMOOCInstruction((QuaternaryOperator) instr.getOperator(), instr.input1, instr.input2, + instr.input3, instr.input4, instr.output, opcode, instrStr); + } + + @Override + public void processInstruction(ExecutionContext ec) { + QuaternaryOperator qop = ((QuaternaryOperator) _optr); + final WDivMMType wt = qop.wtype3; + + CachingStream X = new CachingStream(ec.getMatrixObject(input1).getStreamHandle()); + CachingStream U = new CachingStream(ec.getMatrixObject(input2).getStreamHandle()); + CachingStream V = new CachingStream(ec.getMatrixObject(input3).getStreamHandle()); + + boolean basic = wt.isBasic(); + boolean left = wt.isLeft(); + boolean mult = wt.isMult(); + boolean minus = wt.isMinus(); + boolean four = wt.hasFourInputs(); + boolean scalar = wt.hasScalar(); + + OOCStream mmt = matMultOOC(U.getReadStream(), V.getReadStream(), U.getDataCharacteristics(), + V.getDataCharacteristics(), false, true); + OOCStream inter; + OOCStream out; + + if(basic) { + out = elemMultOOC(X.getReadStream(), mmt); + ec.getMatrixObject(output).setStreamHandle(out); + return; + } + else if(four) { + if(scalar) { + double eps = ec.getScalarInput(input4).getDoubleValue(); + inter = elemDivOOC(X.getReadStream(), elemPlusOOC(mmt, eps)); + } + else { + CachingStream W = new CachingStream(ec.getMatrixObject(input4).getStreamHandle()); + inter = elemMultOOC(X.getReadStream(), elemMinusOOC(mmt, W.getReadStream())); + } + } + else { + if(minus) + inter = maskOOC(X.getReadStream(), elemMinusOOC(mmt, X.getReadStream())); + else { + if(mult) + inter = elemMultOOC(X.getReadStream(), mmt); + else + inter = elemDivOOC(X.getReadStream(), mmt); + } + } + + if(left) + out = matMultOOC(inter, U.getReadStream(), X.getDataCharacteristics(), U.getDataCharacteristics(), + true, false); + else + out = matMultOOC(inter, V.getReadStream(), X.getDataCharacteristics(), V.getDataCharacteristics(), + false, false); + + ec.getMatrixObject(output).setStreamHandle(out); + } + + private OOCStream matMultOOC(OOCStream m1, OOCStream m2, + DataCharacteristics dc1, DataCharacteristics dc2, boolean leftTranspose, boolean rightTranspose) { + + int emitLeftThreshold = rightTranspose ? (int) dc2.getNumRowBlocks() : (int) dc2.getNumColBlocks(); + int emitRightThreshold = leftTranspose ? (int) dc1.getNumColBlocks() : (int) dc1.getNumRowBlocks(); + + OOCStream intermediateStream = createWritableStream(); + OOCStream out = createWritableStream(); + + AggregateOperator agg = new AggregateOperator(0, Plus.getPlusFnObject()); + AggregateBinaryOperator op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), agg); + + joinManyOOC(m1, m2, intermediateStream, (left, right) -> { + MatrixBlock leftBlock = (MatrixBlock) left.getValue(); + MatrixBlock rightBlock = (MatrixBlock) right.getValue(); + if(leftTranspose) + leftBlock = leftBlock.transpose(); + if(rightTranspose) + rightBlock = rightBlock.transpose(); + + MatrixBlock partialResult = leftBlock.aggregateBinaryOperations(leftBlock, rightBlock, new MatrixBlock(), op); + int lidx = (int) (leftTranspose ? left.getIndexes().getColumnIndex() : left.getIndexes().getRowIndex()); + int ridx = (int) (rightTranspose ? right.getIndexes().getRowIndex() : right.getIndexes().getColumnIndex()); + return new IndexedMatrixValue(new MatrixIndexes(lidx, ridx), partialResult); + }, tmp -> leftTranspose ? tmp.getIndexes().getRowIndex() : tmp.getIndexes().getColumnIndex(), + tmp -> rightTranspose ? tmp.getIndexes().getColumnIndex() : tmp.getIndexes().getRowIndex(), + emitLeftThreshold, emitRightThreshold); + + BinaryOperator plus = InstructionUtils.parseBinaryOperator(Opcodes.PLUS.toString()); + int emitAggThreshold = leftTranspose ? (int) dc1.getNumRowBlocks() : (int) dc1.getNumColBlocks(); + + groupedReduceOOC(intermediateStream, out, (left, right) -> { + MatrixBlock mb = ((MatrixBlock) left.getValue()).binaryOperationsInPlace(plus, right.getValue()); + left.setValue(mb); + return left; + }, emitAggThreshold); + + return out; + } + + private OOCStream elemOOC(OOCStream m1, OOCStream m2, BinaryOperator bop) { + SubscribableTaskQueue out = new SubscribableTaskQueue<>(); + Function key = imv -> + new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex()); + + joinOOC(m1, m2, out, (left, right) -> { + MatrixBlock lb = (MatrixBlock) left.getValue(); + MatrixBlock rb = (MatrixBlock) right.getValue(); + MatrixBlock combined = lb.binaryOperations(bop, rb); + return new IndexedMatrixValue( + new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined); + }, key); + + return out; + } + + private OOCStream elemDivOOC(OOCStream m1, OOCStream m2) { + BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.DIV.toString()); + return elemOOC(m1, m2, div); + } + + private OOCStream elemMultOOC(OOCStream m1, OOCStream m2) { + BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.MULT.toString()); + return elemOOC(m1, m2, div); + } + + private OOCStream elemMinusOOC(OOCStream m1, OOCStream m2) { + BinaryOperator div = InstructionUtils.parseBinaryOperator(Opcodes.MINUS.toString()); + return elemOOC(m1, m2, div); + } + + private OOCStream elemPlusOOC(OOCStream m1, double eps) { + SubscribableTaskQueue out = new SubscribableTaskQueue<>(); + mapOOC(m1, out, blk -> { + MatrixBlock res = ((MatrixBlock) blk.getValue()) + .scalarOperations(new RightScalarOperator(Plus.getPlusFnObject(), eps), null); + return new IndexedMatrixValue( + new MatrixIndexes(blk.getIndexes().getRowIndex(), blk.getIndexes().getColumnIndex()), res); + }); + return out; + } + + private OOCStream maskOOC(OOCStream mask, OOCStream m1) { + SubscribableTaskQueue out = new SubscribableTaskQueue<>(); + Function key = imv -> + new MatrixIndexes(imv.getIndexes().getRowIndex(), imv.getIndexes().getColumnIndex()); + + joinOOC(mask, m1, out, (left, right) -> { + MatrixBlock lb = (MatrixBlock) left.getValue(); + MatrixBlock rb = (MatrixBlock) right.getValue(); + MatrixBlock combined = mask(lb, rb); + return new IndexedMatrixValue( + new MatrixIndexes(left.getIndexes().getRowIndex(), left.getIndexes().getColumnIndex()), combined); + }, key); + + return out; + } + + private MatrixBlock mask(MatrixBlock mask, MatrixBlock blk) { + for(int i = 0; i < blk.getNumRows(); i++) { + for(int j = 0; j < blk.getNumColumns(); j++) { + if(mask.get(i,j) ==0) blk.set(i, j, 0); + } + } + return blk; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java index a25249985d6..d7186f2bbe2 100644 --- a/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java +++ b/src/test/java/org/apache/sysds/test/functions/ooc/PNMFTest.java @@ -21,12 +21,16 @@ import java.io.IOException; +import org.apache.sysds.common.Opcodes; import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.DataConverter; import org.apache.sysds.test.AutomatedTestBase; import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Test; public class PNMFTest extends AutomatedTestBase { private static final String TEST_NAME = "PNMF"; @@ -44,6 +48,7 @@ public class PNMFTest extends AutomatedTestBase { private static final int RANK = 20; private static final int MAX_ITER = 10; private static final int BLOCK_SIZE = 1000; + private static final int SEED = 7; private static final double SPARSITY = 0.7; private static final double EPS = 1e-6; @@ -54,7 +59,7 @@ public void setUp() { addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); } - //@Test + @Test public void testPNMFOOCVsCP() { runPNMFTest(); } @@ -71,13 +76,16 @@ private void runPNMFTest() { double[][] xData = getRandomMatrix(ROWS, COLS, 1, 10, SPARSITY, 7); writeBinaryWithMTD(INPUT_X, DataConverter.convertToMatrixBlock(xData)); - programArgs = new String[] {"-explain", "-stats", "-seed", "7", "-ooc", "-args", - input(INPUT_X), String.valueOf(RANK), String.valueOf(MAX_ITER), + programArgs = new String[] {"-explain", "-stats", "-ooc", "-args", + input(INPUT_X), String.valueOf(RANK), String.valueOf(MAX_ITER), String.valueOf(SEED), output(OUTPUT_W_OOC), output(OUTPUT_H_OOC)}; runTest(true, false, null, -1); - programArgs = new String[] {"-explain", "-stats", "-seed", "7", "-args", - input(INPUT_X), String.valueOf(RANK), String.valueOf(MAX_ITER), + Assert.assertTrue("OOC wasn't used for pnmf", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.WEIGHTEDDIVMM)); + + programArgs = new String[] {"-explain", "-stats", "-args", + input(INPUT_X), String.valueOf(RANK), String.valueOf(MAX_ITER), String.valueOf(SEED), output(OUTPUT_W_CP), output(OUTPUT_H_CP)}; runTest(true, false, null, -1); diff --git a/src/test/java/org/apache/sysds/test/functions/ooc/WDivMMTest.java b/src/test/java/org/apache/sysds/test/functions/ooc/WDivMMTest.java new file mode 100644 index 00000000000..549fdc764d7 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/ooc/WDivMMTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.ooc; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collection; + +import org.apache.sysds.common.Opcodes; +import org.apache.sysds.common.Types; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.io.MatrixWriter; +import org.apache.sysds.runtime.io.MatrixWriterFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.util.DataConverter; +import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +@net.jcip.annotations.NotThreadSafe +public class WDivMMTest extends AutomatedTestBase { + private final static String INPUT_NAME_1 = "W"; + private final static String INPUT_NAME_2 = "U"; + private final static String INPUT_NAME_3 = "V"; + private final static String OUTPUT_NAME = "R"; + private static final String TEST_DIR = "functions/ooc/"; + private static final String TEST_CLASS_DIR = TEST_DIR + WDivMMTest.class.getSimpleName() + "/"; + + private static final int rows = 2201; + private static final int cols = 1103; + private static final int rank = 20; + private static final int blen = 1000; + private static final double eps = 1e-6; + private static final double div_eps = 0.1; + + private final static String TEST_NAME_1 = "WeightedDivMMLeft"; + private final static String TEST_NAME_2 = "WeightedDivMMRight"; + private final static String TEST_NAME_3 = "WeightedDivMMMultBasic"; + private final static String TEST_NAME_4 = "WeightedDivMMMultLeft"; + private final static String TEST_NAME_5 = "WeightedDivMMMultRight"; + private final static String TEST_NAME_6 = "WeightedDivMMMultMinusLeft"; + private final static String TEST_NAME_7 = "WeightedDivMMMultMinusRight"; + private final static String TEST_NAME_8 = "WeightedDivMM4MultMinusLeft"; + private final static String TEST_NAME_9 = "WeightedDivMM4MultMinusRight"; + private final static String TEST_NAME_10 = "WeightedDivMMLeftEps"; + private final static String TEST_NAME_11 = "WeightedDivMMRightEps"; + private String TEST_NAME; + + public WDivMMTest(String testName) { + this.TEST_NAME = testName; + } + + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList(new Object[][] {{TEST_NAME_1}, {TEST_NAME_2}, {TEST_NAME_3}, {TEST_NAME_4}, {TEST_NAME_5}, + {TEST_NAME_6}, {TEST_NAME_7}, {TEST_NAME_8}, {TEST_NAME_9}, {TEST_NAME_10}, {TEST_NAME_11}}); + } + + @Before + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {OUTPUT_NAME})); + } + + @Test + public void testWeightedDivMM() { + runWeightedDivMMTest(TEST_NAME); + } + + private void runWeightedDivMMTest(String TEST_NAME) { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + + try { + boolean basic = TEST_NAME.equals(TEST_NAME_3); + boolean left = TEST_NAME.equals(TEST_NAME_1) || TEST_NAME.equals(TEST_NAME_4) || + TEST_NAME.equals(TEST_NAME_6) || TEST_NAME.equals(TEST_NAME_8) || TEST_NAME.equals(TEST_NAME_10); + + getAndLoadTestConfiguration(TEST_NAME); + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + + double[][] W = getRandomMatrix(rows, cols, 0, 1, 0.7, 7); + double[][] U = getRandomMatrix(rows, rank, 0, 1, 1.0, 713); + double[][] V = getRandomMatrix(cols, rank, 0, 1, 1.0, 812); + + MatrixWriter writer = MatrixWriterFactory.createMatrixWriter(Types.FileFormat.BINARY); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(W), input(INPUT_NAME_1), rows, + cols, blen, rows * cols); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(U), input(INPUT_NAME_2), rows, + rank, blen, rows * rank); + writer.writeMatrixToHDFS(DataConverter.convertToMatrixBlock(V), input(INPUT_NAME_3), cols, + rank, blen, cols * rank); + + HDFSTool.writeMetaDataFile(input(INPUT_NAME_1 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, cols, blen, rows * cols), Types.FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_2 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(rows, rank, blen, rows * rank), Types.FileFormat.BINARY); + HDFSTool.writeMetaDataFile(input(INPUT_NAME_3 + ".mtd"), Types.ValueType.FP64, + new MatrixCharacteristics(cols, rank, blen, cols * rank), Types.FileFormat.BINARY); + + programArgs = new String[] {"-ooc", "-stats", "-explain", "runtime", "-args", input(INPUT_NAME_1), + input(INPUT_NAME_2), input(INPUT_NAME_3), output(OUTPUT_NAME), Double.toString(div_eps)}; + + runTest(true, false, null, -1); + + Assert.assertTrue("OOC wasn't used for wdivmm", + heavyHittersContainsString(Instruction.OOC_INST_PREFIX + Opcodes.WEIGHTEDDIVMM)); + + programArgs = new String[] {"-stats", "-explain", "runtime", "-args", input(INPUT_NAME_1), + input(INPUT_NAME_2), input(INPUT_NAME_3), output(OUTPUT_NAME + "_target"), Double.toString(div_eps)}; + + runTest(true, false, null, -1); + + int rows2 = left ? cols : rows; + int cols2 = basic ? cols : rank; + checkDMLMetaDataFile("R", new MatrixCharacteristics(rows2, cols2)); + + MatrixBlock actual = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME), + Types.FileFormat.BINARY, rows2, cols2, blen); + MatrixBlock expected = DataConverter.readMatrixFromHDFS(output(OUTPUT_NAME + "_target"), + Types.FileFormat.BINARY, rows2, cols2, blen); + TestUtils.compareMatrices(expected, actual, eps); + } + catch(IOException e) { + throw new RuntimeException(e); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/ooc/PNMF.dml b/src/test/scripts/functions/ooc/PNMF.dml index 60aecb8963f..bc0fd5b100e 100644 --- a/src/test/scripts/functions/ooc/PNMF.dml +++ b/src/test/scripts/functions/ooc/PNMF.dml @@ -20,7 +20,7 @@ #------------------------------------------------------------- X = read($1); -[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE); +[W, H] = pnmf(X=X, rnk=$2, maxi=$3, verbose=FALSE, seed=$4); -write(W, $4, format="binary"); -write(H, $5, format="binary"); +write(W, $5, format="binary"); +write(H, $6, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml new file mode 100644 index 00000000000..42bd4c96a0c --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusLeft.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +X = W/0.7; +while(FALSE){} +R = t(t(U) %*% (W*(U%*%t(V)-X))); + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml new file mode 100644 index 00000000000..7b393f12310 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMM4MultMinusRight.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +X = W/0.3 +while(FALSE){} +R = (W*(U%*%t(V)-X)) %*% V; + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMLeft.dml b/src/test/scripts/functions/ooc/WeightedDivMMLeft.dml new file mode 100644 index 00000000000..48639a176a7 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMLeft.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = t(t(U) %*% (W/(U%*%t(V)))); + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml b/src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml new file mode 100644 index 00000000000..dc07670feab --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMLeftEps.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +x = $5; + +R = t(t(U) %*% (W/(U%*%t(V) + x))); + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml b/src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml new file mode 100644 index 00000000000..144e59a773a --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMMultBasic.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = W*(U%*%t(V)); + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml b/src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml new file mode 100644 index 00000000000..93bc765617f --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMMultLeft.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = t(t(U) %*% (W*(U%*%t(V)))); + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml new file mode 100644 index 00000000000..84ac35ad899 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusLeft.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = t(t(U) %*% ((W != 0)*(U%*%t(V)-W))); + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml new file mode 100644 index 00000000000..59caa4d17b4 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMMultMinusRight.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = ((W != 0)*(U%*%t(V)-W)) %*% V; + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml b/src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml new file mode 100644 index 00000000000..fbb1224d173 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMMultRight.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = (W*(U%*%t(V))) %*% V; + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMRight.dml b/src/test/scripts/functions/ooc/WeightedDivMMRight.dml new file mode 100644 index 00000000000..e878a81d14d --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMRight.dml @@ -0,0 +1,30 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +R = (W/(U%*%t(V))) %*% V; + +write(R, $4, format="binary"); diff --git a/src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml b/src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml new file mode 100644 index 00000000000..9ecbaf56630 --- /dev/null +++ b/src/test/scripts/functions/ooc/WeightedDivMMRightEps.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + + + +W = read($1); +U = read($2); +V = read($3); + +x = $5; + +R = (W/(U%*%t(V) + x)) %*% V; + +write(R, $4, format="binary");