diff --git a/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/ComplexityUtils.java b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/ComplexityUtils.java new file mode 100644 index 000000000..42d2eda68 --- /dev/null +++ b/wayang-commons/wayang-basic/src/main/java/org/apache/wayang/basic/util/ComplexityUtils.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.basic.util; + +import org.apache.wayang.basic.operators.CoGroupOperator; +import org.apache.wayang.basic.operators.FilterOperator; +import org.apache.wayang.basic.operators.FlatMapOperator; +import org.apache.wayang.basic.operators.GlobalReduceOperator; +import org.apache.wayang.basic.operators.GroupByOperator; +import org.apache.wayang.basic.operators.JoinOperator; +import org.apache.wayang.basic.operators.LoopOperator; +import org.apache.wayang.basic.operators.MapOperator; +import org.apache.wayang.basic.operators.MapPartitionsOperator; +import org.apache.wayang.basic.operators.MaterializedGroupByOperator; +import org.apache.wayang.basic.operators.ReduceByOperator; +import org.apache.wayang.basic.operators.ReduceOperator; +import org.apache.wayang.basic.operators.SortOperator; +import org.apache.wayang.core.optimizer.ComplexityClass; +import org.apache.wayang.core.plan.wayangplan.Operator; + +public class ComplexityUtils { + /** + * Infer complexity class from a given operator's descriptors. + * @param operator + * @return {@link ComplexityClass#LOGARITHMIC}, {@link ComplexityClass#LINEAR}, {@link ComplexityClass#QUADRATIC} or {@link ComplexityClass#SUPERQUADRATIC}. {@link ComplexityClass#LINEAR} on default + */ + public static ComplexityClass inferFromOperator(final Operator operator) { + if (operator instanceof final ReduceByOperator reduceBy) { + return reduceBy.getReduceDescriptor().getComplexityClass().orElse(ComplexityClass.LINEAR); + } else if (operator instanceof final ReduceOperator reduce) { + return reduce.getReduceDescriptor().getComplexityClass().orElse(ComplexityClass.LINEAR); + } else if (operator instanceof final GlobalReduceOperator globalReduce) { + return globalReduce.getReduceDescriptor().getComplexityClass().orElse(ComplexityClass.LINEAR); + } else if (operator instanceof final CoGroupOperator coGroup) { + return coGroup.getKeyDescriptor0().getComplexityClass().orElse(ComplexityClass.LINEAR); + } else if (operator instanceof final GroupByOperator groupBy) { + return groupBy.getKeyDescriptor().getComplexityClass().orElse(ComplexityClass.LINEAR); + } else if (operator instanceof final MaterializedGroupByOperator matGroupBy) { + return matGroupBy.getKeyDescriptor().getComplexityClass().orElse(ComplexityClass.LINEAR); + } else if (operator instanceof final SortOperator sort) { + return sort.getKeyDescriptor().getComplexityClass().orElse(ComplexityClass.LINEAR); + } else if (operator instanceof final JoinOperator join) { + return join.getKeyDescriptor0().getComplexityClass().orElse(ComplexityClass.LINEAR); + } else if (operator instanceof final MapOperator map) { + return map.getFunctionDescriptor().getComplexityClass().orElse(ComplexityClass.LINEAR); + } else if (operator instanceof final FlatMapOperator flatMap) { + return flatMap.getFunctionDescriptor().getComplexityClass().orElse(ComplexityClass.LINEAR); + } else if (operator instanceof final MapPartitionsOperator mapPartitions) { + return mapPartitions.getFunctionDescriptor().getComplexityClass().orElse(ComplexityClass.LINEAR); + } else if (operator instanceof final FilterOperator filter) { + return filter.getPredicateDescriptor().getComplexityClass().orElse(ComplexityClass.LINEAR); + } else if (operator instanceof final LoopOperator loop) { + return loop.getCriterionDescriptor().getComplexityClass().orElse(ComplexityClass.LINEAR); + } + + return ComplexityClass.LINEAR; + } +} diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/function/FunctionDescriptor.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/function/FunctionDescriptor.java index 3c4c36408..ee7e09a12 100644 --- a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/function/FunctionDescriptor.java +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/function/FunctionDescriptor.java @@ -18,6 +18,7 @@ package org.apache.wayang.core.function; +import org.apache.wayang.core.optimizer.ComplexityClass; import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval; import org.apache.wayang.core.optimizer.costs.LoadEstimator; import org.apache.wayang.core.optimizer.costs.LoadProfileEstimator; @@ -32,10 +33,13 @@ */ public abstract class FunctionDescriptor implements Serializable { - public FunctionDescriptor() {} + public FunctionDescriptor() { + } private LoadProfileEstimator loadProfileEstimator; + private ComplexityClass complexityClass = null; + public FunctionDescriptor(LoadProfileEstimator loadProfileEstimator) { this.setLoadProfileEstimator(loadProfileEstimator); } @@ -48,6 +52,14 @@ public Optional getLoadProfileEstimator() { return Optional.ofNullable(this.loadProfileEstimator); } + public Optional getComplexityClass(){ + return Optional.ofNullable(complexityClass); + } + + public void setComplexityClass(final ComplexityClass complexityClass){ + this.complexityClass = complexityClass; + } + /** * Utility method to retrieve the selectivity of a {@link FunctionDescriptor} * diff --git a/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/ComplexityClass.java b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/ComplexityClass.java new file mode 100644 index 000000000..d7f1d7702 --- /dev/null +++ b/wayang-commons/wayang-core/src/main/java/org/apache/wayang/core/optimizer/ComplexityClass.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.core.optimizer; + +public enum ComplexityClass { + LINEAR, + LOGARITHMIC, + QUADRATIC, + SUPERQUADRATIC +} \ No newline at end of file diff --git a/wayang-plugins/pom.xml b/wayang-plugins/pom.xml index 1d11b2da0..5dce55053 100644 --- a/wayang-plugins/pom.xml +++ b/wayang-plugins/pom.xml @@ -39,6 +39,6 @@ wayang-iejoin wayang-spatial + wayang-ml - diff --git a/wayang-plugins/wayang-ml/pom.xml b/wayang-plugins/wayang-ml/pom.xml new file mode 100644 index 000000000..205c8d29a --- /dev/null +++ b/wayang-plugins/wayang-ml/pom.xml @@ -0,0 +1,157 @@ + + + + + 4.0.0 + + + org.apache.wayang + wayang-plugins + 1.1.2-SNAPSHOT + + + wayang-ml + 1.1.2-SNAPSHOT + + + org.apache.wayang.extensions.ml + + + + + org.apache.wayang + wayang-api-sql + 1.1.2-SNAPSHOT + + + com.microsoft.onnxruntime + onnxruntime + 1.21.1 + + + + org.apache.wayang + wayang-core + 1.1.2-SNAPSHOT + + + org.apache.wayang + wayang-basic + 1.1.2-SNAPSHOT + + + org.apache.wayang + wayang-java + 1.1.2-SNAPSHOT + + + org.apache.wayang + wayang-spark + 1.1.2-SNAPSHOT + + + org.apache.wayang + wayang-flink + 1.1.2-SNAPSHOT + + + org.apache.flink + flink-java + ${flink.version} + + + org.apache.wayang + wayang-giraph + 1.1.2-SNAPSHOT + + + org.apache.wayang + wayang-generic-jdbc + 1.1.2-SNAPSHOT + + + org.reflections + reflections + 0.10.2 + + + org.apache.wayang + wayang-benchmark + 1.1.2-SNAPSHOT + + + org.apache.wayang + wayang-api-python + 1.1.2-SNAPSHOT + + + org.apache.commons + commons-dbcp2 + 2.7.0 + + + org.apache.spark + spark-core_2.12 + ${spark.version} + + + org.apache.spark + spark-graphx_2.12 + ${spark.version} + + + org.apache.spark + spark-mllib_2.12 + ${spark.version} + + + com.google.protobuf + protobuf-java + 3.16.3 + + + org.apache.calcite + calcite-core + ${calcite.version} + + + org.apache.calcite + calcite-linq4j + ${calcite.version} + + + org.apache.calcite + calcite-file + ${calcite.version} + + + + + + src/main/resources + + + + diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/MLContext.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/MLContext.java new file mode 100644 index 000000000..534f873e5 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/MLContext.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml; + +import java.util.Optional; + +import org.apache.logging.log4j.Level; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.api.WayangContext; +import org.apache.wayang.core.plan.wayangplan.WayangPlan; +import org.apache.wayang.core.util.ReflectionUtils; +import org.apache.wayang.ml.encoding.OneHotMappings; +import org.apache.wayang.ml.encoding.TreeEncoder; +import org.apache.wayang.ml.util.Logging; + +/** + * This is the entry point for users to work with Wayang ML. + */ +public class MLContext extends WayangContext { + public MLContext() { + super(); + } + + public MLContext(final Configuration configuration) { + super(configuration); + } + + /** + * Execute a plan. + * + * @param wayangPlan the plan to execute + * @param udfJars JARs that declare the code for the UDFs + * @see ReflectionUtils#getDeclaringJar(Class) + */ + @Override + public void execute(final WayangPlan wayangPlan, final String... udfJars) { + this.setLogLevel(Level.ERROR); + final Job wayangJob = this.createJob("", wayangPlan, udfJars); + + final Configuration config = this.getConfiguration(); + final Configuration jobConfig = wayangJob.getConfiguration(); + + wayangJob.execute(); + + if (config.getBooleanProperty("wayang.ml.experience.enabled")) { + final Optional originalOption = config.getOptionalStringProperty("wayang.ml.experience.original"); + + final OneHotMappings mappings = new OneHotMappings(); + final TreeEncoder encoder = new TreeEncoder(mappings); + final String original = originalOption.orElse(encoder.encode(wayangPlan, wayangJob.getOptimizationContext(), false).toString()); + + final Optional choicesOption = config + .getOptionalStringProperty("wayang.ml.experience.with-platforms"); + final String withChoices = choicesOption + .orElse(jobConfig.getStringProperty("wayang.ml.experience.with-platforms")); + + final long execTime = jobConfig.getLongProperty("wayang.ml.experience.exec-time"); + + this.logExperience(original, withChoices, execTime); + } + } + + private void logExperience(final String original, final String withChoices, final long execTime) { + if (!this.getConfiguration().getBooleanProperty("wayang.ml.experience.enabled")) { + return; + } + + final String content = String.format("%s:%s:%d", original, withChoices, execTime); + Logging.writeToFile(content, this.getConfiguration().getStringProperty("wayang.ml.experience.file")); + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/MachineLearning.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/MachineLearning.java new file mode 100644 index 000000000..80ebc1f67 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/MachineLearning.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml; + +import java.util.Collection; +import java.util.Collections; + +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.mapping.Mapping; +import org.apache.wayang.core.optimizer.channels.ChannelConversion; +import org.apache.wayang.core.platform.Platform; +import org.apache.wayang.core.plugin.Plugin; +import org.apache.wayang.java.platform.JavaPlatform; +import org.apache.wayang.spark.platform.SparkPlatform; + +/** + * Provides {@link Plugin}s that enable usage of the xxxx. + */ +public class MachineLearning { + + /** + * Enables use with the {@link JavaPlatform} and {@link SparkPlatform}. + */ + private static final Plugin PLUGIN = new Plugin() { + + @Override + public Collection getRequiredPlatforms() { + return Collections.emptyList(); + } + + @Override + public Collection getMappings() { + return Collections.emptyList(); + } + + @Override + public Collection getChannelConversions() { + return Collections.emptyList(); + } + + @Override + public void setProperties(final Configuration configuration) { + } + }; + + /** + * Retrieve a {@link Plugin} to use xxx on the + * + * @return the {@link Plugin} + */ + public static Plugin plugin() { + return PLUGIN; + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/WordCount.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/WordCount.java new file mode 100644 index 000000000..57a83ea54 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/benchmarks/WordCount.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.benchmarks; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +import org.apache.wayang.api.utils.Parameters; +import org.apache.wayang.basic.data.Tuple2; +import org.apache.wayang.basic.operators.FilterOperator; +import org.apache.wayang.basic.operators.FlatMapOperator; +import org.apache.wayang.basic.operators.LocalCallbackSink; +import org.apache.wayang.basic.operators.MapOperator; +import org.apache.wayang.basic.operators.ReduceByOperator; +import org.apache.wayang.basic.operators.TextFileSource; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.function.FlatMapDescriptor; +import org.apache.wayang.core.function.ReduceDescriptor; +import org.apache.wayang.core.function.TransformationDescriptor; +import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval; +import org.apache.wayang.core.plan.wayangplan.WayangPlan; +import org.apache.wayang.core.plugin.Plugin; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.types.DataUnitType; +import org.apache.wayang.core.util.ReflectionUtils; +import org.apache.wayang.java.platform.JavaPlatform; +import org.apache.wayang.ml.MLContext; +import org.apache.wayang.ml.costs.DefaultPointwiseCost; + +import scala.collection.JavaConversions; + +/** + * Example Apache Wayang (incubating) App that does a word count -- the Hello + * World of Map/Reduce-like systems. + */ +public class WordCount { + + /** + * Creates the {@link WayangPlan} for the word count app. + * + * @param inputFileUrl the file whose words should be counted + */ + public static WayangPlan createWayangPlan(final String inputFileUrl, + final Collection> collector) throws URISyntaxException, IOException { + // Assignment mode: none. + + final TextFileSource textFileSource = new TextFileSource(inputFileUrl); + textFileSource.setName("Load file"); + + // for each line (input) output an iterator of the words + final FlatMapOperator flatMapOperator = new FlatMapOperator<>( + new FlatMapDescriptor<>(line -> Arrays.asList(line.split("\\W+")), String.class, String.class, + new ProbabilisticDoubleInterval(100, 10000, 0.8))); + flatMapOperator.setName("Split words"); + + final FilterOperator filterOperator = new FilterOperator<>(str -> !str.isEmpty(), String.class); + filterOperator.setName("Filter empty words"); + + // for each word transform it to lowercase and output a key-value pair (word, 1) + final MapOperator> mapOperator = new MapOperator<>( + new TransformationDescriptor<>(word -> new Tuple2<>(word.toLowerCase(), 1), + DataUnitType.createBasic(String.class), DataUnitType.createBasicUnchecked(Tuple2.class)), + DataSetType.createDefault(String.class), DataSetType.createDefaultUnchecked(Tuple2.class)); + mapOperator.setName("To lower case, add counter"); + + // groupby the key (word) and add up the values (frequency) + final ReduceByOperator, String> reduceByOperator = new ReduceByOperator<>( + new TransformationDescriptor<>(pair -> pair.field0, DataUnitType.createBasicUnchecked(Tuple2.class), + DataUnitType.createBasic(String.class)), + new ReduceDescriptor<>(((a, b) -> { + a.field1 += b.field1; + return a; + }), DataUnitType.createGroupedUnchecked(Tuple2.class), DataUnitType.createBasicUnchecked(Tuple2.class)), + DataSetType.createDefaultUnchecked(Tuple2.class)); + reduceByOperator.setName("Add counters"); + + // write results to a sink + final LocalCallbackSink> sink = LocalCallbackSink.createCollectingSink(collector, + DataSetType.createDefaultUnchecked(Tuple2.class)); + sink.setName("Collect result"); + + // Build Rheem plan by connecting operators + textFileSource.connectTo(0, flatMapOperator, 0); + flatMapOperator.connectTo(0, filterOperator, 0); + filterOperator.connectTo(0, mapOperator, 0); + mapOperator.connectTo(0, reduceByOperator, 0); + reduceByOperator.connectTo(0, sink, 0); + + return new WayangPlan(sink); + } + + public static void main(final String[] args) throws IOException, URISyntaxException { + try { + if (args.length == 0) { + System.err.print("Usage: [,]* "); + System.exit(1); + } + + final List> collector = new LinkedList<>(); + final WayangPlan wayangPlan = createWayangPlan(args[1], collector); + + final Configuration config = new Configuration(); + + config.setProperty("wayang.ml.model.file", + "/var/www/html/wayang-plugins/wayang-ml/src/main/resources/cost.onnx"); + + config.setProperty("wayang.core.log.enabled", "false"); + + config.setCostModel(DefaultPointwiseCost.FACTORY.makeCost()); + final MLContext wayangContext = new MLContext(config); + + final List plugins = JavaConversions.seqAsJavaList(Parameters.loadPlugins(args[0])); + plugins.stream().forEach(plug -> wayangContext.register(plug)); + + wayangContext.execute(wayangPlan, ReflectionUtils.getDeclaringJar(WordCount.class), + ReflectionUtils.getDeclaringJar(JavaPlatform.class)); + + collector.sort((t1, t2) -> Integer.compare(t2.field1, t1.field1)); + System.out.printf("Found %d words:\n", collector.size()); + } catch (final Exception e) { + System.err.println("App failed."); + e.printStackTrace(); + System.exit(4); + } + } + +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/costs/BenchmarkCost.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/costs/BenchmarkCost.java new file mode 100644 index 000000000..cfdd12505 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/costs/BenchmarkCost.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.costs; + +import java.io.BufferedWriter; +import java.io.FileWriter; +import java.util.Collection; +import java.util.Set; + +import org.apache.wayang.commons.util.profiledb.instrumentation.StopWatch; +import org.apache.wayang.commons.util.profiledb.model.Experiment; +import org.apache.wayang.commons.util.profiledb.model.Subject; +import org.apache.wayang.commons.util.profiledb.model.measurement.TimeMeasurement; + +import org.apache.wayang.core.api.exception.WayangException; +import org.apache.wayang.core.optimizer.costs.DefaultEstimatableCost; +import org.apache.wayang.core.optimizer.costs.EstimatableCost; +import org.apache.wayang.core.optimizer.costs.EstimatableCostFactory; +import org.apache.wayang.core.optimizer.enumeration.PlanImplementation; +import org.apache.wayang.core.plan.executionplan.Channel; +import org.apache.wayang.core.plan.executionplan.ExecutionPlan; +import org.apache.wayang.core.plan.executionplan.ExecutionStage; + +public class BenchmarkCost extends DefaultEstimatableCost { + public static class Factory implements EstimatableCostFactory { + public EstimatableCost makeCost() { + return new BenchmarkCost(); + } + } + + public EstimatableCostFactory getFactory() { + return new Factory(); + } + + public PlanImplementation pickBestExecutionPlan( + final Collection executionPlans, + final ExecutionPlan existingPlan, + final Set openChannels, + final Set executedStages) { + // Measure time needed for a decision: picking the better plan. + try { + final BufferedWriter writer = new BufferedWriter(new FileWriter("/var/www/html/data/decisions.txt", true)); + final Experiment experiment = new Experiment("wayang-ml", new Subject("Wayang", "0.1")); + final StopWatch stopWatch = new StopWatch(experiment); + final TimeMeasurement decisionRound = stopWatch.getOrCreateRound("Decision"); + final TimeMeasurement currentDecisionRound = decisionRound.start(String.format("Decision %d", 1)); + final PlanImplementation bestPlanImplementation = executionPlans.stream() + .reduce((p1, p2) -> { + final double t1 = p1.getSquashedCostEstimate(); + final double t2 = p2.getSquashedCostEstimate(); + return t1 < t2 ? p1 : p2; + }) + .orElseThrow(() -> new WayangException("Could not find an execution plan.")); + decisionRound.stop(); + writer.write(String.format("Decision: %s",currentDecisionRound.getMillis())); + writer.newLine(); + writer.close(); + + return bestPlanImplementation; + } catch (final Exception e) { + System.out.println("Couldnt write to File error: " + e); + } + + return null; + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/costs/DefaultPointwiseCost.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/costs/DefaultPointwiseCost.java new file mode 100644 index 000000000..0693b0eca --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/costs/DefaultPointwiseCost.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.costs; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.exception.WayangException; +import org.apache.wayang.core.optimizer.costs.DefaultEstimatableCost; +import org.apache.wayang.core.optimizer.costs.EstimatableCost; +import org.apache.wayang.core.optimizer.costs.EstimatableCostFactory; +import org.apache.wayang.core.optimizer.enumeration.PlanImplementation; +import org.apache.wayang.core.plan.executionplan.Channel; +import org.apache.wayang.core.plan.executionplan.ExecutionPlan; +import org.apache.wayang.core.plan.executionplan.ExecutionStage; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.ml.encoding.OneHotMappings; +import org.apache.wayang.ml.encoding.OrtMLModel; +import org.apache.wayang.ml.encoding.OrtTensorEncoder; +import org.apache.wayang.ml.encoding.TreeEncoder; +import org.apache.wayang.ml.encoding.TreeNode; + +import ai.onnxruntime.OrtException; + +/** + * Default {@link EstimatableCost} for pointwise/cost-based ML models.
+ * Takes config {@code wayang.ml.experience.enabled} + */ +public class DefaultPointwiseCost extends DefaultEstimatableCost { + private static final Logger logger = LogManager.getLogger(DefaultPointwiseCost.class); + + public static class Factory implements EstimatableCostFactory { + private final TreeEncoder encoder = new TreeEncoder(new OneHotMappings()); + + @Override + public EstimatableCost makeCost() { + return new DefaultPointwiseCost(encoder); + } + } + + private final TreeEncoder encoder; + + public DefaultPointwiseCost(TreeEncoder encoder) { + this.encoder = encoder; + } + + @Override + public PlanImplementation pickBestExecutionPlan(final Collection executionPlans, + final ExecutionPlan existingPlan, final Set openChannels, + final Set executedStages) { + + final Map planCostMapping = executionPlans.stream() + .collect(Collectors.toMap(Function.identity(), this::getCost)); + + final PlanImplementation bestPlanImplementation = executionPlans.stream() + .min(Comparator.comparingDouble(planCostMapping::get)) + .orElseThrow(() -> new WayangException("Could not find an execution plan.")); + + final Configuration config = bestPlanImplementation.getOptimizationContext().getConfiguration(); + + if (config.getOptionalBooleanProperty("wayang.ml.experience.enabled").orElse(false)) { + final TreeNode encodedPlan = encoder.encode(bestPlanImplementation); + config.setProperty("wayang.ml.experience.with-platforms", encodedPlan.toString()); + } + + return bestPlanImplementation; + } + + /** + * Estimates the runtime cost for a given plan. + * + * @param plan + * @return + */ + public Double getCost(final PlanImplementation plan) { + try { + final Configuration config = plan.getOptimizationContext().getConfiguration(); + final OrtMLModel model = OrtMLModel.getInstance(config); + final TreeNode encodedOne = encoder.encode(plan); + final Tuple, ArrayList> tuple1 = OrtTensorEncoder.encode(encodedOne); + final double cost = Math.exp(model.runModel(tuple1)) - 1; + + return cost; + } catch (final OrtException e) { + logger.warn("Failed to estimate ML cost for plan {" + plan + "}. Falling back to MAX_VALUE.", e); + return Double.MAX_VALUE; + } + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OneHotEncoder.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OneHotEncoder.java new file mode 100644 index 000000000..95db04021 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OneHotEncoder.java @@ -0,0 +1,446 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.encoding; + +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.builder.HashCodeBuilder; +import org.apache.wayang.basic.util.ComplexityUtils; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.optimizer.cardinality.CardinalityEstimate; +import org.apache.wayang.core.optimizer.enumeration.PlanImplementation; +import org.apache.wayang.core.plan.executionplan.ExecutionTask; +import org.apache.wayang.core.plan.wayangplan.BinaryToUnaryOperator; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.plan.wayangplan.InputSlot; +import org.apache.wayang.core.plan.wayangplan.Operator; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.plan.wayangplan.UnaryToUnaryOperator; +import org.apache.wayang.core.platform.Junction; +import org.apache.wayang.core.util.Canonicalizer; +import org.apache.wayang.core.util.json.WayangJsonObj; +import org.apache.wayang.ml.util.CardinalitySampler; +import org.apache.wayang.ml.util.SampledCardinality; + +public class OneHotEncoder { + + protected OneHotEncoder() { + } + + public static final int PADDING_SIZE = 1; + + public static long[] encode(final PlanImplementation plan) { + final OneHotVector result = new OneHotVector(); + + if (plan.getOperators() == null) { + return result.getEntries(); + } + + encodeTopologies(plan, result); + encodeOperators(plan, result); + encodeDataMovement(plan, result); + encodeDataset(plan, result); + + return result.getEntries(); + } + + public static void encodeOperators(final PlanImplementation plan, final OneHotVector vector) { + /* + * Format: ---- BEGIN OPERATOR ITERATION ---- 0 - total # instances 1 - # + * instances in Java 2 - # instances in Spark 3 - # instances in Pipeline 4 - # + * instances in Junction 5 - # instances in Replicator 6 - # instances in Loop 7 + * - sum of UDF complexities 8 - sum of input cardinalities 9 - sum of output + * cardinalities + */ + final Canonicalizer operators = plan.getOperators(); + final HashMap platformMappings = OneHotMappings.getPlatformsMapping(); + final int platformsCount = platformMappings.size(); + + final List> distinctOperators = operators.stream().map(operator -> operator.getClass().getSuperclass()) + .distinct().collect(Collectors.toList()); + + for (final Class operator : distinctOperators) { + // build the features + final long encodedOperator[] = new long[OneHotVector.OPERATOR_SIZE]; + final List executionOperators = operators.stream() + .filter(op -> operator == op.getClass().getSuperclass()).toList(); + + encodedOperator[0] = (long) executionOperators.size(); + + final List operatorSamples = CardinalitySampler.samples.stream() + .filter(sample -> sample.getOperator().get("class").equals(operator.getName())).toList(); + + final long inputCardinality = operatorSamples.stream().mapToLong(sample -> { + long card = 0; + for (final Object input : sample.getInputs()) { + card += ((WayangJsonObj) input).getLong("upperBound"); + } + + return card; + }).sum(); + final long outputCardinality = operatorSamples.stream() + .mapToLong(sample -> sample.getOutput().getLong("cardinality")).sum(); + + for (final ExecutionOperator executionOperator : executionOperators) { + final Integer platformPosition = platformMappings + .get(executionOperator.getPlatform().getClass().getName()); + + if (platformPosition == null) { + continue; + } + + encodedOperator[platformPosition] += 1; + + if (executionOperator instanceof UnaryToUnaryOperator) { + encodedOperator[platformsCount + 1] += 1; + } + + if (executionOperator instanceof BinaryToUnaryOperator) { + encodedOperator[platformsCount + 2] += 1; + } + + if (executionOperator.isLoopSubplan() || executionOperator.isLoopHead()) { + encodedOperator[platformsCount + 3] += 1; + } + + encodedOperator[platformPosition + 4] += ComplexityUtils.inferFromOperator(executionOperator).ordinal(); + } + + encodedOperator[platformsCount + 5] += inputCardinality; + encodedOperator[platformsCount + 6] += outputCardinality; + + vector.addOperator(encodedOperator, operator.getName()); + } + } + + /* + * Format: ---- BEGIN OPERATOR ITERATION ---- 0 - # instances in Java 1 - # + * instances in Spark 2 - sum of input cardinalities 3 - sum of output + * cardinalities + */ + public static void encodeDataMovement(final PlanImplementation plan, final OneHotVector vector) { + final OptimizationContext optimizationContext = plan.getOptimizationContext(); + final HashMap platformMappings = OneHotMappings.getPlatformsMapping(); + final int platformsCount = platformMappings.size(); + + final List conversionTasks = plan.getJunctions().values().stream() + .map(Junction::getConversionTasks).flatMap(Collection::stream).toList(); + + final List> distinctOperators = conversionTasks.stream().map(task -> task.getOperator().getClass()) + .distinct().collect(Collectors.toList()); + + for (final Class operator : distinctOperators) { + final long encodedOperator[] = new long[OneHotVector.CONVERSION_SIZE]; + final List executionOperators = conversionTasks.stream().map(ExecutionTask::getOperator) + .filter(op -> operator == op.getClass()).toList(); + + for (final ExecutionOperator executionOperator : executionOperators) { + final Integer platformPosition = platformMappings + .get(executionOperator.getPlatform().getClass().getName()); + + if (platformPosition == null) { + continue; + } + + encodedOperator[platformPosition] += 1; + + final OptimizationContext.OperatorContext operatorContext = optimizationContext + .getOperatorContext(executionOperator); + + if (operatorContext == null) { + continue; + } + + final List operatorSamples = CardinalitySampler.samples.stream().filter( + sample -> sample.getOperator().get("class").equals(executionOperator.getClass().getName())) + .toList(); + + final long inputCardinality = operatorSamples.stream().mapToLong(sample -> { + long card = 0; + for (final Object input : sample.getInputs()) { + card += ((WayangJsonObj) input).getLong("upperBound"); + } + + return card; + }).sum(); + final long outputCardinality = operatorSamples.stream() + .mapToLong(sample -> sample.getOutput().getLong("cardinality")).sum(); + + encodedOperator[platformsCount] = inputCardinality; + encodedOperator[platformsCount + 1] = outputCardinality; + } + + vector.addDataMovement(encodedOperator, operator.getName()); + } + } + + public static void encodeTopologies(final PlanImplementation plan, final OneHotVector vector) { + final long[] topologies = new long[OneHotVector.TOPOLOGIES_LENGTH]; + + final long replicatorCount = plan.getOperators().stream() + .filter((operator) -> operator.getAllOutputs().length > 1).count(); + topologies[0] = replicatorCount; + topologies[1] = getPipelineCount(plan); + final long junctionCounter = plan.getOperators().stream() + .filter((operator) -> operator.getAllInputs().length > 1).count(); + topologies[2] = junctionCounter; + topologies[3] = (long) plan.getLoopImplementations().size(); + + vector.setTopologies(topologies); + } + + /* + * Format: ---- BEGIN OPERATOR ITERATION ---- 0 - operator hashCode as long 1 - + * sum of UDF complexities 2 - sum of input cardinalities 3 - sum of output + * cardinalities (4 ... end) - one hot marking type of operator + */ + public static long[] encodeOperator(final Operator operator, final OptimizationContext optimizationContext, + final boolean encodeIds) { + final List operatorSamples = CardinalitySampler.samples.stream() + .filter(sample -> sample.getOperator().get("class").equals(operator.getClass().getName())).toList(); + + long inputCardinality = 0; + long outputCardinality = 0; + + if (operatorSamples.size() == 0) { + final OptimizationContext.OperatorContext operatorContext = optimizationContext + .getOperatorContext(operator); + + if (operatorContext != null) { + for (final InputSlot input : operator.getAllInputs()) { + final CardinalityEstimate card = operatorContext.getInputCardinality(input.getIndex()); + if (card != null) { + inputCardinality += card.getLowerEstimate(); + } + } + + for (final OutputSlot output : operator.getAllOutputs()) { + final CardinalityEstimate card = operatorContext.getOutputCardinality(output.getIndex()); + if (card != null) { + outputCardinality += card.getLowerEstimate(); + } + } + } + } else { + inputCardinality = operatorSamples.stream().mapToLong(sample -> { + long card = 0; + for (final Object input : sample.getInputs()) { + card += ((WayangJsonObj) input).getLong("upperBound"); + } + + return card; + }).sum(); + outputCardinality = operatorSamples.stream().mapToLong(sample -> sample.getOutput().getLong("cardinality")) + .sum(); + } + + final HashMap operatorMappings = OneHotMappings.getOperatorMapping(); + final HashMap platformMappings = OneHotMappings.getPlatformsMapping(); + + final int operatorsCount = operatorMappings.size(); + final int platformsCount = platformMappings.size(); + + final long[] result = new long[PADDING_SIZE + operatorsCount + platformsCount + 3]; + + if (encodeIds) { + result[0] = (long) new HashCodeBuilder(17, 37).append(operator.toString()).append(operator.getName()) + .append(operator.getAllInputs().length).append(operator.getAllOutputs().length).toHashCode(); + } + + result[PADDING_SIZE + operatorsCount + platformsCount] = ComplexityUtils.inferFromOperator(operator).ordinal(); + result[PADDING_SIZE + operatorsCount + platformsCount + 1] = inputCardinality; + result[PADDING_SIZE + operatorsCount + platformsCount + 2] = outputCardinality; + + final Integer operatorPosition = operatorMappings.get(operator.getClass().getName()); + result[1 + operatorPosition] = 1; + + return result; + } + + public static long[] encodeOperator(final ExecutionOperator operator, final OptimizationContext optimizationContext, + final boolean encodeIds) { + final List operatorSamples = CardinalitySampler.samples.stream() + .filter(sample -> sample.getOperator().get("class").equals(operator.getClass().getName())).toList(); + + long inputCardinality = 0; + long outputCardinality = 0; + + if (operatorSamples.size() == 0) { + final OptimizationContext.OperatorContext operatorContext = optimizationContext + .getOperatorContext(operator); + + if (operatorContext != null) { + for (final InputSlot input : operator.getAllInputs()) { + inputCardinality += operatorContext.getInputCardinality(input.getIndex()).getLowerEstimate(); + } + + for (final OutputSlot output : operator.getAllOutputs()) { + outputCardinality += operatorContext.getOutputCardinality(output.getIndex()).getLowerEstimate(); + } + } + } else { + inputCardinality = operatorSamples.stream().mapToLong(sample -> { + long card = 0; + for (final Object input : sample.getInputs()) { + card += ((WayangJsonObj) input).getLong("upperBound"); + } + + return card; + }).sum(); + outputCardinality = operatorSamples.stream().mapToLong(sample -> sample.getOutput().getLong("cardinality")) + .sum(); + } + + final HashMap operatorMappings = OneHotMappings.getOperatorMapping(); + final HashMap platformMappings = OneHotMappings.getPlatformsMapping(); + + final int operatorsCount = operatorMappings.size(); + final int platformsCount = platformMappings.size(); + + // Schema is: [ID, operator_1, ..., operator_N, platform_1, ..., platform_P, + // udf, in_c, out_c] + final long[] result = new long[PADDING_SIZE + operatorsCount + platformsCount + 3]; + + if (encodeIds) { + result[0] = (long) new HashCodeBuilder(17, 37).append(operator.toString()).append(operator.getName()) + .append(operator.getAllInputs().length).append(operator.getAllOutputs().length).toHashCode(); + } + + result[PADDING_SIZE + operatorsCount + platformsCount] = ComplexityUtils.inferFromOperator(operator).ordinal(); + result[PADDING_SIZE + operatorsCount + platformsCount + 1] = inputCardinality; + result[PADDING_SIZE + operatorsCount + platformsCount + 2] = outputCardinality; + + Integer operatorPosition = operatorMappings.get(operator.getClass().getSuperclass().getName()); + + // Try to find a higher matching parent in the mappings + if (operatorPosition == null) { + operatorPosition = operatorMappings.get(operator.getClass().getSuperclass().getSuperclass().getName()); + } + + assert operatorPosition != null : operator.getClass().getSuperclass().getName() + " was not found in mappings"; + + result[PADDING_SIZE + operatorPosition] = 1; + + final Integer platformPosition = platformMappings.get(operator.getPlatform().getClass().getName()); + result[PADDING_SIZE + operatorsCount + platformPosition] = 1; + + return result; + } + + public static long[] encodeNullOperator() { + final HashMap operatorMappings = OneHotMappings.getOperatorMapping(); + final HashMap platformMappings = OneHotMappings.getPlatformsMapping(); + + final int operatorsCount = operatorMappings.size(); + final int platformsCount = platformMappings.size(); + final long[] result = new long[PADDING_SIZE + operatorsCount + platformsCount + 3]; + + return result; + } + + private static long getPipelineCount(final PlanImplementation plan) { + long pipelineCount = 0; + final HashMap visited = new HashMap<>(); + final List startOperators = plan.getStartOperators(); + + // traverse operators starting from each startOperator until + // a junction target is found. Mark all as visited and increment + // pipeline counter until no more visitable operators are existant. + for (final ExecutionOperator startOperator : startOperators) { + pipelineCount += traverse(plan, startOperator, visited, 0, 0); + } + + return pipelineCount; + } + + private static long traverse(final PlanImplementation plan, final Operator current, + final HashMap visited, final int steps, long pipelineCount) { + + if (visited.containsKey(current)) { + return pipelineCount; + } + + visited.put(current, Integer.valueOf(1)); + final OutputSlot[] outputs = current.getAllOutputs(); + + if (outputs.length == 0) { + if (steps > 0) { + pipelineCount++; + } + + return pipelineCount; + } + + // check if this junction output + if (current.getAllInputs().length > 1) { + if (steps > 1) { + pipelineCount++; + } + + for (int i = 0; i < outputs.length; i++) { + final Junction junction = plan.getJunction(outputs[i]); + + if (junction.getNumTargets() == 0) { + return pipelineCount; + } + + for (final InputSlot input : junction.getTargetInputs()) { + final Operator next = input.getOwner(); + pipelineCount += traverse(plan, next, visited, 0, pipelineCount); + } + } + } + + // check if this is replicator input + if (current.getAllOutputs().length > 1) { + if (steps > 1) { + pipelineCount++; + } + + for (int i = 0; i < outputs.length; i++) { + final Junction junction = plan.getJunction(outputs[i]); + + if (junction.getNumTargets() == 0) { + return pipelineCount; + } + + for (final InputSlot input : junction.getTargetInputs()) { + final Operator next = input.getOwner(); + pipelineCount += traverse(plan, next, visited, 0, pipelineCount); + } + } + } + + final Junction junction = plan.getJunction(outputs[0]); + + if (junction.getNumTargets() == 0) { + return pipelineCount; + } + final Operator next = junction.getTargetInput(0).getOwner(); + + return traverse(plan, next, visited, steps + 1, pipelineCount); + } + + private static void encodeDataset(final PlanImplementation plan, final OneHotVector vector) { + vector.setDataset(100l); + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OneHotMappings.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OneHotMappings.java new file mode 100644 index 000000000..03cc90a08 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OneHotMappings.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.encoding; + +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Optional; + +import org.apache.commons.lang3.builder.HashCodeBuilder; +import org.apache.wayang.core.plan.wayangplan.Operator; +import org.apache.wayang.core.platform.Platform; +import org.apache.wayang.ml.util.Operators; +import org.apache.wayang.ml.util.Platforms; + +public class OneHotMappings { + private static final int PADDING_SIZE = 1; + + private static final HashMap operatorMapping = createOperatorMapping(); + private static final HashMap platformsMapping = createPlatformMapping(); + + public static Optional getOperatorPlatformFromEncoding(final long[] encoded) { + final int platformsCount = platformsMapping.size(); + final int operatorsCount = operatorMapping.size(); + + if (platformsCount > encoded.length) { + return Optional.empty(); + } + + int platformIndex = -1; + final int offset = PADDING_SIZE + operatorsCount; + + for (int i = offset; i < platformsCount + offset && platformIndex == -1; i++) { + if (encoded[i] == 1) { + platformIndex = i; + } + } + + if (platformIndex == -1) { + return Optional.empty(); + } + + for (final Object entry : platformsMapping.keySet()) { + if (platformsMapping.get(entry).equals(platformIndex - offset)) { + return Platforms.getPlatforms().stream().filter(pl -> pl.getName().equals(entry)) + .map(cl -> Platform.load(cl.getName())).findAny(); + } + } + + return Optional.empty(); + } + + public static HashMap getOperatorMapping() { + return operatorMapping; + } + + public static HashMap getPlatformsMapping() { + return platformsMapping; + } + + private static HashMap createOperatorMapping() { + final HashMap mappings = new HashMap<>(); + + Operators.getOperators().stream() + .filter(operator -> operator.getName().contains("org.apache.wayang.basic.operators") + || operator.getName().contains("org.apache.wayang.core.plan.wayangplan")) + .distinct().sorted(Comparator.comparing(Class::getName)) + .forEachOrdered(entry -> mappings.put(entry.getName(), mappings.size())); + + return mappings; + } + + private static HashMap createPlatformMapping() { + final HashMap mappings = new HashMap<>(); + + Platforms.getPlatforms().stream().sorted(Comparator.comparing(Class::getName)) + .forEachOrdered(entry -> mappings.put(entry.getName(), mappings.size())); + + return mappings; + } + + private final HashSet originalOperators = new HashSet<>(); + + public OneHotMappings() { + } + + public void addOriginalOperator(final Operator operator) { + originalOperators.add(operator); + } + + public HashSet getOriginalOperators() { + return originalOperators; + } + + public Optional getOperatorFromEncoding(final long[] encoded) { + final long hashCode = encoded[0]; + + final Optional original = originalOperators.stream() + .filter(op -> + (long) new HashCodeBuilder(17, 37).append(op.toString()).append(op.getName()) + .append(op.getAllInputs().length).append(op.getAllOutputs().length).toHashCode() == hashCode) + .findAny(); + + if (original.isPresent()) { + return original; + } + + return Optional.empty(); + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OneHotVector.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OneHotVector.java new file mode 100644 index 000000000..362471d90 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OneHotVector.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.encoding; + +import java.util.HashMap; + +public class OneHotVector { + public static final int TOPOLOGIES_LENGTH = 4; + + // Size of the encoding data for one operator + public static final int OPERATOR_SIZE = OneHotMappings.getPlatformsMapping().size() + 8; + public static final int CONVERSION_SIZE = OneHotMappings.getPlatformsMapping().size() + 3; + public static final int OPERATORS_LENGTH = OneHotMappings.getOperatorMapping().size() * OPERATOR_SIZE; + public static final int CONVERSIONS_LENGTH = OneHotMappings.getOperatorMapping().size() * CONVERSION_SIZE; + + public static final int LENGTH = TOPOLOGIES_LENGTH + OPERATORS_LENGTH + CONVERSIONS_LENGTH + 1; + + private static int getPosition(final String operator) { + final HashMap operatorMapping = OneHotMappings.getOperatorMapping(); + + return !operatorMapping.containsKey(operator) ? -1 : operatorMapping.get(operator); + } + + private final long[] entries = new long[OneHotVector.LENGTH]; + + public OneHotVector(){ + } + + public void addOperator(final long[] encodedOperator, final String operator) { + assert encodedOperator.length == OPERATOR_SIZE : + "Invalid encoded operator size: expected " + OPERATOR_SIZE + + " but got " + encodedOperator.length + + " for operator [" + operator + "]."; + + final int position = getPosition(operator); + + // position of operator couldnt be found + if (position == -1) { + return; + //throw new WayangException("Could not find position of operator, potentially illegal operator, got operator: " + operator); + } + + for (int i = 0; i < encodedOperator.length; i++) { + this.entries[TOPOLOGIES_LENGTH + i + (position * OPERATOR_SIZE)] = encodedOperator[i]; + } + } + + public void addDataMovement(final long[] encodedConversion, final String operator) { + assert encodedConversion.length == CONVERSION_SIZE : "amount of encoded operators was not equal to the operator size defined in one hot. Got: " + encodedConversion.length + ", expected: " + CONVERSION_SIZE; + final int position = getPosition(operator); + + // position of operator couldnt be found + if (position == -1) { + return; + //throw new WayangException("Could not find position of operator, potentially illegal operator, got operator: " + operator); + } + + for (int i = 0; i < encodedConversion.length; i++) { + this.entries[TOPOLOGIES_LENGTH + OPERATORS_LENGTH + i + (position * CONVERSION_SIZE)] = encodedConversion[i]; + } + } + + public void setTopologies(final long[] topologies) { + assert topologies.length == TOPOLOGIES_LENGTH : "amount of encoded operators was not equal to the operator size defined in one hot."; + + for (int i = 0; i < TOPOLOGIES_LENGTH; i++) { + this.entries[i] = topologies[i]; + } + } + + public long getDataset() { + return this.entries[LENGTH - 1]; + } + + public void setDataset(final Long dataset) { + this.entries[LENGTH - 1] = dataset; + } + + public long[] getEntries() { + return this.entries; + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OrtMLModel.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OrtMLModel.java new file mode 100644 index 000000000..f721f3c07 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OrtMLModel.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.encoding; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.function.BiFunction; + +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.ml.util.Logging; + +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import ai.onnxruntime.OrtSession.Result; + +public class OrtMLModel { + private static OrtMLModel INSTANCE; + + public static OrtMLModel getInstance(final Configuration configuration) throws OrtException { + if (INSTANCE == null) { + INSTANCE = new OrtMLModel(configuration); + } + + return INSTANCE; + } + + private OrtSession session; + private OrtEnvironment env; + + private final Configuration configuration; + private final Map inputMap = new HashMap<>(); + + private final Set requestedOutputs = new HashSet<>(); + + private OrtMLModel(final Configuration configuration) throws OrtException { + this.configuration = configuration; + this.loadModel(configuration.getStringProperty("wayang.ml.model.file")); + } + + /** + * placeholder + * + * @param encoded + * @return + */ + public double runModel(final long[] encoded) { + return 0; + } + + public static void printTupleDeep(Tuple, ArrayList> input) { + System.out.println("=== VALUES (field0) ==="); + for (int k = 0; k < input.field0.size(); k++) { + System.out.println("Tree " + k + ":"); + + long[][] arr = input.field0.get(k); + for (int i = 0; i < arr.length; i++) { + System.out.print(" Node " + i + ": "); + for (int j = 0; j < arr[i].length; j++) { + System.out.print(arr[i][j] + " "); + } + System.out.println(); + } + } + + System.out.println("\n=== INDEXES (field1) ==="); + for (int k = 0; k < input.field1.size(); k++) { + System.out.println("Tree " + k + ":"); + + long[][] arr = input.field1.get(k); + for (int i = 0; i < arr.length; i++) { + System.out.print(" Idx " + i + ": "); + for (int j = 0; j < arr[i].length; j++) { + System.out.print(arr[i][j] + " "); + } + System.out.println(); + } + } + } + + /** + * Close the session after running, {@link #closeSession()} + * + * @param encodedVector + * @return NaN on error, and a predicted cost on any other value. + * @throws OrtException + */ + public double runModel(final Tuple, ArrayList> input1) throws OrtException { + final int batchSize = input1.getField0().size(); + final long[][] firstValues = input1.getField0().get(0); + final int featureSize = firstValues.length; + final int sequenceLength = firstValues[0].length; + + final Instant start = Instant.now(); + final float[][][] inputValueStructure = new float[batchSize][featureSize][sequenceLength]; + final long[][][] inputIndexStructure = new long[batchSize][featureSize][sequenceLength]; + + for (int i = 0; i < input1.field0.get(0).length; i++) { + for (int j = 0; j < input1.field0.get(0)[i].length; j++) { + inputValueStructure[0][i][j] = Long.valueOf(input1.field0.get(0)[i][j]).floatValue(); + } + } + + for (int i = 0; i < input1.field1.get(0).length; i++) { + inputIndexStructure[0][i] = input1.field1.get(0)[i]; + } + + final OnnxTensor tensorValues = OnnxTensor.createTensor(env, inputValueStructure); + final OnnxTensor tensorIndexes = OnnxTensor.createTensor(env, inputIndexStructure); + + this.inputMap.put("input1", tensorValues); + this.inputMap.put("input2", tensorIndexes); + + this.requestedOutputs.add("output"); + + final BiFunction unwrapFunc = (r, s) -> { + try { + return ((float[]) r.get(s).get().getValue())[0]; + } catch (final OrtException e) { + this.inputMap.clear(); + this.requestedOutputs.clear(); + + return Float.NaN; + } + }; + + final double costPrediction; + + try (Result r = session.run(inputMap, requestedOutputs)) { + costPrediction = unwrapFunc.apply(r, "output"); + final Instant end = Instant.now(); + final long execTime = Duration.between(start, end).toMillis(); + + Logging.writeToFile(String.format("%d", execTime), + this.configuration.getStringProperty("wayang.ml.optimizations.file")); + } catch (final Exception e) { + e.printStackTrace(); + return 0; + } finally { + this.inputMap.clear(); + this.requestedOutputs.clear(); + } + + return costPrediction; + } + + /** + * Closes the OrtModel resource, relinquishing any underlying resources. + * + * @throws OrtException + */ + public void closeSession() throws OrtException { + this.session.close(); + this.env.close(); + } + + private void loadModel(final String filePath) throws OrtException { + if (this.env == null) { + this.env = OrtEnvironment.getEnvironment("org.apache.wayang.ml"); + this.env.setTelemetry(false); + } + + if (this.session == null) { + final OrtSession.SessionOptions options = new OrtSession.SessionOptions(); + + options.setInterOpNumThreads(16); + options.setIntraOpNumThreads(16); + options.setDeterministicCompute(true); + + this.session = env.createSession(filePath, options); + } + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OrtTensorDecoder.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OrtTensorDecoder.java new file mode 100644 index 000000000..c3c875c2d --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OrtTensorDecoder.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.encoding; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; + +import org.apache.wayang.core.util.Tuple; + +import com.google.common.primitives.Longs; + +public class OrtTensorDecoder { + + /** + * Decodes the output from a tree based NN model + * + * @param mlOutput takes the out put from @ + */ + public static TreeNode decode(final Tuple, ArrayList> mlOutput) { + final HashMap nodeToIDMap = new HashMap<>(); + final long[][] platformChoices = mlOutput.field0.get(0); + final long[][] indexedTree = mlOutput.field1.get(0); + final long[] flatIndexTree = Arrays.stream(indexedTree).reduce(Longs::concat).orElseThrow(); + + for (int j = 0; j < flatIndexTree.length; j += 3) { + final long curID = flatIndexTree[j]; + final long[] value = platformChoices[(int) curID]; + + final TreeNode curTreeNode = nodeToIDMap.containsKey(curID) ? nodeToIDMap.get(curID) + : new TreeNode(value, null, null); + + curTreeNode.encoded = value; + + if (flatIndexTree.length > j + 1) { + final long lID = flatIndexTree[j + 1]; + TreeNode left; + + final long[] lValues = platformChoices[(int) lID]; + + if (nodeToIDMap.containsKey(lID)) { + left = nodeToIDMap.get(lID); + } else { + left = new TreeNode(lValues, null, null); + } + + left.encoded = lValues; + + nodeToIDMap.put(lID, left); + + curTreeNode.left = left; + + if (flatIndexTree.length > j + 2) { + final long rID = flatIndexTree[j + 2]; + TreeNode right; + + final long[] rValues = platformChoices[(int) rID]; + + if (nodeToIDMap.containsKey(rID)) { + right = nodeToIDMap.get(rID); + } else { + right = new TreeNode(rValues, null, null); + } + + right.encoded = rValues; + + nodeToIDMap.put(rID, right); + + curTreeNode.right = right; + } + } + + // put values back into map so we can look them up in next loop + nodeToIDMap.put(curID, curTreeNode); + } + + return nodeToIDMap.get(1L); + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OrtTensorEncoder.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OrtTensorEncoder.java new file mode 100644 index 000000000..badeb545b --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/OrtTensorEncoder.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.encoding; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import javax.annotation.Nonnull; + +import org.apache.wayang.core.util.Tuple; + +public class OrtTensorEncoder { + public static ArrayList transpose(final ArrayList flatTrees) { + return flatTrees.stream().map(tree -> IntStream.range(0, tree[0].length) // transpose matrix + .mapToObj(i -> Arrays.stream(tree).mapToLong(row -> row[i]).toArray()).toArray(long[][]::new)) + .collect(Collectors.toCollection(ArrayList::new)); + } + + /** + * Encodes a single tree + * + * @param node root tree node + * @return a flat struture of (trees, indexes) + */ + public static Tuple, ArrayList> encode(final @Nonnull TreeNode node) { + return OrtTensorEncoder.prepareTrees(List.of(node)); + } + + /** + * This method prepares the trees for creation of the OnnxTensor + * + * @param trees + * @return returns a tuple of (flatTrees, indexes) + */ + public static Tuple, ArrayList> prepareTrees(final List trees) { + final List flatTrees = trees.stream().map(OrtTensorEncoder::flatten).toList(); + + final ArrayList paddedTrees = padAndCombine(flatTrees); + + final ArrayList transposedTrees = transpose(paddedTrees); + + final ArrayList indexes = trees.stream().map(OrtTensorEncoder::treeConvIndexes) + .collect(Collectors.toCollection(ArrayList::new)); + + final ArrayList paddedIndexes = padAndCombine(indexes); + + return new Tuple<>(transposedTrees, paddedIndexes); + } + + /** + * Create indexes that, when used as indexes into the output of `flatten`, + * create an array such that a stride-3 1D convolution is the same as a tree + * convolution. + * + * @param root + * @return + */ + public static long[][] treeConvIndexes(final TreeNode root) { + final TreeNode indexTree = preorderIndexes(root, 1); + + final ArrayList acc = new ArrayList<>(); // in place of a generator + treeConvIndexesStep(indexTree, acc); // mutates acc + + final long[] flatAcc = acc.stream().flatMapToLong(Arrays::stream).toArray(); + + return Arrays.stream(flatAcc).mapToObj(v -> new long[] { v }).toArray(long[][]::new); + } + + public static void treeConvIndexesStep(final TreeNode root, final ArrayList acc) { + if (root == null) { + return; + } + + if (root.isLeaf()) { + acc.add(new long[] { root.encoded[0], 0, 0 }); + return; + } + + final long ID = root.encoded[0]; + final long lID = root.getLeft() != null ? root.getLeft().encoded[0] : 0; + final long rID = root.getRight() != null ? root.getRight().encoded[0] : 0; + + acc.add(new long[] { ID, lID, rID }); + treeConvIndexesStep(root.getLeft(), acc); + treeConvIndexesStep(root.getRight(), acc); + } + + /** + * transforms a tree into a tree of preorder indexes + * + * @return + * @param idx needs to default to one. + */ + public static TreeNode preorderIndexes(final TreeNode root, final long idx) { + if (root == null) { + return null; + } + + if (root.isNullOperator()) { + return new TreeNode(new long[] { idx }, null, null); + } + + if (root.isLeaf()) { + return new TreeNode(new long[] { idx }, new TreeNode(new long[] { 0 }, null, null), + new TreeNode(new long[] { 0 }, null, null)); + } + + final TreeNode leftSubTree = root.getLeft() != null ? preorderIndexes(root.getLeft(), idx + 1) : null; + + final TreeNode rightSubTree = root.getRight() != null + ? preorderIndexes(root.getRight(), rightMost(leftSubTree) + 1) + : null; + + return new TreeNode(new long[] { idx }, leftSubTree, rightSubTree); + } + + public static long rightMost(final TreeNode root) { + if (root == null) + return 0; + + if (root.isLeaf()) { + return root.encoded[0]; + } + + if (root.getRight() == null && root.getLeft() != null) { + return rightMost(root.getLeft()); + } + + if (root.getRight().encoded[0] == 0 && root.getLeft().encoded[0] == 0) { + return root.encoded[0]; + } + + if (root.getRight().encoded[0] == 0) { + return rightMost(root.getLeft()); + } + + return rightMost(root.getRight()); + } + + /** + * @param flatTrees + * @return + */ + public static ArrayList padAndCombine(final List flatTrees) { + assert flatTrees.size() >= 1; + + final ArrayList vecs = new ArrayList<>(); + + if (flatTrees.get(0).length == 0) { + return vecs; + } + + final int secondDim = flatTrees.get(0)[0].length; + final int maxFirstDim = flatTrees.stream().mapToInt(a -> a.length).max().orElseThrow(); + + for (final long[][] tree : flatTrees) { + final long[][] padding = new long[maxFirstDim][secondDim]; + + for (int i = 0; i < tree.length; i++) { + System.arraycopy(tree[i], 0, padding[i], 0, tree[i].length); + } + + vecs.add(padding); + } + + return vecs; + } + + /** + * @param root + * @return + */ + public static long[][] flatten(final TreeNode root) { + if (root == null) { + return new long[0][0]; + } + + final ArrayList acc = new ArrayList<>(); + flattenStep(root, acc); + + acc.add(0, new long[acc.get(0).length]); + + return acc.toArray(long[][]::new); + } + + public static void flattenStep(final TreeNode v, final ArrayList acc) { + if (v == null) { + return; + } + + final long[] values = v.isNullOperator() ? OneHotEncoder.encodeNullOperator() + : Arrays.copyOf(v.encoded, v.encoded.length); + + values[0] = 0; + acc.add(values); + + if (v.isLeaf()) { + return; + } + + flattenStep(v.getLeft(), acc); + flattenStep(v.getRight(), acc); + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/TreeDecoder.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/TreeDecoder.java new file mode 100644 index 000000000..9a069b48e --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/TreeDecoder.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.encoding; + +import java.util.Arrays; +import java.util.Optional; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import org.apache.wayang.core.api.exception.WayangException; +import org.apache.wayang.core.plan.wayangplan.Operator; +import org.apache.wayang.core.plan.wayangplan.WayangPlan; +import org.apache.wayang.core.platform.Platform; + +public class TreeDecoder { + private static final Logger logger = LogManager.getLogger(TreeDecoder.class); + private final OneHotMappings mappings; + + public TreeDecoder(final TreeEncoder encoder) { + this.mappings = encoder.getMappings(); + } + + public WayangPlan decode(final String encoded) { + final TreeNode node = TreeNode.fromString(encoded); + + updateOperatorPlatforms(node); + + final Operator sink = mappings.getOperatorFromEncoding(node.encoded) + .orElseThrow(() -> new WayangException("Couldnt recover sink operator during decoding")); + + final Operator definitiveSink = sink; + + if (definitiveSink.isSink()) { + return new WayangPlan(definitiveSink); + } else { + throw new WayangException("Recovered sink operator is not a sink"); + } + } + + public WayangPlan decode(final TreeNode node) { + updateOperatorPlatforms(node); + + final Operator sink = mappings.getOperatorFromEncoding(node.encoded) + .orElseThrow(() -> new WayangException("Couldnt recover sink operator during decoding")); + + final Operator definitiveSink = sink; + + if (definitiveSink.isSink()) { + return new WayangPlan(definitiveSink); + } else { + throw new WayangException("Recovered sink operator is not a sink"); + } + } + + private void updateOperatorPlatforms(final TreeNode node) { + if (node.isNullOperator()) { + return; + } + + final Optional operator = mappings.getOperatorFromEncoding(node.encoded); + + if (operator.isPresent()) { + final Platform platform = OneHotMappings.getOperatorPlatformFromEncoding(node.encoded) + .orElseThrow(() -> new WayangException( + String.format("Couldnt recover platform for operator: %s with encoding %s", operator.get(), + Arrays.toString(node.encoded)))); + + operator.get().addTargetPlatform(platform); + } else { + logger.info("Operator couldn't be recovered, potentially conversion operator: {}", node); + } + + if (node.left != null) { + updateOperatorPlatforms(TreeNode.class.cast(node.left)); + } + + if (node.right != null) { + updateOperatorPlatforms(TreeNode.class.cast(node.right)); + } + } +} + \ No newline at end of file diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/TreeEncoder.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/TreeEncoder.java new file mode 100644 index 000000000..cff4430f2 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/TreeEncoder.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.encoding; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Queue; + +import org.apache.wayang.core.api.exception.WayangException; +import org.apache.wayang.core.optimizer.OptimizationContext; +import org.apache.wayang.core.optimizer.enumeration.PlanImplementation; +import org.apache.wayang.core.plan.executionplan.ExecutionTask; +import org.apache.wayang.core.plan.wayangplan.ExecutionOperator; +import org.apache.wayang.core.plan.wayangplan.InputSlot; +import org.apache.wayang.core.plan.wayangplan.Operator; +import org.apache.wayang.core.plan.wayangplan.OperatorAlternative; +import org.apache.wayang.core.plan.wayangplan.OutputSlot; +import org.apache.wayang.core.plan.wayangplan.WayangPlan; +import org.apache.wayang.core.platform.Junction; + +public class TreeEncoder { + private final OneHotMappings mappings; + + public TreeEncoder(final OneHotMappings mappings) { + this.mappings = mappings; + } + + public OneHotMappings getMappings() { + return this.mappings; + } + + public TreeNode encode(final PlanImplementation plan) { + final List result = new ArrayList(); + + final HashMap> tree = new HashMap<>(); + final List sinks = plan.getOperators().stream().filter(Operator::isSink).toList(); + + final Map, Junction> junctions = plan.getJunctions(); + + // TODO: convert to config + final boolean encodeIds = false; + + for (final Operator sink : sinks) { + final TreeNode sinkNode = traversePIOperator(sink, plan.getOptimizationContext(), encodeIds, junctions, + tree); + result.add(sinkNode); + } + + if (result.size() == 0) { + return null; + } + + final TreeNode resultNode = result.get(0); + resultNode.rebalance(); + + return resultNode; + } + + public TreeNode encode(final WayangPlan plan, final OptimizationContext optimizationContext, + final boolean encodeIds) { + final List result = new ArrayList(); + plan.prune(); + + final HashMap> tree = new HashMap<>(); + final Collection sinks = plan.getSinks(); + + for (final Operator sink : sinks) { + final TreeNode sinkNode = traverse(sink, tree, optimizationContext, encodeIds); + result.add(sinkNode); + } + + if (result.size() == 0) { + return null; + } + + assert result.size() == 1 : "result size was not 1"; + + final TreeNode resultNode = result.get(0); + + // rebalance to make it a guaranteed binary tree + resultNode.rebalance(); + + return resultNode; + } + + private TreeNode traversePIOperator(final Operator current, final OptimizationContext optimizationContext, + final boolean encodeIds, final Map, Junction> junctions, + final HashMap> visited) { + if (visited.containsKey(current)) { + return null; + } + + final TreeNode currentNode = new TreeNode(); + + if (current.isAlternative()) { + final Operator original = ((OperatorAlternative) current).getAlternatives().get(0).getContainedOperators() + .stream().findFirst() + .orElseThrow(() -> new WayangException("Operator could not be retrieved from Alternatives")); + mappings.addOriginalOperator(original); + + currentNode.encoded = OneHotEncoder.encodeOperator(original, optimizationContext, encodeIds); + } else { + mappings.addOriginalOperator(current); + + if (current.isExecutionOperator()) { + currentNode.encoded = OneHotEncoder.encodeOperator((ExecutionOperator) current, optimizationContext, + encodeIds); + } else { + currentNode.encoded = OneHotEncoder.encodeOperator(current, optimizationContext, encodeIds); + } + } + + final Collection currentJunctions = junctions.values().stream().filter(junction -> { + for (final InputSlot input : current.getAllInputs()) { + if (junction.getTargetInputs().contains(input)) { + return true; + } + } + + return false; + }).toList(); + + final Collection inputs = currentJunctions.stream().map(Junction::getSourceOperator) + .toList(); + + for (final Operator input : inputs) { + TreeNode next; + final Collection conversions = currentJunctions.stream() + .filter(junction -> junction.getSourceOperator() == input) + .flatMap(junction -> junction.getConversionTasks().stream()).toList(); + + // fit conversions in between current and its inputs + if (conversions.size() > 0) { + final Queue conversionQueue = new LinkedList<>(); + conversionQueue.addAll(conversions); + + next = traverseWithNext(conversionQueue, junctions, visited, input, optimizationContext, encodeIds); + } else { + next = traversePIOperator(input, optimizationContext, encodeIds, junctions, visited); + } + + if (currentNode.left == null) { + currentNode.left = next; + } else { + currentNode.right = next; + } + } + + return currentNode; + } + + private TreeNode traverseWithNext(final Queue conversions, + final Map, Junction> junctions, final HashMap> visited, + final Operator next, final OptimizationContext optimizationContext, final boolean encodeIds) { + if (visited.containsKey(next)) { + return null; + } + + if (conversions.isEmpty()) { + return traversePIOperator(next, optimizationContext, encodeIds, junctions, visited); + } + + final ExecutionTask currentTask = conversions.poll(); + final ExecutionOperator current = currentTask.getOperator(); + final TreeNode currentNode = new TreeNode(); + + if (current.isAlternative()) { + final Operator original = ((OperatorAlternative) current).getAlternatives().get(0).getContainedOperators() + .stream().findFirst() + .orElseThrow(() -> new WayangException("Operator could not be retrieved from Alternatives")); + mappings.addOriginalOperator(original); + + currentNode.encoded = OneHotEncoder.encodeOperator(original, optimizationContext, encodeIds); + currentNode.operator = original; + } else { + mappings.addOriginalOperator(current); + currentNode.operator = current; + + if (current.isExecutionOperator()) { + currentNode.encoded = OneHotEncoder.encodeOperator((ExecutionOperator) current, optimizationContext, + encodeIds); + } else { + currentNode.encoded = OneHotEncoder.encodeOperator(current, optimizationContext, encodeIds); + } + } + + final TreeNode nextNode = traverseWithNext(conversions, junctions, visited, next, optimizationContext, + encodeIds); + + if (currentNode.left == null) { + currentNode.left = nextNode; + } else { + currentNode.right = nextNode; + } + + return currentNode; + } + + private TreeNode traverse(final Operator current, final HashMap> visited, + final OptimizationContext optimizationContext, final boolean encodeIds) { + if (visited.containsKey(current)) { + return null; + } + + final TreeNode currentNode = new TreeNode(); + + if (current.isAlternative()) { + final Operator original = ((OperatorAlternative) current).getAlternatives().get(0).getContainedOperators() + .stream().findFirst() + .orElseThrow(() -> new WayangException("Operator could not be retrieved from Alternatives")); + mappings.addOriginalOperator(original); + + currentNode.encoded = OneHotEncoder.encodeOperator(original, optimizationContext, encodeIds); + currentNode.operator = original; + } else { + mappings.addOriginalOperator(current); + currentNode.operator = current; + + if (current.isExecutionOperator()) { + currentNode.encoded = OneHotEncoder.encodeOperator((ExecutionOperator) current, optimizationContext, + encodeIds); + } else { + currentNode.encoded = OneHotEncoder.encodeOperator(current, optimizationContext, encodeIds); + } + } + + // Add for later reconstruction in TreeDecoder + final List inputs = Arrays.stream(current.getAllInputs()).filter(input -> input.getOccupant() != null) + .map(input -> input.getOccupant().getOwner()).toList(); + + for (final Operator input : inputs) { + final TreeNode next = traverse(input, visited, optimizationContext, encodeIds); + + if (currentNode.getLeft() == null) { + currentNode.left = next; + } else { + currentNode.right = next; + } + } + + return currentNode; + } + +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/TreeNode.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/TreeNode.java new file mode 100644 index 000000000..4152734fb --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/encoding/TreeNode.java @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.encoding; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.wayang.core.plan.wayangplan.Operator; + +public class TreeNode { + private static final int PADDING_SIZE = 1; + + private static final Pattern pattern = Pattern.compile( + "\\(\\((?[+,-]?\\d+(?:,\\s*\\d+)*)\\),(?(?\\s*\\(.+\\)),(?\\s*\\(.+\\))|\\)*)", + Pattern.CASE_INSENSITIVE); + + public static TreeNode fromString(final String encoded) { + final TreeNode result = new TreeNode(); + final Matcher matcher = pattern.matcher(encoded); + + if (!matcher.find()) { + return null; + } + + final String value = matcher.group("value"); + final String left = matcher.group("left"); + final String right = matcher.group("right"); + final long[] encodedLongs = Arrays.stream(value.split(",")) + .map(String::trim) + .mapToLong(Long::parseLong) + .toArray(); + + // ignore if no platform choices given + if (Arrays.stream(encodedLongs).sum() == 0L) { + return null; + } + + result.encoded = encodedLongs; + + if (left != null) { + result.left = TreeNode.fromString(left); + } + + if (right != null) { + result.right = TreeNode.fromString(right); + } + + return result; + } + + public static TreeNode create() { + return new TreeNode(); + } + + public long[] encoded; + + public TreeNode left; + + public TreeNode right; + + public Operator operator; + + public TreeNode() { + this.operator = null; + this.encoded = OneHotEncoder.encodeNullOperator(); + this.left = null; + this.right = null; + } + + public TreeNode(final long[] encoded, final TreeNode left, final TreeNode right) { + this.operator = null; + this.encoded = encoded; + this.left = left; + this.right = right; + } + + public TreeNode(final Operator operator, final long[] encoded, final TreeNode left, final TreeNode right) { + this.operator = operator; + this.encoded = encoded; + this.left = left; + this.right = right; + } + + /* + * Utility function to rebalance the tree to a guaranteed BinaryTree + * + * @return void + */ + public void rebalance() { + if (this.isLeaf()) { + this.left = TreeNode.create(); + this.right = TreeNode.create(); + return; + } + + if (this.left != null) { + this.left.rebalance(); + } + + if (this.right != null) { + this.right.rebalance(); + } + + if (this.left == null && this.right != null) { + this.left = TreeNode.create(); + } + + if (this.left != null && this.right == null) { + this.right = TreeNode.create(); + } + } + + public TreeNode getLeft() { + return this.left; + } + + public TreeNode getRight() { + return this.right; + } + + public String display() { + return Long.toString(this.encoded[0]); + } + + public String toStringEncoding() { + final String encodedString = Arrays.toString(encoded).replace("[", "(").replace("]", ")").replaceAll("\\s+", + ""); + + if (this.getLeft() == null && this.getRight() == null) { + return '(' + encodedString + ",)"; + } + + String leftString = ""; + + if (this.getLeft() != null) { + final TreeNode castLeft = this.getLeft(); + + if (castLeft.isNullOperator()) { + leftString = Arrays.toString(OneHotEncoder.encodeNullOperator()).replace("[", "((").replace("]", "),)") + .replaceAll("\\s+", ""); + } else { + leftString = castLeft.toStringEncoding(); + } + } + + String rightString = ""; + + if (this.getRight() != null) { + final TreeNode castRight = this.getRight(); + + if (castRight.isNullOperator()) { + rightString = Arrays.toString(OneHotEncoder.encodeNullOperator()).replace("[", "((").replace("]", "),)") + .replaceAll("\\s+", ""); + } else { + rightString = castRight.toStringEncoding(); + } + } + + return "(" + encodedString + "," + leftString + "," + rightString + ")"; + } + + public TreeNode withIdsFrom(final TreeNode node) { + this.encoded[0] = node.encoded[0]; + + if (this.getLeft() != null && node.getLeft() != null) { + this.left = this.getLeft().withIdsFrom(node.getLeft()); + } + + if (this.getRight() != null && node.getRight() != null) { + this.right = this.getRight().withIdsFrom(node.getRight()); + } + + return this; + } + + public TreeNode withPlatformChoicesFrom(final TreeNode node) { + if (this.isNullOperator()) { + return this; + } + + if (this.encoded == OneHotEncoder.encodeNullOperator()) { + return this; + } + + if (node.encoded == null) { + assert this.encoded != null; + return this; + } + final HashMap platformMappings = OneHotMappings.getPlatformsMapping(); + final HashMap operatorMappings = OneHotMappings.getOperatorMapping(); + final int operatorsCount = operatorMappings.size(); + final int platformsCount = platformMappings.size(); + + if (this.encoded.length > 0) { + // Check if this already encodes a platform specific operator + final long[] platformChoices = Arrays.copyOfRange(this.encoded, PADDING_SIZE + operatorsCount, + PADDING_SIZE + operatorsCount + platformsCount); + + if (ArrayUtils.indexOf(platformChoices, 1) != -1) { + return this; + } + } + + int platformPosition = -1; + platformPosition = ArrayUtils.indexOf(node.encoded, 1); + String platform = ""; + + assert platformPosition >= 0; + + for (final Map.Entry pair : platformMappings.entrySet()) { + if (pair.getValue() == platformPosition) { + platform = pair.getKey(); + } + } + + assert platform != ""; + + this.encoded[PADDING_SIZE + operatorsCount + platformPosition] = 1; + + if (this.getLeft() != null && node.getLeft() != null) { + this.left = this.getLeft().withPlatformChoicesFrom(node.getLeft()); + } + + if (this.getRight() != null && node.getRight() != null) { + this.right = this.getRight().withPlatformChoicesFrom(node.getRight()); + } + + return this; + } + + public void softmax() { + if (this.encoded == null || this.encoded == OneHotEncoder.encodeNullOperator()) { + return; + } + + // All set to 1, aka null operator + if (Arrays.stream(this.encoded).sum() == 9) { + return; + } + + final long maxValue = Arrays.stream(this.encoded).max().getAsLong(); + final long[] values = Arrays.stream(this.encoded).map(value -> value == maxValue ? 1 : 0).toArray(); + + this.encoded = values; + + if (this.getLeft() != null) { + this.getLeft().softmax(); + } + + if (this.getRight() != null) { + this.getRight().softmax(); + } + } + + public boolean isNullOperator() { + return this.operator == null && Arrays.equals(this.encoded, OneHotEncoder.encodeNullOperator()); + } + + /* + * Utility function for tree traversal without return value. Can be used for + * mutation. + * + * @param Consumer> func + * + * @return void + */ + public void traverse(final Consumer func) { + func.accept(this); + + if (this.isLeaf()) { + return; + } + + if (this.left != null) { + this.left.traverse(func); + } + + if (this.right != null) { + this.right.traverse(func); + } + } + + public boolean isLeaf() { + return this.left == null && this.right == null; + } + + public TreeNode getNode(final int index) { + final List nodes = new ArrayList<>(); + + nodes.add(new TreeNode()); + + this.traverse(nodes::add); + + return nodes.get(index); + } + + public int getNumberOfNodes() { + final List nodes = new ArrayList<>(); + + // Add null operator + nodes.add(new TreeNode()); + + this.traverse(nodes::add); + + return nodes.size(); + } + + public int size() { + int size = 1; + + if (this.isLeaf()) { + return 1; + } + + if (left != null) { + size += this.left.size(); + } + + if (right != null) { + size += this.right.size(); + } + + return size; + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/CardinalitySampler.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/CardinalitySampler.java new file mode 100644 index 000000000..e662c0962 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/CardinalitySampler.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.util; + +import java.io.File; +import java.io.FileWriter; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.exception.WayangException; +import org.apache.wayang.core.util.JsonSerializables; +import org.apache.wayang.core.util.json.WayangJsonObj; + +public class CardinalitySampler { + + public static List samples = new ArrayList<>(); + + public static void configureWriteToFile( + final Configuration config, + final String filePath){ + config.setProperty("wayang.core.log.enabled", "true"); + config.setProperty("wayang.core.log.cardinalities", filePath); + config.setProperty("wayang.core.optimizer.instrumentation", "org.apache.wayang.core.profiling.FullInstrumentationStrategy"); + + // clear previous measurements from file + try { + final File f = new File(filePath); + if(f.exists() && !f.isDirectory()) { + new FileWriter(filePath, false).close(); + } + } catch (final Exception e) { + e.printStackTrace(); + } + } + + public static void readFromFile(final String filePath) { + try { + final SampledCardinality.Serializer serializer = new SampledCardinality.Serializer(); + samples = Files.lines(Path.of(filePath), Charset.forName("UTF-8")) + .map(line -> { + try { + return JsonSerializables.deserialize(new WayangJsonObj(line), serializer, SampledCardinality.class); + } catch (final Exception e) { + System.out.println("Exception: " + e); + throw new WayangException(String.format("Could not parse \"%s\".", new WayangJsonObj(line).getNode()), e); + } + }).collect(Collectors.toList()); + } catch(final Exception e) { + e.printStackTrace(); + } + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/EnumerationStrategy.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/EnumerationStrategy.java new file mode 100644 index 000000000..2264c15e6 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/EnumerationStrategy.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.util; + +public enum EnumerationStrategy { + NONE, + PAIRWISE, + LISTWISE; +} + diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/Logging.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/Logging.java new file mode 100644 index 000000000..3ec5845d0 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/Logging.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.util; + +import java.io.BufferedWriter; +import java.io.FileWriter; + +public class Logging { + public static void writeToFile(final String content, final String path) { + try { + final FileWriter fw = new FileWriter( + path, + true + ); + final BufferedWriter writer = new BufferedWriter(fw); + + writer.write(content); + writer.newLine(); + writer.flush(); + writer.close(); + } catch(final Exception e) { + e.printStackTrace(); + } + } +} + diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/Operators.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/Operators.java new file mode 100644 index 000000000..1b2c3c4b4 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/Operators.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.util; + +import org.apache.wayang.core.plan.wayangplan.OperatorBase; +import org.reflections.Reflections; + +import java.util.stream.Collectors; +import java.util.Set; +import java.util.Comparator; + +public class Operators { + public static Set> getOperators() { + final Reflections reflections = new Reflections("org.apache.wayang.basic.operators"); + final Set> basics = reflections.getSubTypesOf(OperatorBase.class); + + final Reflections coreReflections = new Reflections("org.apache.wayang.core.plan.wayangplan"); + final Set> core = coreReflections.getSubTypesOf(OperatorBase.class); + + basics.addAll(core); + return basics; + } + + public static Set> getPlatformOperators(final String namespace) { + final Reflections reflections = new Reflections(namespace + ".operators"); + return reflections.getSubTypesOf(OperatorBase.class) + .stream() + .filter(operator -> operator.getName().contains(namespace)) + .distinct() + .sorted(Comparator.comparing(Class::getName)) + .collect(Collectors.toSet()); + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/Platforms.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/Platforms.java new file mode 100644 index 000000000..659e8e17e --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/Platforms.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.util; + +import org.apache.wayang.core.platform.Platform; +import org.apache.wayang.jdbc.platform.JdbcPlatformTemplate; +import org.apache.wayang.sqlite3.platform.Sqlite3Platform; +import org.apache.wayang.giraph.platform.GiraphPlatform; +import org.apache.wayang.genericjdbc.platform.GenericJdbcPlatform; +import org.apache.wayang.tensorflow.platform.TensorflowPlatform; +import org.reflections.Reflections; + +import java.util.Set; +import java.util.HashSet; + +public class Platforms { + public static Set> getPlatforms() { + final Reflections reflections = new Reflections("org.apache.wayang"); + final Set> platforms = reflections.getSubTypesOf(Platform.class); + + final Set> disallowedPlatforms = new HashSet<>(); + disallowedPlatforms.add(JdbcPlatformTemplate.class); + disallowedPlatforms.add(Sqlite3Platform.class); + disallowedPlatforms.add(GiraphPlatform.class); + disallowedPlatforms.add(GenericJdbcPlatform.class); + disallowedPlatforms.add(TensorflowPlatform.class); + + platforms.removeAll(disallowedPlatforms); + + return platforms; + } + + public static String getNamespace(final String platformName) { + final String[] exploded = platformName.split("\\."); + final StringBuilder strBuilder = new StringBuilder(); + for (int i = 0; i < 4; i++) { + strBuilder.append(exploded[i]); + if (i != 3) { + strBuilder.append("."); + } + } + + return strBuilder.toString(); + } +} + diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/SampledCardinality.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/SampledCardinality.java new file mode 100644 index 000000000..3a6411091 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/util/SampledCardinality.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.util; + +import org.apache.wayang.core.util.JsonSerializer; +import org.apache.wayang.core.util.json.WayangJsonArray; +import org.apache.wayang.core.util.json.WayangJsonObj; + +public class SampledCardinality { + public static class Serializer implements JsonSerializer { + + @Override + public WayangJsonObj serialize(final SampledCardinality sample) { + return new WayangJsonObj(); + } + + @Override + public SampledCardinality deserialize(final WayangJsonObj json, final Class cls) { + final WayangJsonArray inputs = json.getJSONArray("inputs"); + final WayangJsonObj operator = json.getJSONObject("operator"); + final WayangJsonObj output = json.getJSONObject("output"); + + return new SampledCardinality( + inputs, operator, output + ); + } + } + + private final WayangJsonArray inputs; + + private final WayangJsonObj operator; + + private final WayangJsonObj output; + + public SampledCardinality( + final WayangJsonArray inputs, + final WayangJsonObj operator, + final WayangJsonObj output + ) { + this.inputs = inputs; + this.operator = operator; + this.output = output; + } + + public WayangJsonArray getInputs() { + return this.inputs; + } + public WayangJsonObj getOperator() { + return this.operator; + } + + public WayangJsonObj getOutput() { + return this.output; + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/BitmaskValidationRule.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/BitmaskValidationRule.java new file mode 100644 index 000000000..694d6ad43 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/BitmaskValidationRule.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.validation; + +import java.util.Set; + +import org.apache.wayang.ml.encoding.TreeNode; + +/** + * ValidationRule to forbid certain platforms when input has not been on + * Postgres before + */ +public class BitmaskValidationRule implements ValidationRule { + /* + * Index of disallowed platform choices + */ + private final Set disallowed = Set.of(0, 1); + + public BitmaskValidationRule() { + } + + public void validate(final Float[][] choices, final long[][][] indexes, final TreeNode tree) { + // Start at 1, 0th platform choice is for null operators + for (int i = 1; i < choices.length; i++) { + for (final Integer disallowedId : disallowed) { + choices[i][disallowedId] = -Float.MAX_VALUE; + } + } + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/OperatorValidationRule.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/OperatorValidationRule.java new file mode 100644 index 000000000..f1ebb598d --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/OperatorValidationRule.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.validation; + +import org.apache.wayang.basic.operators.TextFileSource; +import org.apache.wayang.ml.encoding.TreeNode; +import org.apache.wayang.postgres.operators.PostgresTableSource; + +/** + * ValidationRule to forbid certain platforms when an operator doesn't exist for + * that platform + */ +public class OperatorValidationRule implements ValidationRule { + + private final int postgresIndex = 3; + + public OperatorValidationRule() { + } + + public void validate(final Float[][] choices, final long[][][] indexes, final TreeNode tree) { + // Start at 1, 0th platform choice is for null operators + for (int i = 1; i < tree.getNumberOfNodes(); i++) { + final TreeNode node = (TreeNode) tree.getNode(i); + + if (node != null && !node.isNullOperator()) { + + // Prevent TextFileSources from being in postgres + if (node.operator instanceof TextFileSource) { + choices[i][postgresIndex] = -Float.MAX_VALUE; + } + + // Prevent TextFileSources from being outside of postgres + if (node.operator instanceof PostgresTableSource) { + choices[i][postgresIndex] = Float.MAX_VALUE; + } + } + } + } + +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/PlatformChoiceValidator.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/PlatformChoiceValidator.java new file mode 100644 index 000000000..014fa40a3 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/PlatformChoiceValidator.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.validation; + +import java.util.Arrays; +import java.util.Comparator; + +import org.apache.wayang.ml.encoding.TreeNode; +/** + * Class used for enforcing validation rules on given platform choices + */ +public class PlatformChoiceValidator { + + public static long[][] validate( + final float[][][] tensor, + final long[][][] indexes, + final TreeNode tree, + final ValidationRule... rules + ) { + final Float[][] transposed = transpose(tensor); + + for (final ValidationRule rule : rules) { + rule.validate(transposed, indexes, tree); + } + + return getPlatformChoices(transposed); + } + + public static Float[][] transpose(final float[][][] tensor) { + final int cols = tensor[0][0].length; + final int rows = tensor[0].length; + final Float[][] transposed = new Float[cols][rows]; + + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + transposed[j][i] = tensor[0][i][j]; + } + } + + return transposed; + } + + public static long[][] getPlatformChoices(final Float[][] transposed) { + return Arrays.stream(transposed) + .map(row -> { + final Float max = Arrays.stream(row).max(Comparator.naturalOrder()).orElse(-Float.MAX_VALUE); + final long[] result = Arrays.stream(row) + .mapToLong(v -> v.equals(max) ? 1L : 0L) + .toArray(); + + return result; + }) + .toArray(long[][]::new); + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/PostgresSourceValidationRule.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/PostgresSourceValidationRule.java new file mode 100644 index 000000000..9b94cda3c --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/PostgresSourceValidationRule.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.validation; + +import org.apache.wayang.core.util.Tuple; +import org.apache.wayang.ml.encoding.TreeNode; + +import com.google.common.primitives.Longs; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.Optional; + +/** + * ValidationRule to forbid going to Postgres when input has not been on + * Postgres before + */ +public class PostgresSourceValidationRule implements ValidationRule { + private static int indexOfMax(final Float[] array) { + if (array == null || array.length == 0) { + throw new IllegalArgumentException("Array must not be null or empty"); + } + int maxIndex = 0; + Float maxValue = array[0]; + + for (int i = 1; i < array.length; i++) { + if (array[i] > maxValue) { + maxValue = array[i]; + maxIndex = i; + } + } + return maxIndex; + } + + /* + * Index of platform choice for Postgres + */ + private final int postgresIndex = 3; + + public PostgresSourceValidationRule() { + } + + public void validate(final Float[][] choices, final long[][][] indexes, final TreeNode tree) { + // Start at 1, 0th platform choice is for null operators + for (int i = 1; i < choices.length; i++) { + final Float max = Arrays.stream(choices[i]).max(Comparator.naturalOrder()).orElse(-Float.MAX_VALUE); + + // Check if Postgres is to be chosen + if (indexOfMax(choices[i]) == postgresIndex) { + // Check if Postgres has been chosen in one of the preceeding inputs + if (!isPostgresAllowed(i, indexes, choices, tree)) { + for (int j = 0; j < choices[i].length; j++) { + if (max.equals(choices[i][j])) { + /* + * Set this choice to zero, identifying the platform choices later will take + * care of the rest + */ + choices[i][j] = -Float.MAX_VALUE; + break; + } + } + } + } + } + } + + /* + * Helper to retrieve the input indexes from a given index + */ + private Tuple, Optional> getInputIndexes(final long index, final long[][][] indexes, final TreeNode tree) { + final long[] flatIndexTree = Arrays.stream(indexes[0]).reduce(Longs::concat).orElseThrow(); + for (int i = 0; i < flatIndexTree.length; i += 3) { + final long rootId = flatIndexTree[i]; + final long leftId = flatIndexTree[i + 1]; + final long rightId = flatIndexTree[i + 2]; + + if (rootId == index) { + final Optional left = (leftId == 0 || tree.getNode((int) leftId).isNullOperator()) ? Optional.empty() + : Optional.of(leftId); + final Optional right = (rightId == 0 || tree.getNode((int) rightId).isNullOperator()) ? Optional.empty() + : Optional.of(rightId); + + // Optional left = leftId == 0 ? Optional.empty() : Optional.of(leftId); + // Optional right = rightId == 0 ? Optional.empty() : + // Optional.of(rightId); + + return new Tuple<>(left, right); + } + } + + return new Tuple<>(Optional.empty(), Optional.empty()); + } + + private boolean isPostgresAllowed(final int index, final long[][][] indexes, final Float[][] choices, final TreeNode tree) { + // Check if current operator choice is on PostgreSQL + if (indexOfMax(choices[index]) == postgresIndex) { + // Check for all children recursively + final Tuple, Optional> inputIndexes = getInputIndexes((long) index, indexes, tree); + + // Recurse left + if (inputIndexes.getField0().isPresent()) { + final int leftIndex = inputIndexes.getField0().get().intValue(); + if (!isPostgresAllowed(leftIndex, indexes, choices, tree)) { + return false; + } + } + + // Recurse right + if (inputIndexes.getField1().isPresent()) { + final int rightIndex = inputIndexes.getField1().get().intValue(); + + if (!isPostgresAllowed(rightIndex, indexes, choices, tree)) { + return false; + } + } + + return true; + } + + return false; + } +} diff --git a/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/ValidationRule.java b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/ValidationRule.java new file mode 100644 index 000000000..7597f1ed7 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/java/org/apache/wayang/ml/validation/ValidationRule.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.validation; + +import org.apache.wayang.ml.encoding.TreeNode; + +/** + * Class used for specifying validation rules on given platform choices + */ +public interface ValidationRule { + public void validate(final Float[][] choices, final long[][][] indexes, final TreeNode tree); +} diff --git a/wayang-plugins/wayang-ml/src/main/resources/wayang-ml-defaults.properties b/wayang-plugins/wayang-ml/src/main/resources/wayang-ml-defaults.properties new file mode 100644 index 000000000..822169eb8 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/main/resources/wayang-ml-defaults.properties @@ -0,0 +1,76 @@ + +#/ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Configure plan enumeration pruning. +wayang.core.optimizer.pruning.strategies = org.apache.wayang.core.optimizer.enumeration.LatentOperatorPruningStrategy +# wayang.core.optimizer.pruning.strategies = org.apache.wayang.core.optimizer.enumeration.TopKPruningStrategy +# wayang.core.optimizer.pruning.topk = 5 +# wayang.core.optimizer.channels.selection = org.apache.wayang.core.optimizer.channels.ChannelConversionGraph$CostbasedTreeSelectionStrategy +# wayang.core.optimizer.instrumentation = org.apache.wayang.core.profiling.OutboundInstrumentationStrategy +wayang.core.optimizer.enumeration.concatenationprio = plans2 +wayang.core.optimizer.enumeration.invertconcatenations = false +wayang.core.optimizer.enumeration.branchesfirst = false + +# Configure statistics collection. +wayang.core.log.enabled = true +# wayang.core.log.cardinalities = ~/.wayang/cardinalities.json +# wayang.core.log.executions = ~/.wayang/executions.json + +# Configure re-optimization. +wayang.core.optimizer.reoptimize = false +wayang.core.optimizer.reoptimize.proactive = false +wayang.core.optimizer.cardinality.maxspread = 10 +wayang.core.optimizer.cardinality.spreadsmoothing = 10000 +wayang.core.optimizer.cardinality.minconfidence = 0.5 + +# Settings for aggressive re-optimization. +#wayang.core.optimizer.instrumentation = org.apache.wayang.core.profiling.FullInstrumentationStrategy +#wayang.core.optimizer.reoptimize = true +#wayang.core.optimizer.reoptimize.proactive = true +#wayang.core.optimizer.cardinality.maxspread = 1 +#wayang.core.optimizer.cardinality.spreadsmoothing = 1 +#wayang.core.optimizer.cardinality.minconfidence = 1 + +# Configure fallback estimates. +wayang.core.fallback.udf.cpu.lower = 100 +wayang.core.fallback.udf.cpu.upper = 1000 +wayang.core.fallback.udf.cpu.confidence = 0.2 +wayang.core.fallback.udf.ram.lower = 100 +wayang.core.fallback.udf.ram.upper = 1000 +wayang.core.fallback.udf.ram.confidence = 0.2 +wayang.core.fallback.operator.cpu.lower = 100 +wayang.core.fallback.operator.cpu.upper = 1000 +wayang.core.fallback.operator.cpu.confidence = 0.2 +wayang.core.fallback.operator.ram.lower = 100 +wayang.core.fallback.operator.ram.upper = 1000 +wayang.core.fallback.operator.ram.confidence = 0.2 + +# Configure Monitor. +wayang.core.monitor.enabled = false + +# Configure parallelism. +wayang.core.optimizer.enumeration.parallel-tasks = false + +# Configure average input size +wayang.ml.tuple.average-size = 100 +wayang.ml.model.file = /wayang-plugins/wayang-ml/src/main/resources/linear_model.onnx +wayang.ml.experience.enabled = false +wayang.ml.executions.file = /var/www/html/data/executions.txt +wayang.ml.optimizations.file = /var/www/html/data/optmizations.txt +wayang.ml.experience.file = /var/www/html/data/experience/experience-vae.txt +org.apache.logging.log4j.level = INFO diff --git a/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/JavaExecutionMLTest.java b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/JavaExecutionMLTest.java new file mode 100644 index 000000000..1dbd6f1b9 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/JavaExecutionMLTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +import org.apache.wayang.basic.data.Tuple2; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.WayangContext; +import org.apache.wayang.core.optimizer.enumeration.PlanImplementation; +import org.apache.wayang.core.plan.wayangplan.WayangPlan; +import org.apache.wayang.java.Java; +import org.apache.wayang.ml.encoding.OneHotMappings; +import org.apache.wayang.ml.encoding.TreeEncoder; +import org.apache.wayang.ml.encoding.TreeNode; +import org.apache.wayang.spark.Spark; +import org.junit.jupiter.api.Test; + +public class JavaExecutionMLTest extends JavaExecutionTestBase { + @Test + public void testPlanImplementationEncoding() throws IOException, URISyntaxException { + final List> collector = new LinkedList<>(); + final Configuration config = new Configuration(); + final String filePath = JavaExecutionMLTest.class.getResource("/README.md").toURI().toString(); + final WayangPlan wayangPlan = createWayangPlan(filePath, collector); + final WayangContext wayangContext = new WayangContext(config); + wayangContext.register(Java.basicPlugin()); + wayangContext.register(Spark.basicPlugin()); + + final Collection planImplementations = buildPlanImplementations(wayangPlan, wayangContext); + + for (final PlanImplementation planImplementation : planImplementations) { + // Just a sanity check for determinism + final TreeEncoder encoder = new TreeEncoder(new OneHotMappings()); + final TreeNode encoded = encoder.encode(planImplementation); + assertArrayEquals(encoded.encoded, encoded.encoded); + } + } +} diff --git a/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/JavaExecutionTestBase.java b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/JavaExecutionTestBase.java new file mode 100644 index 000000000..d89edcfa1 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/JavaExecutionTestBase.java @@ -0,0 +1,160 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.test; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.Collection; + +import org.apache.wayang.basic.data.Tuple2; +import org.apache.wayang.basic.operators.FilterOperator; +import org.apache.wayang.basic.operators.FlatMapOperator; +import org.apache.wayang.basic.operators.LocalCallbackSink; +import org.apache.wayang.basic.operators.MapOperator; +import org.apache.wayang.basic.operators.ReduceByOperator; +import org.apache.wayang.basic.operators.TextFileSource; +import org.apache.wayang.commons.util.profiledb.instrumentation.StopWatch; +import org.apache.wayang.commons.util.profiledb.model.Experiment; +import org.apache.wayang.commons.util.profiledb.model.Subject; +import org.apache.wayang.commons.util.profiledb.model.measurement.TimeMeasurement; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.api.WayangContext; +import org.apache.wayang.core.function.FlatMapDescriptor; +import org.apache.wayang.core.function.ReduceDescriptor; +import org.apache.wayang.core.function.TransformationDescriptor; +import org.apache.wayang.core.optimizer.DefaultOptimizationContext; +import org.apache.wayang.core.optimizer.ProbabilisticDoubleInterval; +import org.apache.wayang.core.optimizer.enumeration.PlanEnumeration; +import org.apache.wayang.core.optimizer.enumeration.PlanEnumerator; +import org.apache.wayang.core.optimizer.enumeration.PlanImplementation; +import org.apache.wayang.core.plan.executionplan.ExecutionPlan; +import org.apache.wayang.core.plan.wayangplan.WayangPlan; +import org.apache.wayang.core.types.DataSetType; +import org.apache.wayang.core.types.DataUnitType; +import org.apache.wayang.java.execution.JavaExecutor; +import org.apache.wayang.java.operators.JavaExecutionOperator; +import org.apache.wayang.java.platform.JavaPlatform; +import org.apache.wayang.ml.costs.DefaultPointwiseCost; +import org.junit.BeforeClass; + +/** + * Superclass for tests of {@link JavaExecutionOperator}s. + */ +public class JavaExecutionTestBase { + + protected static Job job; + + protected static Configuration configuration; + + @BeforeClass + public static void init() { + configuration = new Configuration(); + configuration.setCostModel(DefaultPointwiseCost.FACTORY.makeCost()); + job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + final DefaultOptimizationContext optimizationContext = new DefaultOptimizationContext(job); + when(job.getOptimizationContext()).thenReturn(optimizationContext); + } + + protected static JavaExecutor createExecutor() { + final Job job = mock(Job.class); + when(job.getConfiguration()).thenReturn(configuration); + return new JavaExecutor(JavaPlatform.getInstance(), job); + } + + /** + * Creates the {@link WayangPlan} for the word count app. + * + * @param inputFileUrl the file whose words should be counted + */ + static WayangPlan createWayangPlan(final String inputFileUrl, final Collection> collector) + throws URISyntaxException, IOException { + // Assignment mode: none. + + final TextFileSource textFileSource = new TextFileSource(inputFileUrl); + textFileSource.setName("Load file"); + + // for each line (input) output an iterator of the words + final FlatMapOperator flatMapOperator = new FlatMapOperator<>( + new FlatMapDescriptor<>(line -> Arrays.asList(line.split("\\W+")), String.class, String.class, + new ProbabilisticDoubleInterval(100, 10000, 0.8))); + flatMapOperator.setName("Split words"); + + final FilterOperator filterOperator = new FilterOperator<>(str -> !str.isEmpty(), String.class); + filterOperator.setName("Filter empty words"); + + // for each word transform it to lowercase and output a key-value pair (word, 1) + final MapOperator> mapOperator = new MapOperator<>( + new TransformationDescriptor<>(word -> new Tuple2<>(word.toLowerCase(), 1), + DataUnitType.createBasic(String.class), DataUnitType.createBasicUnchecked(Tuple2.class)), + DataSetType.createDefault(String.class), DataSetType.createDefaultUnchecked(Tuple2.class)); + mapOperator.setName("To lower case, add counter"); + + // groupby the key (word) and add up the values (frequency) + final ReduceByOperator, String> reduceByOperator = new ReduceByOperator<>( + new TransformationDescriptor<>(pair -> pair.field0, DataUnitType.createBasicUnchecked(Tuple2.class), + DataUnitType.createBasic(String.class)), + new ReduceDescriptor<>(((a, b) -> { + a.field1 += b.field1; + return a; + }), DataUnitType.createGroupedUnchecked(Tuple2.class), DataUnitType.createBasicUnchecked(Tuple2.class)), + DataSetType.createDefaultUnchecked(Tuple2.class)); + reduceByOperator.setName("Add counters"); + + // write results to a sink + final LocalCallbackSink> sink = LocalCallbackSink.createCollectingSink(collector, + DataSetType.createDefaultUnchecked(Tuple2.class)); + sink.setName("Collect result"); + + // Build Rheem plan by connecting operators + textFileSource.connectTo(0, flatMapOperator, 0); + flatMapOperator.connectTo(0, filterOperator, 0); + filterOperator.connectTo(0, mapOperator, 0); + mapOperator.connectTo(0, reduceByOperator, 0); + reduceByOperator.connectTo(0, sink, 0); + + return new WayangPlan(sink); + } + + Collection buildPlanImplementations(final WayangPlan wayangPlan, + final WayangContext wayangContext) { + final Job job = wayangContext.createJob("encodingTestJob", wayangPlan, ""); + final ExecutionPlan baseplan = job.buildInitialExecutionPlan(); + final Experiment experiment = new Experiment("wayang-ml-test", new Subject("Wayang", "0.1")); + final StopWatch stopWatch = new StopWatch(experiment); + final TimeMeasurement optimizationRound = stopWatch.getOrCreateRound("optimization"); + final PlanEnumerator planEnumerator = new PlanEnumerator(wayangPlan, job.getOptimizationContext()); + + final TimeMeasurement enumerateMeasurment = optimizationRound.start("Create Initial Execution Plan", + "Enumerate"); + planEnumerator.setTimeMeasurement(enumerateMeasurment); + final PlanEnumeration comprehensiveEnumeration = planEnumerator.enumerate(true); + planEnumerator.setTimeMeasurement(null); + optimizationRound.stop("Create Initial Execution Plan", "Enumerate"); + + return comprehensiveEnumeration.getPlanImplementations(); + } + +} diff --git a/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/OneHotEncoderTest.java b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/OneHotEncoderTest.java new file mode 100644 index 000000000..4e0cb00b8 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/OneHotEncoderTest.java @@ -0,0 +1,68 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +import org.apache.wayang.basic.data.Tuple2; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.WayangContext; +import org.apache.wayang.core.optimizer.enumeration.PlanImplementation; +import org.apache.wayang.core.plan.wayangplan.WayangPlan; +import org.apache.wayang.java.Java; +import org.apache.wayang.ml.encoding.OneHotEncoder; +import org.apache.wayang.spark.Spark; +import org.junit.jupiter.api.Test; + +public class OneHotEncoderTest extends JavaExecutionTestBase { + @Test + public void testOneHotEncoding() throws IOException, URISyntaxException { + final List> collector = new LinkedList<>(); + final Configuration config = new Configuration(); + config.setProperty("wayang.ml.tuple.average-size", "100"); + final String filePath = JavaExecutionMLTest.class.getResource("/README.md").toURI().toString(); + final WayangPlan wayangPlan = createWayangPlan(filePath, collector); + final WayangContext wayangContext = new WayangContext(config); + wayangContext.register(Java.basicPlugin()); + wayangContext.register(Spark.basicPlugin()); + + final Collection executionPlans = buildPlanImplementations(wayangPlan, wayangContext); + + for (final PlanImplementation plan : executionPlans) { + long[] previous = null; + for (int i = 0; i < 10; i++) { + final long[] encoded = OneHotEncoder.encode(plan); + if (previous != null) { + assertArrayEquals(previous, encoded); + } else { + assertEquals(true, true); + } + previous = encoded; + } + } + } +} diff --git a/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/OneHotVectorTest.java b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/OneHotVectorTest.java new file mode 100644 index 000000000..d0e3aacc9 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/OneHotVectorTest.java @@ -0,0 +1,34 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.test; + +import org.apache.wayang.basic.operators.LocalCallbackSink; +import org.apache.wayang.ml.encoding.OneHotVector; +import org.junit.jupiter.api.Test; + +public class OneHotVectorTest { + @Test + public void testOneHotVector() { + final OneHotVector vector = new OneHotVector(); + final long[] encoded = new long[12]; + final LocalCallbackSink sink = LocalCallbackSink.createStdoutSink(Integer.class); + vector.addOperator(encoded, sink.getClass().getName()); + } +} diff --git a/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/OrtTensorEncoderTest.java b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/OrtTensorEncoderTest.java new file mode 100644 index 000000000..06ffeaf42 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/OrtTensorEncoderTest.java @@ -0,0 +1,43 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; + +import org.apache.wayang.ml.encoding.OrtTensorEncoder; +import org.junit.jupiter.api.Test; + +public class OrtTensorEncoderTest extends JavaExecutionTestBase { + + @Test + public void testTranspose() { + final ArrayList input = new ArrayList<>(); + input.add(new long[][] { { 1, 2 }, { 3, 4 } }); + + final ArrayList result = OrtTensorEncoder.transpose(input); + + assertEquals(1, result.size()); + assertArrayEquals(new long[] { 1, 3 }, result.get(0)[0]); + assertArrayEquals(new long[] { 2, 4 }, result.get(0)[1]); + } +} diff --git a/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/TreeEncoderTest.java b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/TreeEncoderTest.java new file mode 100644 index 000000000..31ea7de01 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/TreeEncoderTest.java @@ -0,0 +1,61 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.test; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.LinkedList; +import java.util.List; + +import org.apache.wayang.basic.data.Tuple2; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.Job; +import org.apache.wayang.core.api.WayangContext; +import org.apache.wayang.core.plan.executionplan.ExecutionPlan; +import org.apache.wayang.core.plan.wayangplan.WayangPlan; +import org.apache.wayang.java.Java; +import org.apache.wayang.ml.encoding.OneHotMappings; +import org.apache.wayang.ml.encoding.TreeEncoder; +import org.apache.wayang.ml.encoding.TreeNode; +import org.apache.wayang.spark.Spark; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class TreeEncoderTest extends JavaExecutionTestBase { + @Test + public void testTreeEncoding() throws IOException, URISyntaxException { + final List> collector = new LinkedList<>(); + final Configuration config = new Configuration(); + final String filePath = JavaExecutionMLTest.class.getResource("/README.md").toURI().toString(); + final WayangPlan wayangPlan = createWayangPlan(filePath, collector); + final WayangContext wayangContext = new WayangContext(config); + final Job wayangJob = wayangContext.createJob("", wayangPlan, ""); + wayangContext.register(Java.basicPlugin()); + wayangContext.register(Spark.basicPlugin()); + + final ExecutionPlan exPlan = wayangJob.buildInitialExecutionPlan(); + + final TreeEncoder encoder = new TreeEncoder(new OneHotMappings()); + final TreeNode encoded = encoder.encode(wayangPlan, wayangJob.getOptimizationContext(), false); + + Assertions.assertNotNull(exPlan); + Assertions.assertNotNull(encoded); + } +} diff --git a/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/TreeNodeTest.java b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/TreeNodeTest.java new file mode 100644 index 000000000..38eee20ec --- /dev/null +++ b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/TreeNodeTest.java @@ -0,0 +1,39 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.io.IOException; +import java.net.URISyntaxException; + +import org.apache.wayang.ml.encoding.TreeNode; +import org.junit.jupiter.api.Test; + +public class TreeNodeTest { + @Test + public void testEncodingFromString() throws IOException, URISyntaxException { + String encoded = "((0,1,2,3),((4,5,6,7), ((8,9,10,11),((12,13,14,15),((16,17,18,19),((20,21,22,23),((24,25,26,27),),((28,29,30,31),)),((32,33,34,35),)),((36,37,38,39),)),((40,41,42,43),)),((44,45,46,47),)),((48,49,50,51),))"; + encoded = encoded.replaceAll("\\s+", ""); + final TreeNode decoded = TreeNode.fromString(encoded); + + assertEquals(encoded, decoded.toStringEncoding()); + } +} diff --git a/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/WordCountIntegerationTest.java b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/WordCountIntegerationTest.java new file mode 100644 index 000000000..f070f077e --- /dev/null +++ b/wayang-plugins/wayang-ml/src/test/java/org/apache/wayang/ml/test/WordCountIntegerationTest.java @@ -0,0 +1,53 @@ +/* + * + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.wayang.ml.test; + +import java.util.LinkedList; +import java.util.List; + +import org.apache.wayang.basic.data.Tuple2; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.WayangContext; +import org.apache.wayang.core.plan.wayangplan.WayangPlan; +import org.apache.wayang.java.Java; +import org.apache.wayang.ml.costs.DefaultPointwiseCost; +import org.apache.wayang.spark.Spark; +import org.junit.jupiter.api.Test; + +public class WordCountIntegerationTest extends JavaExecutionTestBase { + @Test + void wordcount() throws Exception { + final List> collector = new LinkedList<>(); + final Configuration config = new Configuration(); + + final String modelPath = WordCountIntegerationTest.class.getResource("/cost_model.onnx").getPath(); + config.setProperty("wayang.ml.model.file", modelPath); + + config.setCostModel(new DefaultPointwiseCost.Factory().makeCost()); + final String filePath = JavaExecutionMLTest.class.getResource("/README.md").toURI().toString(); + final WayangPlan wayangPlan = createWayangPlan(filePath, collector); + final WayangContext wayangContext = new WayangContext(config); + + wayangContext.register(Java.basicPlugin()); + wayangContext.register(Spark.basicPlugin()); + + wayangContext.execute(wayangPlan); + } +} diff --git a/wayang-plugins/wayang-ml/src/test/resources/README.md b/wayang-plugins/wayang-ml/src/test/resources/README.md new file mode 100644 index 000000000..b64b9eec0 --- /dev/null +++ b/wayang-plugins/wayang-ml/src/test/resources/README.md @@ -0,0 +1,279 @@ + + +# Apache Wayang™ Wayang Logo + +## The first open-source cross-platform data processing system + +[![Maven central](https://img.shields.io/maven-central/v/org.apache.wayang/wayang-core.svg?style=for-the-badge)](https://img.shields.io/maven-central/v/org.apache.wayang/wayang-core.svg) +[![License](https://img.shields.io/github/license/apache/incubator-wayang.svg?style=for-the-badge)](http://www.apache.org/licenses/LICENSE-2.0) +[![Last commit](https://img.shields.io/github/last-commit/apache/incubator-wayang.svg?style=for-the-badge)]() +![GitHub commit activity (branch)](https://img.shields.io/github/commit-activity/m/apache/incubator-wayang?style=for-the-badge) +![GitHub forks](https://img.shields.io/github/forks/apache/incubator-wayang?style=for-the-badge) +![GitHub Repo stars](https://img.shields.io/github/stars/apache/incubator-wayang?style=for-the-badge) + +[![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/intent/tweet?text=Apache%20Wayang%20enables%20cross%20platform%20data%20processing,%20star%20it%20via:%20&url=https://github.com/apache/incubator-wayang&via=apachewayang&hashtags=dataprocessing,bigdata,analytics,hybridcloud,developers) [![Subreddit subscribers](https://img.shields.io/reddit/subreddit-subscribers/ApacheWayang?style=social)](https://www.reddit.com/r/ApacheWayang/) +## Table of contents + * [Description](#description) + * [Quick Guide for Running Wayang](#quick-guide-for-running-wayang) + * [Quick Guide for Developing with Wayang](#quick-guide-for-developing-with-wayang) + * [Installing Wayang](#installing-wayang) + + [Requirements at Runtime](#requirements-at-runtime) + + [Validating the installation](#validating-the-installation) + * [Getting Started](#getting-started) + + [Prerequisites](#prerequisites) + + [Building](#building) + * [Running the tests](#running-the-tests) + * [Example Applications](#example-applications) + * [Built With](#built-with) + * [Contributing](#contributing) + * [Authors](#authors) + * [License](#license) + +## Description + +In contrast to traditional data processing systems that provide one dedicated execution engine, Apache Wayang can transparently and seamlessly integrate multiple execution engines and use them to perform a single task. We call this *cross-platform data processing*. In Wayang, users can specify any data processing application using one of Wayang's APIs and then Wayang can choose the data processing platform(s), e.g., Postgres or Apache Spark, that best fits the application. Finally, Wayang will orchestrate the execution, thereby hiding the different platform-specific APIs and coordinating inter-platform communication. + +Apache Wayang aims at freeing data engineers and software developers from the burden of learning all different data processing systems, their APIs, strengths and weaknesses; the intricacies of coordinating and integrating different processing platforms; and the inflexibility when trying a fixed set of processing platforms. As of now, Wayang has built-in support for the following processing platforms: +- [Java Streams](https://docs.oracle.com/javase/8/docs/api/java/util/stream/Stream.html) +- [Apache Spark](https://spark.apache.org/) +- [Apache Flink](https://flink.apache.org/) +- [Apache Giraph](https://giraph.apache.org/) +- [Postgres](http://www.postgresql.org) +- [SQLite](https://www.sqlite.org/) +- [Apache Kafka](https://kafka.apache.org) +- [Tensorflow](https://www.tensorflow.org/) + +Apache Wayang can be used via the following APIs: +- Java scala-like +- Scala +- SQL +- Java native (recommended only for low level development) + +Apache Wayang provides a flexible architecture which enables easy addition of new operators and data processing platforms without requiring any change of the internals of the system. For details on how to add new operators, see [here](https://wayang.apache.org/docs/guide/adding-operators). + +## Quick Guide for Running Wayang + +For a quick guide on how to run WordCount see [here](guides/tutorial.md). + +### Spark Dataset / DataFrame pipelines + +Wayang’s Spark platform can now execute end-to-end pipelines on Spark `Dataset[Row]` (aka DataFrames). This is particularly useful when working with lakehouse-style storage (Parquet/Delta) or when you want to plug Spark ML stages into a Wayang plan without repeatedly falling back to RDDs. + +To build a Dataset-backed pipeline: + +1. **Use the Dataset-aware plan builder APIs.** + - `PlanBuilder.readParquet(..., preferDataset = true)` (or `JavaPlanBuilder.readParquet(..., ..., true)`) reads Parquet files directly into a Dataset channel. + - `DataQuanta.writeParquet(..., preferDataset = true)` writes a Dataset channel without converting it back to an RDD. +2. **Keep operators dataset-compatible.** Most operators continue to work unchanged; if an operator explicitly prefers RDDs, Wayang will insert the necessary conversions automatically (at an additional cost). Custom operators can expose `DatasetChannel` descriptors to stay in the dataframe world. +3. **Let the optimizer do the rest.** The optimizer now assigns a higher cost to Dataset↔RDD conversions, so once you opt into Dataset sources/sinks the plan will stay in Dataset form by default. + +No extra flags are required—just opt into the Dataset-based APIs where you want dataframe semantics. If you see unexpected conversions in your execution plan, check that the upstream/downstream operators you use can consume `DatasetChannel`s; otherwise Wayang will insert a conversion operator for you. + +## Quick Guide for Developing with Wayang + +For a quick guide on how to use Wayang in your Java/Scala project see [here](guides/develop-with-Wayang.md). + +## Installing Wayang + +You first have to build the binaries as shown [here](guides/tutorial.md). +Once you have the binaries built, follow these steps to install Wayang: + +```shell +tar -xvf wayang-1.0.1-SNAPSHOT.tar.gz +cd wayang-1.0.1-SNAPSHOT +``` + +In linux +```shell +echo "export WAYANG_HOME=$(pwd)" >> ~/.bashrc +echo "export PATH=${PATH}:${WAYANG_HOME}/bin" >> ~/.bashrc +source ~/.bashrc +``` +In MacOS +```shell +echo "export WAYANG_HOME=$(pwd)" >> ~/.zshrc +echo "export PATH=${PATH}:${WAYANG_HOME}/bin" >> ~/.zshrc +source ~/.zshrc +``` + +### Requirements at Runtime + +Apache Wayang relies on external execution engines and Java to function correctly. Below are the updated runtime requirements: + +- **Java 17**: Make sure `JAVA_HOME` is correctly set to your Java 17 installation. +- **Apache Spark 3.4.4**: Compatible with Scala 2.12. Set the `SPARK_HOME` environment variable. +- **Apache Hadoop 3+**: Set the `HADOOP_HOME` environment variable. + +> 🛠️ **Note:** When using Java 17, you _must_ add JVM flags to allow Wayang and Spark to access internal Java APIs, or you will encounter `IllegalAccessError`. See below. + +### Validating the installation + +To execute your first application with Apache Wayang, you need to execute your program with the 'wayang-submit' command: + +```shell +bin/wayang-submit org.apache.wayang.apps.wordcount.Main java file://$(pwd)/README.md +``` + +### ⚙️ Java 17 Compatibility + +When running Wayang applications using Java 17 (especially with Spark), you must add JVM flags to open specific internal Java modules. These flags resolve access issues with `sun.nio.ch.DirectBuffer` and others. + +Update your `wayang-submit` (wayang-assembly/target/wayang-1.0.1-SNAPSHOT/bin/wayang-submit) script (or command) with: + +```bash +eval "$RUNNER \ + --add-exports=java.base/sun.nio.ch=ALL-UNNAMED \ + --add-opens=java.base/java.nio=ALL-UNNAMED \ + --add-opens=java.base/java.lang=ALL-UNNAMED \ + --add-opens=java.base/java.util=ALL-UNNAMED \ + --add-opens=java.base/java.io=ALL-UNNAMED \ + --add-opens=java.base/java.lang.reflect=ALL-UNNAMED \ + --add-opens=java.base/java.util.concurrent=ALL-UNNAMED \ + --add-opens=java.base/java.net=ALL-UNNAMED \ + --add-opens=java.base/java.lang.invoke=ALL-UNNAMED \ + $FLAGS -cp \"${WAYANG_CLASSPATH}\" $CLASS ${ARGS}" +``` + +## Getting Started + +Wayang is available via Maven Central. To use it with Maven, include the following code snippet into your POM file: +```xml + + org.apache.wayang + wayang-*** + 1.0.0 + +``` +Note the `***`: Wayang ships with multiple modules that can be included in your app, depending on how you want to use it: +* `wayang-core`: provides core data structures and the optimizer (required) +* `wayang-basic`: provides common operators and data types for your apps (recommended) +* `wayang-api-scala-java`: provides an easy-to-use Scala and Java API to assemble Wayang plans (recommended) +* `wayang-java`, `wayang-spark`, `wayang-graphchi`, `wayang-sqlite3`, `wayang-postgres`: adapters for the various supported processing platforms +* `wayang-profiler`: provides functionality to learn operator and UDF cost functions from historical execution data + +> **NOTE:** The module `wayang-api-scala-java` is intended to be used with Java 11 and Scala 2.12. + +For the sake of version flexibility, you still have to include in the POM file your Hadoop (`hadoop-hdfs` and `hadoop-common`) and Spark (`spark-core` and `spark-graphx`) version of choice. + +In addition, you can obtain the most recent snapshot version of Wayang via Sonatype's snapshot repository. Just include: +```xml + + + apache-snapshots + Apache Foundation Snapshot Repository + https://repository.apache.org/content/repositories/snapshots + + +``` + +### Prerequisites +Apache Wayang is built with Java 17 and Scala 2.12. However, to run Apache Wayang it is sufficient to have just Java 17 installed. Please also consider that processing platforms employed by Wayang might have further requirements. +``` +Java 17 +Scala 2.12.17 +Spark 3.4.4, Compatible with Scala 2.12. +Maven +``` + +> **NOTE:** In windows, you need to define the variable `HADOOP_HOME` with the winutils.exe, an not official option to obtain [this repository](https://github.com/steveloughran/winutils), or you can generate your winutils.exe following the instructions in the repository. Also, you may need to install [msvcr100.dll](https://www.microsoft.com/en-us/download/details.aspx?id=26999) + +> **NOTE:** Make sure that the JAVA_HOME environment variable is set correctly to Java 17 as the prerequisite checker script currently supports up to Java 17 and checks the latest version of Java if you have higher version installed. In Linux, it is preferably to use the export JAVA_HOME method inside the project folder. It is also recommended running './mvnw clean install' before opening the project using IntelliJ. + + +### Building + +If you need to rebuild Wayang, e.g., to use a different Scala version, you can simply do so via Maven: + +1. Adapt the version variables (e.g., `spark.version`) in the main `pom.xml` file. +2. Build Wayang with the adapted versions. + ```shell + git clone https://github.com/apache/incubator-wayang.git + cd incubator-wayang + ./mvnw clean install -DskipTests + ``` +> **NOTE:** If you receive an error about not finding `MathExBaseVisitor`, then the problem might be that you are trying to build from IntelliJ, without Maven. MathExBaseVisitor is generated code, and a Maven build should generate it automatically. + +> **NOTE:**: In the current Maven setup, Wayang supports Java 17. The default Scala version is 2.12.17, which is compatible with Java 17. Ensure that your Spark distribution is also built with Scala 2.12 (e.g., `spark-3.4.4-bin-hadoop3-scala2.12`). + +> **NOTE:** For compiling and testing the code it is required to have Hadoop installed on your machine. + +> **NOTE:** the `standalone` profile to fix Hadoop and Spark versions, so that Wayang apps do not explicitly need to declare the corresponding dependencies. + +> **NOTE**: When running applications (e.g., WordCount) with Java 17, you must pass additional flags to allow internal module access: + +>--add-exports=java.base/sun.nio.ch=ALL-UNNAMED \ +--add-opens=java.base/java.nio=ALL-UNNAMED \ +--add-opens=java.base/java.lang=ALL-UNNAMED \ +--add-opens=java.base/java.util=ALL-UNNAMED \ +--add-opens=java.base/java.io=ALL-UNNAMED \ +--add-opens=java.base/java.lang.reflect=ALL-UNNAMED \ +--add-opens=java.base/java.util.concurrent=ALL-UNNAMED \ +--add-opens=java.base/java.net=ALL-UNNAMED \ +--add-opens=java.base/java.lang.invoke=ALL-UNNAMED \ + +> +> Also, note the `distro` profile, which assembles a binary Wayang distribution. +To activate these profiles, you need to specify them when running maven, i.e., + +```shell +./mvnw clean install -DskipTests -P +``` + +## Running the tests +In the incubator-wayang root folder run: +```shell +./mvnw test +``` + +## Example Applications +You can see examples on how to start using Wayang [here](guides/wayang-examples.md) + +## Built With + +* [Java 17](https://www.oracle.com/java/technologies/javase/17-0-14-relnotes.html) +* [Scala 2.12.17](https://www.scala-lang.org/download/2.12.17.html) +* [Maven](https://maven.apache.org/) + +## Contributing +Before submitting a PR, please take a look on how to contribute with Apache Wayang contributing guidelines [here](CONTRIBUTING.md). + +There is also a guide on how to compile your code [here](guides/develop-in-Wayang.md). +## Authors +The list of [contributors](https://github.com/apache/incubator-wayang/graphs/contributors). + +## License +All files in this repository are licensed under the Apache Software License 2.0 + +Copyright 2020 - 2026 The Apache Software Foundation. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +## Acknowledgements +The [Logo](https://wayang.apache.org/img/wayang.png) was donated by Brian Vera. diff --git a/wayang-plugins/wayang-ml/src/test/resources/cost_model.onnx b/wayang-plugins/wayang-ml/src/test/resources/cost_model.onnx new file mode 100644 index 000000000..3bba58606 Binary files /dev/null and b/wayang-plugins/wayang-ml/src/test/resources/cost_model.onnx differ