From 16852cf877f64d325015800d75e83d18766584e0 Mon Sep 17 00:00:00 2001 From: Mandar Chandorkar Date: Tue, 1 Oct 2019 17:18:51 +0200 Subject: [PATCH 01/11] Support for Tensorflow Linear Algebra API: -- Added matrix determinant op (eager & lazy) --- .../tensorflow/api/core/types/package.scala | 6 + .../platanios/tensorflow/api/ops/Linalg.scala | 43 +++ .../tensorflow/api/ops/package.scala | 1 + .../platanios/tensorflow/api/package.scala | 262 +++++++++--------- .../tensorflow/api/tensors/ops/Linalg.scala | 41 +++ .../tensorflow/api/tensors/ops/package.scala | 1 + .../native/generated/tensor_linalg_ops.cc | 56 ++++ .../main/native/generated/tensor_linalg_ops.h | 21 ++ .../jni/generated/tensors/Linalg.scala | 26 ++ 9 files changed, 326 insertions(+), 131 deletions(-) create mode 100644 modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala create mode 100644 modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala create mode 100644 modules/jni/src/main/native/generated/tensor_linalg_ops.cc create mode 100644 modules/jni/src/main/native/generated/tensor_linalg_ops.h create mode 100644 modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/core/types/package.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/core/types/package.scala index 8546282c4..c6646427d 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/core/types/package.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/core/types/package.scala @@ -165,6 +165,7 @@ package object types { type Quantized = Union[QByte]#or[QShort]#or[QInt]#or[QUByte]#or[QUShort]#create type Numeric = Union[TruncatedHalf]#or[Half]#or[Float]#or[Double]#or[Byte]#or[Short]#or[Int]#or[Long]#or[UByte]#or[UShort]#or[UInt]#or[ULong]#or[ComplexFloat]#or[ComplexDouble]#or[QByte]#or[QShort]#or[QInt]#or[QUByte]#or[QUShort]#create type BooleanOrNumeric = Union[Boolean]#or[Half]#or[Float]#or[Double]#or[Byte]#or[Short]#or[Int]#or[Long]#or[UByte]#or[UShort]#or[UInt]#or[ULong]#or[ComplexFloat]#or[ComplexDouble]#or[QByte]#or[QShort]#or[QInt]#or[QUByte]#or[QUShort]#create + type RealOrComplex = Union[Float]#or[Double]#or[ComplexFloat]#or[ComplexDouble]#create type IsFloatOrDouble[T] = Contains[T, FloatOrDouble] type IsHalfOrFloat[T] = Contains[T, HalfOrFloat] @@ -186,6 +187,7 @@ package object types { type IsQuantized[T] = Contains[T, Quantized] type IsNumeric[T] = Contains[T, Numeric] type IsBooleanOrNumeric[T] = Contains[T, BooleanOrNumeric] + type IsRealOrComplex[T] = Contains[T, RealOrComplex] object IsFloatOrDouble { def apply[T: IsFloatOrDouble]: IsFloatOrDouble[T] = implicitly[IsFloatOrDouble[T]] @@ -266,4 +268,8 @@ package object types { object IsBooleanOrNumeric { def apply[T: IsBooleanOrNumeric]: IsBooleanOrNumeric[T] = implicitly[IsBooleanOrNumeric[T]] } + + object IsRealOrComplex { + def apply[T: IsRealOrComplex]: IsRealOrComplex[T] = implicitly[IsRealOrComplex[T]] + } } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala new file mode 100644 index 000000000..d56d6290f --- /dev/null +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala @@ -0,0 +1,43 @@ +/* Copyright 2019, T.AI Labs. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.platanios.tensorflow.api.ops + +import org.platanios.tensorflow.api.core.Shape +import org.platanios.tensorflow.api.core.exception.InvalidArgumentException +import org.platanios.tensorflow.api.core.types._ +import org.platanios.tensorflow.api.implicits.Implicits._ +import org.platanios.tensorflow.api.tensors +import org.platanios.tensorflow.api.tensors.Tensor +import org.platanios.tensorflow.api.utilities.DefaultsTo.IntDefault + +import scala.language.postfixOps + +/** + * Defines linear algebra ops similar to the + * ones defined in tf.linalg package of the Python TF API + * + * + */ +trait Linalg { + + def matrixDeterminant[T: TF: IsRealOrComplex](matrix: Output[T], name: String = "MatrixDeterminant"): Output[T] = { + Op.Builder[Output[T], Output[T]]( + opType = "MatrixDeterminant", + name = name, + input = matrix + ).build().output + } + +} diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/package.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/package.scala index 48041f186..cbb5b98c5 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/package.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/package.scala @@ -67,6 +67,7 @@ package object ops { with Logging with Math with NN + with Linalg with Parsing with Random with Resources diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/package.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/package.scala index 7badc279c..dcf9f0cd8 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/package.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/package.scala @@ -72,76 +72,76 @@ package object api extends implicits.Implicits with Documentation { val Shape: core.Shape.type = core.Shape type Indexer = core.Indexer - type Index = core.Index - type Slice = core.Slice + type Index = core.Index + type Slice = core.Slice val --- : Indexer = core.Ellipsis val NewAxis: Indexer = core.NewAxis val :: : Slice = core.Slice.:: - type TensorLike[T] = tensors.TensorLike[T] - type Tensor[T] = tensors.Tensor[T] + type TensorLike[T] = tensors.TensorLike[T] + type Tensor[T] = tensors.Tensor[T] type TensorIndexedSlices[T] = tensors.TensorIndexedSlices[T] - type SparseTensor[T] = tensors.SparseTensor[T] + type SparseTensor[T] = tensors.SparseTensor[T] - val Tensor : tensors.Tensor.type = tensors.Tensor + val Tensor: tensors.Tensor.type = tensors.Tensor val TensorIndexedSlices: tensors.TensorIndexedSlices.type = tensors.TensorIndexedSlices - val SparseTensor : tensors.SparseTensor.type = tensors.SparseTensor + val SparseTensor: tensors.SparseTensor.type = tensors.SparseTensor - type Op[I, O] = ops.Op[I, O] + type Op[I, O] = ops.Op[I, O] type UntypedOp = ops.UntypedOp - type OutputLike[T] = ops.OutputLike[T] - type Output[T] = ops.Output[T] + type OutputLike[T] = ops.OutputLike[T] + type Output[T] = ops.Output[T] type OutputIndexedSlices[T] = ops.OutputIndexedSlices[T] - type SparseOutput[T] = ops.SparseOutput[T] + type SparseOutput[T] = ops.SparseOutput[T] type TensorArray[T] = ops.TensorArray[T] - val Op : ops.Op.type = ops.Op - val Output : ops.Output.type = ops.Output + val Op: ops.Op.type = ops.Op + val Output: ops.Output.type = ops.Output val OutputIndexedSlices: ops.OutputIndexedSlices.type = ops.OutputIndexedSlices - val SparseOutput : ops.SparseOutput.type = ops.SparseOutput - val TensorArray : ops.TensorArray.type = ops.TensorArray + val SparseOutput: ops.SparseOutput.type = ops.SparseOutput + val TensorArray: ops.TensorArray.type = ops.TensorArray type VariableLike[T] = ops.variables.VariableLike[T] - type Variable[T] = ops.variables.Variable[T] + type Variable[T] = ops.variables.Variable[T] //region Types //region Value Classes - type Half = core.types.Half + type Half = core.types.Half type TruncatedHalf = core.types.TruncatedHalf - type ComplexFloat = core.types.ComplexFloat + type ComplexFloat = core.types.ComplexFloat type ComplexDouble = core.types.ComplexDouble - type UByte = core.types.UByte - type UShort = core.types.UShort - type UInt = core.types.UInt - type ULong = core.types.ULong - type QByte = core.types.QByte - type QShort = core.types.QShort - type QInt = core.types.QInt - type QUByte = core.types.QUByte - type QUShort = core.types.QUShort - type Resource = core.types.Resource - type Variant = core.types.Variant - - val Half : core.types.Half.type = core.types.Half + type UByte = core.types.UByte + type UShort = core.types.UShort + type UInt = core.types.UInt + type ULong = core.types.ULong + type QByte = core.types.QByte + type QShort = core.types.QShort + type QInt = core.types.QInt + type QUByte = core.types.QUByte + type QUShort = core.types.QUShort + type Resource = core.types.Resource + type Variant = core.types.Variant + + val Half: core.types.Half.type = core.types.Half val TruncatedHalf: core.types.TruncatedHalf.type = core.types.TruncatedHalf - val ComplexFloat : core.types.ComplexFloat.type = core.types.ComplexFloat + val ComplexFloat: core.types.ComplexFloat.type = core.types.ComplexFloat val ComplexDouble: core.types.ComplexDouble.type = core.types.ComplexDouble - val UByte : core.types.UByte.type = core.types.UByte - val UShort : core.types.UShort.type = core.types.UShort - val UInt : core.types.UInt.type = core.types.UInt - val ULong : core.types.ULong.type = core.types.ULong - val QByte : core.types.QByte.type = core.types.QByte - val QShort : core.types.QShort.type = core.types.QShort - val QInt : core.types.QInt.type = core.types.QInt - val QUByte : core.types.QUByte.type = core.types.QUByte - val QUShort : core.types.QUShort.type = core.types.QUShort - val Resource : core.types.Resource.type = core.types.Resource - val Variant : core.types.Variant.type = core.types.Variant + val UByte: core.types.UByte.type = core.types.UByte + val UShort: core.types.UShort.type = core.types.UShort + val UInt: core.types.UInt.type = core.types.UInt + val ULong: core.types.ULong.type = core.types.ULong + val QByte: core.types.QByte.type = core.types.QByte + val QShort: core.types.QShort.type = core.types.QShort + val QInt: core.types.QInt.type = core.types.QInt + val QUByte: core.types.QUByte.type = core.types.QUByte + val QUShort: core.types.QUShort.type = core.types.QUShort + val Resource: core.types.Resource.type = core.types.Resource + val Variant: core.types.Variant.type = core.types.Variant //endregion Value Classes @@ -149,96 +149,100 @@ package object api extends implicits.Implicits with Documentation { type DataType[T] = core.types.DataType[T] - type STRING = core.types.DataType[String] - type BOOLEAN = core.types.DataType[Boolean] - type FLOAT16 = core.types.DataType[core.types.Half] - type FLOAT32 = core.types.DataType[Float] - type FLOAT64 = core.types.DataType[Double] - type BFLOAT16 = core.types.DataType[core.types.TruncatedHalf] - type COMPLEX64 = core.types.DataType[core.types.ComplexFloat] + type STRING = core.types.DataType[String] + type BOOLEAN = core.types.DataType[Boolean] + type FLOAT16 = core.types.DataType[core.types.Half] + type FLOAT32 = core.types.DataType[Float] + type FLOAT64 = core.types.DataType[Double] + type BFLOAT16 = core.types.DataType[core.types.TruncatedHalf] + type COMPLEX64 = core.types.DataType[core.types.ComplexFloat] type COMPLEX128 = core.types.DataType[core.types.ComplexDouble] - type INT8 = core.types.DataType[Byte] - type INT16 = core.types.DataType[Short] - type INT32 = core.types.DataType[Int] - type INT64 = core.types.DataType[Long] - type UINT8 = core.types.DataType[core.types.UByte] - type UINT16 = core.types.DataType[core.types.UShort] - type UINT32 = core.types.DataType[core.types.UInt] - type UINT64 = core.types.DataType[core.types.ULong] - type QINT8 = core.types.DataType[core.types.QByte] - type QINT16 = core.types.DataType[core.types.QShort] - type QINT32 = core.types.DataType[core.types.QInt] - type QUINT8 = core.types.DataType[core.types.QUByte] - type QUINT16 = core.types.DataType[core.types.QUShort] - type RESOURCE = core.types.DataType[core.types.Resource] - type VARIANT = core.types.DataType[core.types.Variant] - - val STRING : STRING = core.types.STRING - val BOOLEAN : BOOLEAN = core.types.BOOLEAN - val FLOAT16 : FLOAT16 = core.types.FLOAT16 - val FLOAT32 : FLOAT32 = core.types.FLOAT32 - val FLOAT64 : FLOAT64 = core.types.FLOAT64 - val BFLOAT16 : BFLOAT16 = core.types.BFLOAT16 - val COMPLEX64 : COMPLEX64 = core.types.COMPLEX64 + type INT8 = core.types.DataType[Byte] + type INT16 = core.types.DataType[Short] + type INT32 = core.types.DataType[Int] + type INT64 = core.types.DataType[Long] + type UINT8 = core.types.DataType[core.types.UByte] + type UINT16 = core.types.DataType[core.types.UShort] + type UINT32 = core.types.DataType[core.types.UInt] + type UINT64 = core.types.DataType[core.types.ULong] + type QINT8 = core.types.DataType[core.types.QByte] + type QINT16 = core.types.DataType[core.types.QShort] + type QINT32 = core.types.DataType[core.types.QInt] + type QUINT8 = core.types.DataType[core.types.QUByte] + type QUINT16 = core.types.DataType[core.types.QUShort] + type RESOURCE = core.types.DataType[core.types.Resource] + type VARIANT = core.types.DataType[core.types.Variant] + + val STRING: STRING = core.types.STRING + val BOOLEAN: BOOLEAN = core.types.BOOLEAN + val FLOAT16: FLOAT16 = core.types.FLOAT16 + val FLOAT32: FLOAT32 = core.types.FLOAT32 + val FLOAT64: FLOAT64 = core.types.FLOAT64 + val BFLOAT16: BFLOAT16 = core.types.BFLOAT16 + val COMPLEX64: COMPLEX64 = core.types.COMPLEX64 val COMPLEX128: COMPLEX128 = core.types.COMPLEX128 - val INT8 : INT8 = core.types.INT8 - val INT16 : INT16 = core.types.INT16 - val INT32 : INT32 = core.types.INT32 - val INT64 : INT64 = core.types.INT64 - val UINT8 : UINT8 = core.types.UINT8 - val UINT16 : UINT16 = core.types.UINT16 - val UINT32 : UINT32 = core.types.UINT32 - val UINT64 : UINT64 = core.types.UINT64 - val QINT8 : QINT8 = core.types.QINT8 - val QINT16 : QINT16 = core.types.QINT16 - val QINT32 : QINT32 = core.types.QINT32 - val QUINT8 : QUINT8 = core.types.QUINT8 - val QUINT16 : QUINT16 = core.types.QUINT16 - val RESOURCE : RESOURCE = core.types.RESOURCE - val VARIANT : VARIANT = core.types.VARIANT + val INT8: INT8 = core.types.INT8 + val INT16: INT16 = core.types.INT16 + val INT32: INT32 = core.types.INT32 + val INT64: INT64 = core.types.INT64 + val UINT8: UINT8 = core.types.UINT8 + val UINT16: UINT16 = core.types.UINT16 + val UINT32: UINT32 = core.types.UINT32 + val UINT64: UINT64 = core.types.UINT64 + val QINT8: QINT8 = core.types.QINT8 + val QINT16: QINT16 = core.types.QINT16 + val QINT32: QINT32 = core.types.QINT32 + val QUINT8: QUINT8 = core.types.QUINT8 + val QUINT16: QUINT16 = core.types.QUINT16 + val RESOURCE: RESOURCE = core.types.RESOURCE + val VARIANT: VARIANT = core.types.VARIANT //endregion Data Type Instances //region Type Traits - type TF[T] = core.types.TF[T] - type IsFloatOrDouble[T] = core.types.IsFloatOrDouble[T] - type IsHalfOrFloatOrDouble[T] = core.types.IsHalfOrFloatOrDouble[T] - type IsTruncatedHalfOrFloatOrDouble[T] = core.types.IsTruncatedHalfOrFloatOrDouble[T] - type IsTruncatedHalfOrHalfOrFloat[T] = core.types.IsTruncatedHalfOrHalfOrFloat[T] - type IsDecimal[T] = core.types.IsDecimal[T] - type IsIntOrLong[T] = core.types.IsIntOrLong[T] - type IsIntOrLongOrFloatOrDouble[T] = core.types.IsIntOrLongOrFloatOrDouble[T] + type TF[T] = core.types.TF[T] + type IsFloatOrDouble[T] = core.types.IsFloatOrDouble[T] + type IsHalfOrFloatOrDouble[T] = core.types.IsHalfOrFloatOrDouble[T] + type IsTruncatedHalfOrFloatOrDouble[T] = core.types.IsTruncatedHalfOrFloatOrDouble[T] + type IsTruncatedHalfOrHalfOrFloat[T] = core.types.IsTruncatedHalfOrHalfOrFloat[T] + type IsDecimal[T] = core.types.IsDecimal[T] + type IsIntOrLong[T] = core.types.IsIntOrLong[T] + type IsIntOrLongOrFloatOrDouble[T] = core.types.IsIntOrLongOrFloatOrDouble[T] type IsIntOrLongOrHalfOrFloatOrDouble[T] = core.types.IsIntOrLongOrHalfOrFloatOrDouble[T] - type IsIntOrLongOrUByte[T] = core.types.IsIntOrLongOrUByte[T] - type IsIntOrUInt[T] = core.types.IsIntOrUInt[T] - type IsStringOrInteger[T] = core.types.IsStringOrInteger[T] - type IsStringOrFloatOrLong[T] = core.types.IsStringOrFloatOrLong[T] - type IsReal[T] = core.types.IsReal[T] - type IsComplex[T] = core.types.IsComplex[T] - type IsNotQuantized[T] = core.types.IsNotQuantized[T] - type IsQuantized[T] = core.types.IsQuantized[T] - type IsNumeric[T] = core.types.IsNumeric[T] - type IsBooleanOrNumeric[T] = core.types.IsBooleanOrNumeric[T] - - val TF : core.types.TF.type = core.types.TF - val IsFloatOrDouble : core.types.IsFloatOrDouble.type = core.types.IsFloatOrDouble - val IsHalfOrFloatOrDouble : core.types.IsHalfOrFloatOrDouble.type = core.types.IsHalfOrFloatOrDouble - val IsTruncatedHalfOrFloatOrDouble : core.types.IsTruncatedHalfOrFloatOrDouble.type = core.types.IsTruncatedHalfOrFloatOrDouble - val IsTruncatedHalfOrHalfOrFloat : core.types.IsTruncatedHalfOrHalfOrFloat.type = core.types.IsTruncatedHalfOrHalfOrFloat - val IsDecimal : core.types.IsDecimal.type = core.types.IsDecimal - val IsIntOrLong : core.types.IsIntOrLong.type = core.types.IsIntOrLong - val IsIntOrLongOrFloatOrDouble : core.types.IsIntOrLongOrFloatOrDouble.type = core.types.IsIntOrLongOrFloatOrDouble - val IsIntOrLongOrHalfOrFloatOrDouble: core.types.IsIntOrLongOrHalfOrFloatOrDouble.type = core.types.IsIntOrLongOrHalfOrFloatOrDouble - val IsIntOrLongOrUByte : core.types.IsIntOrLongOrUByte.type = core.types.IsIntOrLongOrUByte - val IsIntOrUInt : core.types.IsIntOrUInt.type = core.types.IsIntOrUInt - val IsStringOrFloatOrLong : core.types.IsStringOrFloatOrLong.type = core.types.IsStringOrFloatOrLong - val IsReal : core.types.IsReal.type = core.types.IsReal - val IsComplex : core.types.IsComplex.type = core.types.IsComplex - val IsNotQuantized : core.types.IsNotQuantized.type = core.types.IsNotQuantized - val IsQuantized : core.types.IsQuantized.type = core.types.IsQuantized - val IsNumeric : core.types.IsNumeric.type = core.types.IsNumeric - val IsBooleanOrNumeric : core.types.IsBooleanOrNumeric.type = core.types.IsBooleanOrNumeric + type IsIntOrLongOrUByte[T] = core.types.IsIntOrLongOrUByte[T] + type IsIntOrUInt[T] = core.types.IsIntOrUInt[T] + type IsStringOrInteger[T] = core.types.IsStringOrInteger[T] + type IsStringOrFloatOrLong[T] = core.types.IsStringOrFloatOrLong[T] + type IsReal[T] = core.types.IsReal[T] + type IsComplex[T] = core.types.IsComplex[T] + type IsNotQuantized[T] = core.types.IsNotQuantized[T] + type IsQuantized[T] = core.types.IsQuantized[T] + type IsNumeric[T] = core.types.IsNumeric[T] + type IsBooleanOrNumeric[T] = core.types.IsBooleanOrNumeric[T] + + val TF: core.types.TF.type = core.types.TF + val IsFloatOrDouble: core.types.IsFloatOrDouble.type = core.types.IsFloatOrDouble + val IsHalfOrFloatOrDouble: core.types.IsHalfOrFloatOrDouble.type = core.types.IsHalfOrFloatOrDouble + val IsTruncatedHalfOrFloatOrDouble: core.types.IsTruncatedHalfOrFloatOrDouble.type = + core.types.IsTruncatedHalfOrFloatOrDouble + val IsTruncatedHalfOrHalfOrFloat: core.types.IsTruncatedHalfOrHalfOrFloat.type = + core.types.IsTruncatedHalfOrHalfOrFloat + val IsDecimal: core.types.IsDecimal.type = core.types.IsDecimal + val IsIntOrLong: core.types.IsIntOrLong.type = core.types.IsIntOrLong + val IsIntOrLongOrFloatOrDouble: core.types.IsIntOrLongOrFloatOrDouble.type = core.types.IsIntOrLongOrFloatOrDouble + val IsIntOrLongOrHalfOrFloatOrDouble: core.types.IsIntOrLongOrHalfOrFloatOrDouble.type = + core.types.IsIntOrLongOrHalfOrFloatOrDouble + val IsIntOrLongOrUByte: core.types.IsIntOrLongOrUByte.type = core.types.IsIntOrLongOrUByte + val IsIntOrUInt: core.types.IsIntOrUInt.type = core.types.IsIntOrUInt + val IsStringOrFloatOrLong: core.types.IsStringOrFloatOrLong.type = core.types.IsStringOrFloatOrLong + val IsReal: core.types.IsReal.type = core.types.IsReal + val IsComplex: core.types.IsComplex.type = core.types.IsComplex + val IsNotQuantized: core.types.IsNotQuantized.type = core.types.IsNotQuantized + val IsQuantized: core.types.IsQuantized.type = core.types.IsQuantized + val IsNumeric: core.types.IsNumeric.type = core.types.IsNumeric + val IsBooleanOrNumeric: core.types.IsBooleanOrNumeric.type = core.types.IsBooleanOrNumeric + val IsRealOrComplex: core.types.IsRealOrComplex.type = core.types.IsRealOrComplex //endregion Type Traits @@ -291,9 +295,7 @@ package object api extends implicits.Implicits with Documentation { * @groupname CallbackOps Ops / Callback * @groupprio CallbackOps 280 */ - object tf - extends core.API - with ops.API { + object tf extends core.API with ops.API { object learn extends api.learn.API } @@ -336,7 +338,5 @@ package object api extends implicits.Implicits with Documentation { * @groupname CallbackOps Ops / Callback * @groupprio CallbackOps 280 */ - object tfi - extends core.API - with tensors.API + object tfi extends core.API with tensors.API } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala new file mode 100644 index 000000000..d8a967097 --- /dev/null +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala @@ -0,0 +1,41 @@ +/* Copyright 2019, T.AI Labs. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.platanios.tensorflow.api.tensors.ops + +import org.platanios.tensorflow.api.core.Shape +import org.platanios.tensorflow.api.core.exception.InvalidShapeException +import org.platanios.tensorflow.api.core.types._ +import org.platanios.tensorflow.api.implicits.Implicits._ +import org.platanios.tensorflow.api.tensors._ +import org.platanios.tensorflow.api.utilities.DefaultsTo.IntDefault +import org.platanios.tensorflow.jni.generated.tensors.{Linalg => NativeTensorOpsLinAlg} + +import scala.language.postfixOps + +/** + * Defines linear algebra ops similar to the + * ones defined in tf.linalg package of the Python TF API + * + * + */ +trait Linalg { + + def matrixDeterminant[T: TF: IsRealOrComplex](matrix: Tensor[T]): Tensor[T] = { + Tensor.fromNativeHandle[T]( + NativeTensorOpsLinAlg.matrixDeterminant(executionContext.value.nativeHandle, matrix.nativeHandle) + ) + } + +} diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/package.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/package.scala index 2c6bb2cc9..418f2e301 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/package.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/package.scala @@ -25,4 +25,5 @@ package object ops { with Math with NN with Random + with Linalg } diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc new file mode 100644 index 000000000..513538186 --- /dev/null +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc @@ -0,0 +1,56 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ + +/* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +#include "tensor_linalg_ops.h" +#include "exception.h" +#include "utilities.h" + +#include +#include +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/eager/c_api.h" + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixDeterminant( + JNIEnv* env, jobject object, jlong context_handle, jlong input) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "MatrixDeterminant", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, 0); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.h b/modules/jni/src/main/native/generated/tensor_linalg_ops.h new file mode 100644 index 000000000..4d43f7507 --- /dev/null +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.h @@ -0,0 +1,21 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_platanios_tensorflow_jni_generated_tensors_Linalg__ */ + +#ifndef _Included_org_platanios_tensorflow_jni_generated_tensors_Linalg__ +#define _Included_org_platanios_tensorflow_jni_generated_tensors_Linalg__ +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: matrixDeterminant + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixDeterminant + (JNIEnv *, jobject, jlong, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala new file mode 100644 index 000000000..9c54b34c4 --- /dev/null +++ b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala @@ -0,0 +1,26 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ + +/* Copyright 2017-19, Emmanouil Antonios Platanios. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package org.platanios.tensorflow.jni.generated.tensors + +import org.platanios.tensorflow.jni.TensorFlow + +object Linalg { + TensorFlow.load() + + @native def matrixDeterminant(contextHandle: Long, input: Long): Long +} From 23a9a07989fc287e56fa6a625cc090f8858b458d Mon Sep 17 00:00:00 2001 From: Mandar Chandorkar Date: Tue, 1 Oct 2019 17:30:39 +0200 Subject: [PATCH 02/11] Added export rule for Linalg in build.sbt --- build.sbt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/build.sbt b/build.sbt index a016809f6..e448c6571 100644 --- a/build.sbt +++ b/build.sbt @@ -26,7 +26,7 @@ organization in ThisBuild := "org.platanios" autoCompilerPlugins in ThisBuild := true -val tensorFlowVersion = "1.11.0" +val tensorFlowVersion = "1.12.0" val circeVersion = "0.10.1" // Use for working with JSON. // addCompilerPlugin(MetalsPlugin.semanticdbScalac) @@ -157,7 +157,8 @@ lazy val jni = (project in file("./modules/jni")) "Sparse" -> Seq("SparseToDense"), "Text" -> Seq( "StringJoin", "StringSplit", "EncodeBase64", "DecodeBase64", "StringToHashBucket", "StringToHashBucketFast", - "StringToHashBucketStrong") + "StringToHashBucketStrong"), + "Linalg" -> Seq("MatrixDeterminant") ), scalaPackage in generateTensorOps := "tensors", // Native bindings compilation settings From 297f12dc09a1241a87aac580eef031f7c22c7d71 Mon Sep 17 00:00:00 2001 From: Mandar Chandorkar Date: Tue, 1 Oct 2019 17:56:00 +0200 Subject: [PATCH 03/11] Linear Algebra support: -- Added logMatrixDeterminant op --- build.sbt | 2 +- .../platanios/tensorflow/api/ops/Linalg.scala | 8 +++++ .../tensorflow/api/tensors/ops/Linalg.scala | 8 +++++ .../native/generated/tensor_linalg_ops.cc | 34 +++++++++++++++++++ .../main/native/generated/tensor_linalg_ops.h | 8 +++++ .../jni/generated/tensors/Linalg.scala | 1 + 6 files changed, 60 insertions(+), 1 deletion(-) diff --git a/build.sbt b/build.sbt index e448c6571..083abbf04 100644 --- a/build.sbt +++ b/build.sbt @@ -158,7 +158,7 @@ lazy val jni = (project in file("./modules/jni")) "Text" -> Seq( "StringJoin", "StringSplit", "EncodeBase64", "DecodeBase64", "StringToHashBucket", "StringToHashBucketFast", "StringToHashBucketStrong"), - "Linalg" -> Seq("MatrixDeterminant") + "Linalg" -> Seq("LogMatrixDeterminant", "MatrixDeterminant") ), scalaPackage in generateTensorOps := "tensors", // Native bindings compilation settings diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala index d56d6290f..da179bf60 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala @@ -40,4 +40,12 @@ trait Linalg { ).build().output } + def logMatrixDeterminant[T: TF: IsRealOrComplex](matrix: Output[T], name: String = "LogMatrixDeterminant"): (Output[T], Output[T]) = { + Op.Builder[Output[T], (Output[T], Output[T])]( + opType = "LogMatrixDeterminant", + name = name, + input = matrix + ).build().output + } + } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala index d8a967097..7812eb2e2 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala @@ -38,4 +38,12 @@ trait Linalg { ) } + def logMatrixDeterminant[T: TF: IsRealOrComplex](matrix: Tensor[T]): (Tensor[T], Tensor[T]) = { + val results = NativeTensorOpsLinAlg + .logMatrixDeterminant(executionContext.value.nativeHandle, matrix.nativeHandle).map( + h => Tensor.fromNativeHandle[T](h) + ) + (results.head, results.last) + } + } diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc index 513538186..77be85b73 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc @@ -27,6 +27,40 @@ #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_logMatrixDeterminant( + JNIEnv* env, jobject object, jlong context_handle, jlong input) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, nullptr); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "LogMatrixDeterminant", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), nullptr); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, nullptr); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, nullptr); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + const int num_outputs = 2; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + jlongArray outputs_array = env->NewLongArray(static_cast(num_outputs)); + jlong* output_elems = env->GetLongArrayElements(outputs_array, nullptr); + for (int i = 0; i < num_outputs; ++i) { + output_elems[i] = reinterpret_cast(outputs[i]); + } + env->ReleaseLongArrayElements(outputs_array, output_elems, 0); + return outputs_array; +} + JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixDeterminant( JNIEnv* env, jobject object, jlong context_handle, jlong input) { REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.h b/modules/jni/src/main/native/generated/tensor_linalg_ops.h index 4d43f7507..f229fbbfc 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.h +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.h @@ -7,6 +7,14 @@ #ifdef __cplusplus extern "C" { #endif +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: logMatrixDeterminant + * Signature: (JJ)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_logMatrixDeterminant + (JNIEnv *, jobject, jlong, jlong); + /* * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ * Method: matrixDeterminant diff --git a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala index 9c54b34c4..ad7692711 100644 --- a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala +++ b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala @@ -22,5 +22,6 @@ import org.platanios.tensorflow.jni.TensorFlow object Linalg { TensorFlow.load() + @native def logMatrixDeterminant(contextHandle: Long, input: Long): Array[Long] @native def matrixDeterminant(contextHandle: Long, input: Long): Long } From 66ce2c8d1a3e032f4b495d1d9524012552c4b41c Mon Sep 17 00:00:00 2001 From: Mandar Chandorkar Date: Tue, 1 Oct 2019 18:28:22 +0200 Subject: [PATCH 04/11] Added scaladoc comments. --- .../org/platanios/tensorflow/api/ops/Linalg.scala | 14 +++++++++++++- .../tensorflow/api/tensors/ops/Linalg.scala | 8 ++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala index da179bf60..1f97117a0 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala @@ -40,7 +40,19 @@ trait Linalg { ).build().output } - def logMatrixDeterminant[T: TF: IsRealOrComplex](matrix: Output[T], name: String = "LogMatrixDeterminant"): (Output[T], Output[T]) = { + /** + * Computes (sign(det(x)) log(|det(x)|)) for an input x. + * + * @param matrix A matrix of shape [N, M, M] + * @param name An optional name to assign to the op. + * + * @return A tuple having the results. + * + */ + def logMatrixDeterminant[T: TF: IsRealOrComplex]( + matrix: Output[T], + name: String = "LogMatrixDeterminant" + ): (Output[T], Output[T]) = { Op.Builder[Output[T], (Output[T], Output[T])]( opType = "LogMatrixDeterminant", name = name, diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala index 7812eb2e2..4a1ae0c13 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala @@ -38,6 +38,14 @@ trait Linalg { ) } + /** + * Computes (sign(det(x)) log(|det(x)|)) for an input x. + * + * @param matrix A matrix of shape [N, M, M] + * + * @return A tuple having the results. + * + */ def logMatrixDeterminant[T: TF: IsRealOrComplex](matrix: Tensor[T]): (Tensor[T], Tensor[T]) = { val results = NativeTensorOpsLinAlg .logMatrixDeterminant(executionContext.value.nativeHandle, matrix.nativeHandle).map( From 6d7a635136141649cba9d97d6dd8201e74636978 Mon Sep 17 00:00:00 2001 From: Mandar Chandorkar Date: Tue, 1 Oct 2019 19:21:58 +0200 Subject: [PATCH 05/11] Linear Algebra support: -- Added matrix inversion op. --- build.sbt | 2 +- .../platanios/tensorflow/api/ops/Linalg.scala | 26 ++++++++++++++++ .../tensorflow/api/tensors/ops/Linalg.scala | 18 ++++++++++- .../native/generated/tensor_linalg_ops.cc | 30 +++++++++++++++++++ .../main/native/generated/tensor_linalg_ops.h | 8 +++++ .../jni/generated/tensors/Linalg.scala | 1 + 6 files changed, 83 insertions(+), 2 deletions(-) diff --git a/build.sbt b/build.sbt index 083abbf04..eed80d1e9 100644 --- a/build.sbt +++ b/build.sbt @@ -158,7 +158,7 @@ lazy val jni = (project in file("./modules/jni")) "Text" -> Seq( "StringJoin", "StringSplit", "EncodeBase64", "DecodeBase64", "StringToHashBucket", "StringToHashBucketFast", "StringToHashBucketStrong"), - "Linalg" -> Seq("LogMatrixDeterminant", "MatrixDeterminant") + "Linalg" -> Seq("LogMatrixDeterminant", "MatrixDeterminant", "MatrixInverse") ), scalaPackage in generateTensorOps := "tensors", // Native bindings compilation settings diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala index 1f97117a0..daea7827e 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala @@ -21,8 +21,12 @@ import org.platanios.tensorflow.api.implicits.Implicits._ import org.platanios.tensorflow.api.tensors import org.platanios.tensorflow.api.tensors.Tensor import org.platanios.tensorflow.api.utilities.DefaultsTo.IntDefault +//import com.google.protobuf.ByteString.Output + +import org.tensorflow.framework.AttrValue import scala.language.postfixOps +import com.google.protobuf.Descriptors.FieldDescriptor /** * Defines linear algebra ops similar to the @@ -43,6 +47,8 @@ trait Linalg { /** * Computes (sign(det(x)) log(|det(x)|)) for an input x. * + * @tparam T The underlying scala type of the matrix elements. + * * @param matrix A matrix of shape [N, M, M] * @param name An optional name to assign to the op. * @@ -60,4 +66,24 @@ trait Linalg { ).build().output } + /** + * Computes inv(A), assuming matrix A is invertible and of shape [..., M, M] + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The matrix to invert. + * @param adjoint If set to true, returns the adjoint, defaults to false. + * @param name An optional name to assign to the op. + * + */ + def matrixInverse[T: TF: IsRealOrComplex]( + matrix: Output[T], + adjoint: Boolean = false, + name: String = "MatrixInverse" + ): Output[T] = + Op.Builder[Output[T], Output[T]]( + opType = "MatrixInverse", + name = name, + input = matrix + ).setAttribute("adjoint", adjoint).build().output + } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala index 4a1ae0c13..a384ab400 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala @@ -41,8 +41,9 @@ trait Linalg { /** * Computes (sign(det(x)) log(|det(x)|)) for an input x. * + * @tparam T The underlying scala type of the matrix elements. * @param matrix A matrix of shape [N, M, M] - * + * * @return A tuple having the results. * */ @@ -54,4 +55,19 @@ trait Linalg { (results.head, results.last) } + /** + * Computes inv(A), assuming matrix A is invertible and of shape [..., M, M] + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The matrix to invert. + * @param adjoint If set to true, returns the adjoint, defaults to false. + * + * + */ + def matrixInverse[T: TF: IsRealOrComplex](matrix: Tensor[T], adjoint: Boolean = false): Tensor[T] = { + Tensor.fromNativeHandle[T]( + NativeTensorOpsLinAlg.matrixInverse(executionContext.value.nativeHandle, matrix.nativeHandle, adjoint) + ) + } + } diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc index 77be85b73..9281e5297 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc @@ -88,3 +88,33 @@ JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Lina return reinterpret_cast(outputs[0]); } + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixInverse( + JNIEnv* env, jobject object, jlong context_handle, jlong input, jboolean adjoint) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "MatrixInverse", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, 0); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + TFE_OpSetAttrBool(op.get(), "adjoint", static_cast(adjoint)); + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.h b/modules/jni/src/main/native/generated/tensor_linalg_ops.h index f229fbbfc..c8af97765 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.h +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.h @@ -23,6 +23,14 @@ JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixDeterminant (JNIEnv *, jobject, jlong, jlong); +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: matrixInverse + * Signature: (JJZ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixInverse + (JNIEnv *, jobject, jlong, jlong, jboolean); + #ifdef __cplusplus } #endif diff --git a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala index ad7692711..9691d4f31 100644 --- a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala +++ b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala @@ -24,4 +24,5 @@ object Linalg { @native def logMatrixDeterminant(contextHandle: Long, input: Long): Array[Long] @native def matrixDeterminant(contextHandle: Long, input: Long): Long + @native def matrixInverse(contextHandle: Long, input: Long, adjoint: Boolean): Long } From c8e4d3c64a3f70c78954d4b83adb0abb976ddf6f Mon Sep 17 00:00:00 2001 From: Mandar Chandorkar Date: Tue, 1 Oct 2019 19:45:29 +0200 Subject: [PATCH 06/11] Linear Algebra support: -- Added matrix solve function --- build.sbt | 2 +- .../platanios/tensorflow/api/ops/Linalg.scala | 35 +++++++++++++- .../tensorflow/api/tensors/ops/Linalg.scala | 26 ++++++++++ .../native/generated/tensor_linalg_ops.cc | 47 +++++++++++++++++++ .../main/native/generated/tensor_linalg_ops.h | 8 ++++ .../jni/generated/tensors/Linalg.scala | 1 + 6 files changed, 117 insertions(+), 2 deletions(-) diff --git a/build.sbt b/build.sbt index eed80d1e9..f73fc8d95 100644 --- a/build.sbt +++ b/build.sbt @@ -158,7 +158,7 @@ lazy val jni = (project in file("./modules/jni")) "Text" -> Seq( "StringJoin", "StringSplit", "EncodeBase64", "DecodeBase64", "StringToHashBucket", "StringToHashBucketFast", "StringToHashBucketStrong"), - "Linalg" -> Seq("LogMatrixDeterminant", "MatrixDeterminant", "MatrixInverse") + "Linalg" -> Seq("LogMatrixDeterminant", "MatrixDeterminant", "MatrixInverse", "MatrixSolve") ), scalaPackage in generateTensorOps := "tensors", // Native bindings compilation settings diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala index daea7827e..a1f3776d6 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala @@ -48,7 +48,7 @@ trait Linalg { * Computes (sign(det(x)) log(|det(x)|)) for an input x. * * @tparam T The underlying scala type of the matrix elements. - * + * * @param matrix A matrix of shape [N, M, M] * @param name An optional name to assign to the op. * @@ -86,4 +86,37 @@ trait Linalg { input = matrix ).setAttribute("adjoint", adjoint).build().output + /** + * Solves systems of linear equations Ax = b. + * The matrix M must be of shape [..., M, M] whose inner-most 2 dimensions + * form square matrices. + * + * The right hand side b is a tensor of shape [..., M, K]. + * The output x is a tensor shape [..., M, K] + * + * If `adjoint` is `True` then each output matrix satisfies + * adjoint(A[..., :, :]) * x[..., :, :] = b[..., :, :]. + * + * If `adjoint` is `False` then each output matrix satisfies + * A[..., :, :] * x[..., :, :] = b[..., :, :]. + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The matrix (A) on the left hand side. + * @param rhs The right hand side (b). + * @param adjoint Defaults to false. + * @param name An optional name to assign to the op. + * + */ + def matrixSolve[T: TF: IsRealOrComplex]( + matrix: Output[T], + rhs: Output[T], + adjoint: Boolean = false, + name: String = "MatrixSolve" + ): Output[T] = + Op.Builder[(Output[T], Output[T]), Output[T]]( + opType = "MatrixSolve", + name = name, + input = (matrix, rhs) + ).setAttribute("adjoint", adjoint).build().output + } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala index a384ab400..253a159e7 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala @@ -70,4 +70,30 @@ trait Linalg { ) } + /** + * Solves systems of linear equations Ax = b. + * The matrix M must be of shape [..., M, M] whose inner-most 2 dimensions + * form square matrices. + * + * The right hand side b is a tensor of shape [..., M, K]. + * The output x is a tensor shape [..., M, K] + * + * If `adjoint` is `True` then each output matrix satisfies + * adjoint(A[..., :, :]) * x[..., :, :] = b[..., :, :]. + * + * If `adjoint` is `False` then each output matrix satisfies + * A[..., :, :] * x[..., :, :] = b[..., :, :]. + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The matrix (A) on the left hand side. + * @param rhs The right hand side (b). + * @param adjoint Defaults to false. + * + */ + def matrixSolve[T: TF: IsRealOrComplex](matrix: Tensor[T], rhs: Tensor[T], adjoint: Boolean = false): Tensor[T] = { + Tensor.fromNativeHandle[T]( + NativeTensorOpsLinAlg.matrixSolve(executionContext.value.nativeHandle, matrix.nativeHandle, rhs.nativeHandle, adjoint) + ) + } + } diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc index 9281e5297..459a52467 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc @@ -118,3 +118,50 @@ JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Lina return reinterpret_cast(outputs[0]); } + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixSolve( + JNIEnv* env, jobject object, jlong context_handle, jlong matrix, jlong rhs, jboolean adjoint) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "MatrixSolve", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(matrix_handle, TFE_TensorHandle, matrix, 0); + TFE_OpAddInput(op.get(), matrix_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(rhs_handle, TFE_TensorHandle, rhs, 0); + TFE_OpAddInput(op.get(), rhs_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_matrix_handle, TFE_TensorHandle, matrix, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_matrix_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + REQUIRE_HANDLE(attr_T_rhs_handle, TFE_TensorHandle, rhs, 0); + const TF_DataType attr_T_rhs = TFE_TensorHandleDataType(attr_T_rhs_handle); + if (attr_T != attr_T_rhs) { + std::stringstream error_msg; + error_msg + << "Argument 'rhs' of 'matrixSolve' op with data type '" + << attr_T_rhs + << "' must match data type '" + << attr_T + << "' of argument 'matrix'"; + throw_exception(env, tf_invalid_argument_exception, error_msg.str().c_str()); + } + + TFE_OpSetAttrBool(op.get(), "adjoint", static_cast(adjoint)); + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.h b/modules/jni/src/main/native/generated/tensor_linalg_ops.h index c8af97765..812400b53 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.h +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.h @@ -31,6 +31,14 @@ JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Lina JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixInverse (JNIEnv *, jobject, jlong, jlong, jboolean); +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: matrixSolve + * Signature: (JJJZ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixSolve + (JNIEnv *, jobject, jlong, jlong, jlong, jboolean); + #ifdef __cplusplus } #endif diff --git a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala index 9691d4f31..62b078b13 100644 --- a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala +++ b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala @@ -25,4 +25,5 @@ object Linalg { @native def logMatrixDeterminant(contextHandle: Long, input: Long): Array[Long] @native def matrixDeterminant(contextHandle: Long, input: Long): Long @native def matrixInverse(contextHandle: Long, input: Long, adjoint: Boolean): Long + @native def matrixSolve(contextHandle: Long, matrix: Long, rhs: Long, adjoint: Boolean): Long } From f4455846f8a166b358030f091915e0c569e25f3e Mon Sep 17 00:00:00 2001 From: Mandar Chandorkar Date: Tue, 1 Oct 2019 20:10:10 +0200 Subject: [PATCH 07/11] Linear Algebra support: -- Added least squares matrix solve op. --- build.sbt | 2 +- .../platanios/tensorflow/api/ops/Linalg.scala | 46 ++++++++++++++--- .../tensorflow/api/tensors/ops/Linalg.scala | 36 +++++++++---- .../native/generated/tensor_linalg_ops.cc | 51 +++++++++++++++++++ .../main/native/generated/tensor_linalg_ops.h | 8 +++ .../jni/generated/tensors/Linalg.scala | 1 + 6 files changed, 127 insertions(+), 17 deletions(-) diff --git a/build.sbt b/build.sbt index f73fc8d95..69161f086 100644 --- a/build.sbt +++ b/build.sbt @@ -158,7 +158,7 @@ lazy val jni = (project in file("./modules/jni")) "Text" -> Seq( "StringJoin", "StringSplit", "EncodeBase64", "DecodeBase64", "StringToHashBucket", "StringToHashBucketFast", "StringToHashBucketStrong"), - "Linalg" -> Seq("LogMatrixDeterminant", "MatrixDeterminant", "MatrixInverse", "MatrixSolve") + "Linalg" -> Seq("LogMatrixDeterminant", "MatrixDeterminant", "MatrixInverse", "MatrixSolve", "MatrixSolveLs") ), scalaPackage in generateTensorOps := "tensors", // Native bindings compilation settings diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala index a1f3776d6..1f3fc237f 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala @@ -88,16 +88,16 @@ trait Linalg { /** * Solves systems of linear equations Ax = b. - * The matrix M must be of shape [..., M, M] whose inner-most 2 dimensions + * The matrix M must be of shape [..., M, M] whose inner-most 2 dimensions * form square matrices. - * - * The right hand side b is a tensor of shape [..., M, K]. + * + * The right hand side b is a tensor of shape [..., M, K]. * The output x is a tensor shape [..., M, K] - * - * If `adjoint` is `True` then each output matrix satisfies + * + * If `adjoint` is `True` then each output matrix satisfies * adjoint(A[..., :, :]) * x[..., :, :] = b[..., :, :]. - * - * If `adjoint` is `False` then each output matrix satisfies + * + * If `adjoint` is `False` then each output matrix satisfies * A[..., :, :] * x[..., :, :] = b[..., :, :]. * * @tparam T The underlying scala type of the matrix elements. @@ -119,4 +119,36 @@ trait Linalg { input = (matrix, rhs) ).setAttribute("adjoint", adjoint).build().output + /** + * Solves systems of linear equations Ax = b, in the regularised + * least squares sense. + * + * The matrix M must be of shape [..., M, N] whose inner-most 2 dimensions + * form square matrices. + * + * The right hand side b is a tensor of shape [..., M, K]. + * The output x is a tensor shape [..., N, K] + * + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The matrix (A) on the left hand side. + * @param rhs The right hand side (b). + * @param reg The L2 regularisation constant. + * @param fast Defaults to true. + * @param name An optional name to assign to the op. + * + */ + def matrixSolveLS[T: TF: IsRealOrComplex]( + matrix: Output[T], + rhs: Output[T], + reg: Output[T], + fast: Boolean = true, + name: String = "MatrixSolveLs" + ): Output[T] = + Op.Builder[(Output[T], Output[T], Output[T]), Output[T]]( + opType = "MatrixSolve", + name = name, + input = (matrix, rhs, reg) + ).setAttribute("fast", fast).build().output + } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala index 253a159e7..b1db37725 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala @@ -72,27 +72,45 @@ trait Linalg { /** * Solves systems of linear equations Ax = b. - * The matrix M must be of shape [..., M, M] whose inner-most 2 dimensions + * The matrix M must be of shape [..., M, M] whose inner-most 2 dimensions * form square matrices. - * - * The right hand side b is a tensor of shape [..., M, K]. + * + * The right hand side b is a tensor of shape [..., M, K]. * The output x is a tensor shape [..., M, K] - * - * If `adjoint` is `True` then each output matrix satisfies + * + * If `adjoint` is `True` then each output matrix satisfies * adjoint(A[..., :, :]) * x[..., :, :] = b[..., :, :]. - * - * If `adjoint` is `False` then each output matrix satisfies + * + * If `adjoint` is `False` then each output matrix satisfies * A[..., :, :] * x[..., :, :] = b[..., :, :]. * * @tparam T The underlying scala type of the matrix elements. * @param matrix The matrix (A) on the left hand side. * @param rhs The right hand side (b). * @param adjoint Defaults to false. - * + * */ def matrixSolve[T: TF: IsRealOrComplex](matrix: Tensor[T], rhs: Tensor[T], adjoint: Boolean = false): Tensor[T] = { Tensor.fromNativeHandle[T]( - NativeTensorOpsLinAlg.matrixSolve(executionContext.value.nativeHandle, matrix.nativeHandle, rhs.nativeHandle, adjoint) + NativeTensorOpsLinAlg + .matrixSolve(executionContext.value.nativeHandle, matrix.nativeHandle, rhs.nativeHandle, adjoint) + ) + } + + def matrixSolveLS[T: TF: IsRealOrComplex]( + matrix: Tensor[T], + rhs: Tensor[T], + reg: Tensor[T], + fast: Boolean = true + ): Tensor[T] = { + Tensor.fromNativeHandle[T]( + NativeTensorOpsLinAlg.matrixSolveLs( + executionContext.value.nativeHandle, + matrix.nativeHandle, + rhs.nativeHandle, + reg.nativeHandle, + fast + ) ) } diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc index 459a52467..783982b63 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc @@ -165,3 +165,54 @@ JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Lina return reinterpret_cast(outputs[0]); } + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixSolveLs( + JNIEnv* env, jobject object, jlong context_handle, jlong matrix, jlong rhs, jlong l2_regularizer, jboolean fast) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "MatrixSolveLs", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(matrix_handle, TFE_TensorHandle, matrix, 0); + TFE_OpAddInput(op.get(), matrix_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(rhs_handle, TFE_TensorHandle, rhs, 0); + TFE_OpAddInput(op.get(), rhs_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(l2_regularizer_handle, TFE_TensorHandle, l2_regularizer, 0); + TFE_OpAddInput(op.get(), l2_regularizer_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_matrix_handle, TFE_TensorHandle, matrix, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_matrix_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + REQUIRE_HANDLE(attr_T_rhs_handle, TFE_TensorHandle, rhs, 0); + const TF_DataType attr_T_rhs = TFE_TensorHandleDataType(attr_T_rhs_handle); + if (attr_T != attr_T_rhs) { + std::stringstream error_msg; + error_msg + << "Argument 'rhs' of 'matrixSolveLs' op with data type '" + << attr_T_rhs + << "' must match data type '" + << attr_T + << "' of argument 'matrix'"; + throw_exception(env, tf_invalid_argument_exception, error_msg.str().c_str()); + } + + TFE_OpSetAttrBool(op.get(), "fast", static_cast(fast)); + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.h b/modules/jni/src/main/native/generated/tensor_linalg_ops.h index 812400b53..b511bc16b 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.h +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.h @@ -39,6 +39,14 @@ JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Lina JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixSolve (JNIEnv *, jobject, jlong, jlong, jlong, jboolean); +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: matrixSolveLs + * Signature: (JJJJZ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixSolveLs + (JNIEnv *, jobject, jlong, jlong, jlong, jlong, jboolean); + #ifdef __cplusplus } #endif diff --git a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala index 62b078b13..6efeae794 100644 --- a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala +++ b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala @@ -26,4 +26,5 @@ object Linalg { @native def matrixDeterminant(contextHandle: Long, input: Long): Long @native def matrixInverse(contextHandle: Long, input: Long, adjoint: Boolean): Long @native def matrixSolve(contextHandle: Long, matrix: Long, rhs: Long, adjoint: Boolean): Long + @native def matrixSolveLs(contextHandle: Long, matrix: Long, rhs: Long, l2_regularizer: Long, fast: Boolean): Long } From 9edaecd68bce3c429e69d680e851e52412df8025 Mon Sep 17 00:00:00 2001 From: Mandar Chandorkar Date: Tue, 1 Oct 2019 20:29:29 +0200 Subject: [PATCH 08/11] Linear Algebra support: -- Added triangular matrix solve op. --- build.sbt | 2 +- .../platanios/tensorflow/api/ops/Linalg.scala | 27 ++++++++-- .../tensorflow/api/tensors/ops/Linalg.scala | 17 +++++++ .../native/generated/tensor_linalg_ops.cc | 49 +++++++++++++++++++ .../main/native/generated/tensor_linalg_ops.h | 8 +++ .../jni/generated/tensors/Linalg.scala | 1 + 6 files changed, 100 insertions(+), 4 deletions(-) diff --git a/build.sbt b/build.sbt index 69161f086..c15c45522 100644 --- a/build.sbt +++ b/build.sbt @@ -158,7 +158,7 @@ lazy val jni = (project in file("./modules/jni")) "Text" -> Seq( "StringJoin", "StringSplit", "EncodeBase64", "DecodeBase64", "StringToHashBucket", "StringToHashBucketFast", "StringToHashBucketStrong"), - "Linalg" -> Seq("LogMatrixDeterminant", "MatrixDeterminant", "MatrixInverse", "MatrixSolve", "MatrixSolveLs") + "Linalg" -> Seq("LogMatrixDeterminant", "MatrixDeterminant", "MatrixInverse", "MatrixSolve", "MatrixSolveLs", /* "MatrixSquareRoot", */ "MatrixTriangularSolve") ), scalaPackage in generateTensorOps := "tensors", // Native bindings compilation settings diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala index 1f3fc237f..c5461ae69 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala @@ -120,9 +120,9 @@ trait Linalg { ).setAttribute("adjoint", adjoint).build().output /** - * Solves systems of linear equations Ax = b, in the regularised + * Solves systems of linear equations Ax = b, in the regularised * least squares sense. - * + * * The matrix M must be of shape [..., M, N] whose inner-most 2 dimensions * form square matrices. * @@ -146,9 +146,30 @@ trait Linalg { name: String = "MatrixSolveLs" ): Output[T] = Op.Builder[(Output[T], Output[T], Output[T]), Output[T]]( - opType = "MatrixSolve", + opType = "MatrixSolveLs", name = name, input = (matrix, rhs, reg) ).setAttribute("fast", fast).build().output + /* def matrixSquareRoot[T: TF: IsRealOrComplex](matrix: Output[T], name: String = "MatrixSquareRoot"): Output[T] = { + Op.Builder[Output[T], Output[T]]( + opType = "MatrixSquareRoot", + name = name, + input = matrix + ).build().output + } */ + + def matrixTriangularSolve[T: TF: IsRealOrComplex]( + matrix: Output[T], + rhs: Output[T], + lower: Boolean = true, + adjoint: Boolean = false, + name: String = "MatrixTriangularSolve" + ): Output[T] = + Op.Builder[(Output[T], Output[T]), Output[T]]( + opType = "MatrixTriangularSolve", + name = name, + input = (matrix, rhs) + ).setAttribute("lower", lower).setAttribute("adjoint", adjoint).build().output + } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala index b1db37725..b6a3059b9 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala @@ -114,4 +114,21 @@ trait Linalg { ) } + def matrixTriangularSolve[T: TF: IsRealOrComplex]( + matrix: Tensor[T], + rhs: Tensor[T], + lower: Boolean = true, + adjoint: Boolean = false + ): Tensor[T] = { + Tensor.fromNativeHandle[T]( + NativeTensorOpsLinAlg.matrixTriangularSolve( + executionContext.value.nativeHandle, + matrix.nativeHandle, + rhs.nativeHandle, + lower, + adjoint + ) + ) + } + } diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc index 783982b63..d1b29d1eb 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc @@ -216,3 +216,52 @@ JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Lina return reinterpret_cast(outputs[0]); } + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixTriangularSolve( + JNIEnv* env, jobject object, jlong context_handle, jlong matrix, jlong rhs, jboolean lower, jboolean adjoint) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "MatrixTriangularSolve", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(matrix_handle, TFE_TensorHandle, matrix, 0); + TFE_OpAddInput(op.get(), matrix_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(rhs_handle, TFE_TensorHandle, rhs, 0); + TFE_OpAddInput(op.get(), rhs_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_matrix_handle, TFE_TensorHandle, matrix, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_matrix_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + REQUIRE_HANDLE(attr_T_rhs_handle, TFE_TensorHandle, rhs, 0); + const TF_DataType attr_T_rhs = TFE_TensorHandleDataType(attr_T_rhs_handle); + if (attr_T != attr_T_rhs) { + std::stringstream error_msg; + error_msg + << "Argument 'rhs' of 'matrixTriangularSolve' op with data type '" + << attr_T_rhs + << "' must match data type '" + << attr_T + << "' of argument 'matrix'"; + throw_exception(env, tf_invalid_argument_exception, error_msg.str().c_str()); + } + + TFE_OpSetAttrBool(op.get(), "lower", static_cast(lower)); + + TFE_OpSetAttrBool(op.get(), "adjoint", static_cast(adjoint)); + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.h b/modules/jni/src/main/native/generated/tensor_linalg_ops.h index b511bc16b..6f6490bae 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.h +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.h @@ -47,6 +47,14 @@ JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Lina JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixSolveLs (JNIEnv *, jobject, jlong, jlong, jlong, jlong, jboolean); +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: matrixTriangularSolve + * Signature: (JJJZZ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixTriangularSolve + (JNIEnv *, jobject, jlong, jlong, jlong, jboolean, jboolean); + #ifdef __cplusplus } #endif diff --git a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala index 6efeae794..81010a357 100644 --- a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala +++ b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala @@ -27,4 +27,5 @@ object Linalg { @native def matrixInverse(contextHandle: Long, input: Long, adjoint: Boolean): Long @native def matrixSolve(contextHandle: Long, matrix: Long, rhs: Long, adjoint: Boolean): Long @native def matrixSolveLs(contextHandle: Long, matrix: Long, rhs: Long, l2_regularizer: Long, fast: Boolean): Long + @native def matrixTriangularSolve(contextHandle: Long, matrix: Long, rhs: Long, lower: Boolean, adjoint: Boolean): Long } From 5d0e0d02c5c1b452991cd10fda66931c211bda43 Mon Sep 17 00:00:00 2001 From: Mandar Chandorkar Date: Wed, 2 Oct 2019 11:37:50 +0200 Subject: [PATCH 09/11] Linear Algebra support: -- Added Cholesky and GradCholesky op. --- build.sbt | 5 +- .../platanios/tensorflow/api/ops/Linalg.scala | 30 ++++- .../tensorflow/api/tensors/ops/Linalg.scala | 14 +++ .../native/generated/tensor_linalg_ops.cc | 109 ++++++++++++++++++ .../main/native/generated/tensor_linalg_ops.h | 24 ++++ .../jni/generated/tensors/Linalg.scala | 3 + 6 files changed, 183 insertions(+), 2 deletions(-) diff --git a/build.sbt b/build.sbt index c15c45522..957aaaca2 100644 --- a/build.sbt +++ b/build.sbt @@ -158,7 +158,10 @@ lazy val jni = (project in file("./modules/jni")) "Text" -> Seq( "StringJoin", "StringSplit", "EncodeBase64", "DecodeBase64", "StringToHashBucket", "StringToHashBucketFast", "StringToHashBucketStrong"), - "Linalg" -> Seq("LogMatrixDeterminant", "MatrixDeterminant", "MatrixInverse", "MatrixSolve", "MatrixSolveLs", /* "MatrixSquareRoot", */ "MatrixTriangularSolve") + "Linalg" -> Seq( + "Cholesky", "CholeskyGrad", "LogMatrixDeterminant", "MatrixDeterminant", + "MatrixInverse", "MatrixSolve", "MatrixSolveLs", + /* "MatrixSquareRoot", */ "MatrixTriangularSolve", "Qr") ), scalaPackage in generateTensorOps := "tensors", // Native bindings compilation settings diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala index c5461ae69..6c5723e0f 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala @@ -36,6 +36,23 @@ import com.google.protobuf.Descriptors.FieldDescriptor */ trait Linalg { + def cholesky[T: TF: IsRealOrComplex](matrix: Output[T], name: String = "Cholesky"): Output[T] = + Op.Builder[Output[T], Output[T]]( + opType = "Cholesky", + name = name, + input = matrix + ).setGradientFn(choleskyGrad(_, _)(TF[T], IsRealOrComplex[T])).build().output + + protected def choleskyGrad[T: TF: IsRealOrComplex]( + l: Op[Output[T], Output[T]], + outputGradient: Output[T] + ): Output[T] = + Op.Builder[(Output[T], Output[T]), Output[T]]( + opType = "CholeskyGrad", + name = "CholeskyGrad", + input = (l.output, outputGradient) + ).build().output + def matrixDeterminant[T: TF: IsRealOrComplex](matrix: Output[T], name: String = "MatrixDeterminant"): Output[T] = { Op.Builder[Output[T], Output[T]]( opType = "MatrixDeterminant", @@ -138,7 +155,7 @@ trait Linalg { * @param name An optional name to assign to the op. * */ - def matrixSolveLS[T: TF: IsRealOrComplex]( + def matrixSolveLS[T: TF: IsReal]( matrix: Output[T], rhs: Output[T], reg: Output[T], @@ -172,4 +189,15 @@ trait Linalg { input = (matrix, rhs) ).setAttribute("lower", lower).setAttribute("adjoint", adjoint).build().output + def qr[T: TF: IsReal]( + matrix: Output[T], + full_matrices: Boolean = false, + name: String = "Qr" + ): (Output[T], Output[T]) = + Op.Builder[Output[T], (Output[T], Output[T])]( + opType = "Qr", + name = name, + input = matrix + ).setAttribute("full_matrices", full_matrices).build().output + } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala index b6a3059b9..8d002ba4d 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala @@ -32,6 +32,11 @@ import scala.language.postfixOps */ trait Linalg { + def cholesky[T: TF: IsRealOrComplex](matrix: Tensor[T]): Tensor[T] = + Tensor.fromNativeHandle[T]( + NativeTensorOpsLinAlg.cholesky(executionContext.value.nativeHandle, matrix.nativeHandle) + ) + def matrixDeterminant[T: TF: IsRealOrComplex](matrix: Tensor[T]): Tensor[T] = { Tensor.fromNativeHandle[T]( NativeTensorOpsLinAlg.matrixDeterminant(executionContext.value.nativeHandle, matrix.nativeHandle) @@ -131,4 +136,13 @@ trait Linalg { ) } + def qr[T: TF: IsReal](matrix: Tensor[T], full_matrices: Boolean = false): (Tensor[T], Tensor[T]) = { + + val results = NativeTensorOpsLinAlg + .qr(executionContext.value.nativeHandle, matrix.nativeHandle, full_matrices).map( + h => Tensor.fromNativeHandle[T](h) + ) + (results.head, results.last) + } + } diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc index d1b29d1eb..01851b1b4 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc @@ -27,6 +27,79 @@ #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_cholesky( + JNIEnv* env, jobject object, jlong context_handle, jlong input) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "Cholesky", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, 0); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} + +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_choleskyGrad( + JNIEnv* env, jobject object, jlong context_handle, jlong l, jlong grad) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, 0); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "CholeskyGrad", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), 0); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(l_handle, TFE_TensorHandle, l, 0); + TFE_OpAddInput(op.get(), l_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(grad_handle, TFE_TensorHandle, grad, 0); + TFE_OpAddInput(op.get(), grad_handle, status.get()); + CHECK_STATUS(env, status.get(), 0); + + REQUIRE_HANDLE(attr_T_l_handle, TFE_TensorHandle, l, 0); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_l_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + REQUIRE_HANDLE(attr_T_grad_handle, TFE_TensorHandle, grad, 0); + const TF_DataType attr_T_grad = TFE_TensorHandleDataType(attr_T_grad_handle); + if (attr_T != attr_T_grad) { + std::stringstream error_msg; + error_msg + << "Argument 'grad' of 'choleskyGrad' op with data type '" + << attr_T_grad + << "' must match data type '" + << attr_T + << "' of argument 'l'"; + throw_exception(env, tf_invalid_argument_exception, error_msg.str().c_str()); + } + + const int num_outputs = 1; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), 0); + + return reinterpret_cast(outputs[0]); +} + JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_logMatrixDeterminant( JNIEnv* env, jobject object, jlong context_handle, jlong input) { REQUIRE_HANDLE(context, TFE_Context, context_handle, nullptr); @@ -265,3 +338,39 @@ JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Lina return reinterpret_cast(outputs[0]); } + +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_qr( + JNIEnv* env, jobject object, jlong context_handle, jlong input, jboolean full_matrices) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, nullptr); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "Qr", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), nullptr); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, nullptr); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, nullptr); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + TFE_OpSetAttrBool(op.get(), "full_matrices", static_cast(full_matrices)); + + const int num_outputs = 2; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + jlongArray outputs_array = env->NewLongArray(static_cast(num_outputs)); + jlong* output_elems = env->GetLongArrayElements(outputs_array, nullptr); + for (int i = 0; i < num_outputs; ++i) { + output_elems[i] = reinterpret_cast(outputs[i]); + } + env->ReleaseLongArrayElements(outputs_array, output_elems, 0); + return outputs_array; +} diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.h b/modules/jni/src/main/native/generated/tensor_linalg_ops.h index 6f6490bae..4eda5d99f 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.h +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.h @@ -7,6 +7,22 @@ #ifdef __cplusplus extern "C" { #endif +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: cholesky + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_cholesky + (JNIEnv *, jobject, jlong, jlong); + +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: choleskyGrad + * Signature: (JJJ)J + */ +JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_choleskyGrad + (JNIEnv *, jobject, jlong, jlong, jlong); + /* * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ * Method: logMatrixDeterminant @@ -55,6 +71,14 @@ JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Lina JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_matrixTriangularSolve (JNIEnv *, jobject, jlong, jlong, jlong, jboolean, jboolean); +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: qr + * Signature: (JJZ)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_qr + (JNIEnv *, jobject, jlong, jlong, jboolean); + #ifdef __cplusplus } #endif diff --git a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala index 81010a357..62720dc82 100644 --- a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala +++ b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala @@ -22,10 +22,13 @@ import org.platanios.tensorflow.jni.TensorFlow object Linalg { TensorFlow.load() + @native def cholesky(contextHandle: Long, input: Long): Long + @native def choleskyGrad(contextHandle: Long, l: Long, grad: Long): Long @native def logMatrixDeterminant(contextHandle: Long, input: Long): Array[Long] @native def matrixDeterminant(contextHandle: Long, input: Long): Long @native def matrixInverse(contextHandle: Long, input: Long, adjoint: Boolean): Long @native def matrixSolve(contextHandle: Long, matrix: Long, rhs: Long, adjoint: Boolean): Long @native def matrixSolveLs(contextHandle: Long, matrix: Long, rhs: Long, l2_regularizer: Long, fast: Boolean): Long @native def matrixTriangularSolve(contextHandle: Long, matrix: Long, rhs: Long, lower: Boolean, adjoint: Boolean): Long + @native def qr(contextHandle: Long, input: Long, full_matrices: Boolean): Array[Long] } From 3e533ec3179e3e885361e5aec7755b36fa28b1e1 Mon Sep 17 00:00:00 2001 From: Mandar Chandorkar Date: Wed, 2 Oct 2019 17:40:35 +0200 Subject: [PATCH 10/11] Linear Algebra support: -- Added SelfAdjointEig (v2), QR, SVD ops. --- build.sbt | 3 +- .../platanios/tensorflow/api/ops/Linalg.scala | 26 ++++++- .../tensorflow/api/tensors/ops/Linalg.scala | 22 +++++- .../native/generated/tensor_linalg_ops.cc | 74 +++++++++++++++++++ .../main/native/generated/tensor_linalg_ops.h | 16 ++++ .../jni/generated/tensors/Linalg.scala | 2 + 6 files changed, 139 insertions(+), 4 deletions(-) diff --git a/build.sbt b/build.sbt index 957aaaca2..53959b59e 100644 --- a/build.sbt +++ b/build.sbt @@ -161,7 +161,8 @@ lazy val jni = (project in file("./modules/jni")) "Linalg" -> Seq( "Cholesky", "CholeskyGrad", "LogMatrixDeterminant", "MatrixDeterminant", "MatrixInverse", "MatrixSolve", "MatrixSolveLs", - /* "MatrixSquareRoot", */ "MatrixTriangularSolve", "Qr") + /* "MatrixSquareRoot", */ "MatrixTriangularSolve", "Qr", + "SelfAdjointEigV2", "Svd") ), scalaPackage in generateTensorOps := "tensors", // Native bindings compilation settings diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala index 6c5723e0f..5f471f0a9 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala @@ -155,7 +155,7 @@ trait Linalg { * @param name An optional name to assign to the op. * */ - def matrixSolveLS[T: TF: IsReal]( + def matrixSolveLS[T: TF: IsRealOrComplex]( matrix: Output[T], rhs: Output[T], reg: Output[T], @@ -189,7 +189,7 @@ trait Linalg { input = (matrix, rhs) ).setAttribute("lower", lower).setAttribute("adjoint", adjoint).build().output - def qr[T: TF: IsReal]( + def qr[T: TF: IsRealOrComplex]( matrix: Output[T], full_matrices: Boolean = false, name: String = "Qr" @@ -200,4 +200,26 @@ trait Linalg { input = matrix ).setAttribute("full_matrices", full_matrices).build().output + def selfAdjointEig[T: TF: IsRealOrComplex]( + matrix: Output[T], + compute_v: Boolean = true, + name: String = "SelfAdjointEigV2" + ): (Output[T], Output[T]) = + Op.Builder[Output[T], (Output[T], Output[T])]( + opType = "SelfAdjointEigV2", + name = name, + input = matrix + ).setAttribute("compute_v", compute_v).build().output + + def svd[T: TF: IsRealOrComplex]( + matrix: Output[T], + compute_uv: Boolean = true, + full_matrices: Boolean = false, + name: String = "Svd" + ): (Output[T], Output[T], Output[T]) = + Op.Builder[Output[T], (Output[T], Output[T], Output[T])]( + opType = "Svd", + name = name, + input = matrix + ).setAttribute("compute_uv", compute_uv).setAttribute("full_matrices", full_matrices).build().output } diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala index 8d002ba4d..a17990553 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/tensors/ops/Linalg.scala @@ -136,7 +136,7 @@ trait Linalg { ) } - def qr[T: TF: IsReal](matrix: Tensor[T], full_matrices: Boolean = false): (Tensor[T], Tensor[T]) = { + def qr[T: TF: IsRealOrComplex](matrix: Tensor[T], full_matrices: Boolean = false): (Tensor[T], Tensor[T]) = { val results = NativeTensorOpsLinAlg .qr(executionContext.value.nativeHandle, matrix.nativeHandle, full_matrices).map( @@ -145,4 +145,24 @@ trait Linalg { (results.head, results.last) } + def selfAdjointEig[T: TF: IsRealOrComplex](matrix: Tensor[T], compute_v: Boolean = true): (Tensor[T], Tensor[T]) = { + val results = NativeTensorOpsLinAlg + .selfAdjointEigV2(executionContext.value.nativeHandle, matrix.nativeHandle, compute_v).map( + h => Tensor.fromNativeHandle[T](h) + ) + (results.head, results.last) + } + + def svd[T: TF: IsRealOrComplex]( + matrix: Tensor[T], + compute_uv: Boolean = true, + full_matrices: Boolean = false + ): (Tensor[T], Tensor[T], Tensor[T]) = { + + val results = NativeTensorOpsLinAlg + .svd(executionContext.value.nativeHandle, matrix.nativeHandle, compute_uv, full_matrices).map( + h => Tensor.fromNativeHandle[T](h) + ) + (results.head, results(1), results.last) + } } diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc index 01851b1b4..e7da7f08b 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.cc +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.cc @@ -374,3 +374,77 @@ JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors env->ReleaseLongArrayElements(outputs_array, output_elems, 0); return outputs_array; } + +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_selfAdjointEigV2( + JNIEnv* env, jobject object, jlong context_handle, jlong input, jboolean compute_v) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, nullptr); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "SelfAdjointEigV2", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), nullptr); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, nullptr); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, nullptr); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + TFE_OpSetAttrBool(op.get(), "compute_v", static_cast(compute_v)); + + const int num_outputs = 2; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + jlongArray outputs_array = env->NewLongArray(static_cast(num_outputs)); + jlong* output_elems = env->GetLongArrayElements(outputs_array, nullptr); + for (int i = 0; i < num_outputs; ++i) { + output_elems[i] = reinterpret_cast(outputs[i]); + } + env->ReleaseLongArrayElements(outputs_array, output_elems, 0); + return outputs_array; +} + +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_svd( + JNIEnv* env, jobject object, jlong context_handle, jlong input, jboolean compute_uv, jboolean full_matrices) { + REQUIRE_HANDLE(context, TFE_Context, context_handle, nullptr); + std::unique_ptr status(TF_NewStatus(), TF_DeleteStatus); + + std::unique_ptr op( + TFE_NewOp(context, "Svd", status.get()), TFE_DeleteOp); + CHECK_STATUS(env, status.get(), nullptr); + TFE_OpSetDevice(op.get(), "/job:localhost/replica:0/task:0/device:CPU:0", status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(input_handle, TFE_TensorHandle, input, nullptr); + TFE_OpAddInput(op.get(), input_handle, status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + REQUIRE_HANDLE(attr_T_input_handle, TFE_TensorHandle, input, nullptr); + const TF_DataType attr_T = TFE_TensorHandleDataType(attr_T_input_handle); + TFE_OpSetAttrType(op.get(), "T", attr_T); + + TFE_OpSetAttrBool(op.get(), "compute_uv", static_cast(compute_uv)); + + TFE_OpSetAttrBool(op.get(), "full_matrices", static_cast(full_matrices)); + + const int num_outputs = 3; + std::unique_ptr outputs(new TFE_TensorHandle* [num_outputs]); + std::unique_ptr actual_num_outputs(new int[1] {num_outputs}); + TFE_Execute(op.get(), outputs.get(), actual_num_outputs.get(), status.get()); + CHECK_STATUS(env, status.get(), nullptr); + + jlongArray outputs_array = env->NewLongArray(static_cast(num_outputs)); + jlong* output_elems = env->GetLongArrayElements(outputs_array, nullptr); + for (int i = 0; i < num_outputs; ++i) { + output_elems[i] = reinterpret_cast(outputs[i]); + } + env->ReleaseLongArrayElements(outputs_array, output_elems, 0); + return outputs_array; +} diff --git a/modules/jni/src/main/native/generated/tensor_linalg_ops.h b/modules/jni/src/main/native/generated/tensor_linalg_ops.h index 4eda5d99f..0ebd73aac 100644 --- a/modules/jni/src/main/native/generated/tensor_linalg_ops.h +++ b/modules/jni/src/main/native/generated/tensor_linalg_ops.h @@ -79,6 +79,22 @@ JNIEXPORT jlong JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Lina JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_qr (JNIEnv *, jobject, jlong, jlong, jboolean); +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: selfAdjointEigV2 + * Signature: (JJZ)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_selfAdjointEigV2 + (JNIEnv *, jobject, jlong, jlong, jboolean); + +/* + * Class: org_platanios_tensorflow_jni_generated_tensors_Linalg__ + * Method: svd + * Signature: (JJZZ)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_platanios_tensorflow_jni_generated_tensors_Linalg_00024_svd + (JNIEnv *, jobject, jlong, jlong, jboolean, jboolean); + #ifdef __cplusplus } #endif diff --git a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala index 62720dc82..123b20223 100644 --- a/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala +++ b/modules/jni/src/main/scala/org/platanios/tensorflow/jni/generated/tensors/Linalg.scala @@ -31,4 +31,6 @@ object Linalg { @native def matrixSolveLs(contextHandle: Long, matrix: Long, rhs: Long, l2_regularizer: Long, fast: Boolean): Long @native def matrixTriangularSolve(contextHandle: Long, matrix: Long, rhs: Long, lower: Boolean, adjoint: Boolean): Long @native def qr(contextHandle: Long, input: Long, full_matrices: Boolean): Array[Long] + @native def selfAdjointEigV2(contextHandle: Long, input: Long, compute_v: Boolean): Array[Long] + @native def svd(contextHandle: Long, input: Long, compute_uv: Boolean, full_matrices: Boolean): Array[Long] } From 790a838a45ac9f801e15503145639199289453bc Mon Sep 17 00:00:00 2001 From: Mandar Chandorkar Date: Wed, 2 Oct 2019 18:09:57 +0200 Subject: [PATCH 11/11] Added scaladoc comments to linear algebra routines. --- .../platanios/tensorflow/api/ops/Linalg.scala | 73 ++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala index 5f471f0a9..0a5afd60a 100644 --- a/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala +++ b/modules/api/src/main/scala/org/platanios/tensorflow/api/ops/Linalg.scala @@ -36,6 +36,29 @@ import com.google.protobuf.Descriptors.FieldDescriptor */ trait Linalg { + /** + * Performs cholesky decomposition of one or more self-adjoint matrices. + * + * The input is a tensor of shape [..., M, M] whose inner-most 2 + * dimensions form square matrices. + * + * The input has to be symmetric and positive definite. Only the lower-triangular + * part of the input will be used for this operation. The upper-triangular part + * will not be read. The output is a tensor of the same shape as the input + * containing the Cholesky decompositions for all input submatrices `[..., :, :]`. + * **Note**: The gradient computation on GPU is faster for large matrices but + * not for large batch dimensions when the submatrices are small. In this + * case it might be faster to use the CPU. + * + * Returns: + * Output of shape [M, M] + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The input. + * + * @param name An optional name to assign to the op. + * + */ def cholesky[T: TF: IsRealOrComplex](matrix: Output[T], name: String = "Cholesky"): Output[T] = Op.Builder[Output[T], Output[T]]( opType = "Cholesky", @@ -140,7 +163,7 @@ trait Linalg { * Solves systems of linear equations Ax = b, in the regularised * least squares sense. * - * The matrix M must be of shape [..., M, N] whose inner-most 2 dimensions + * The matrix A must be of shape [..., M, N] whose inner-most 2 dimensions * form square matrices. * * The right hand side b is a tensor of shape [..., M, K]. @@ -189,6 +212,28 @@ trait Linalg { input = (matrix, rhs) ).setAttribute("lower", lower).setAttribute("adjoint", adjoint).build().output + /** + * Performs QR decomposition of a matrix. + * + * The matrix must be of [..., M, N] whose inner-most 2 dimensions + * form matrices of size [M, N]. Let P be the minimum of M and N. + * + * Returns: + * q: Orthonormal basis for range of the input matrix. If + * full_matrices is `False` then shape is [..., M, P]; + * if full_matrices is `True` then shape is [..., M, M]. + * r: Triangular factor. If full_matrices is `False` then shape is + * [..., P, N]. If full_matrices is `True` then shape is [..., M, N]. + * + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The input. + * @param full_matrices If true, compute full-sized q and r. + * If false (the default), compute only the + * leading P columns of q. + * @param name An optional name to assign to the op. + * + */ def qr[T: TF: IsRealOrComplex]( matrix: Output[T], full_matrices: Boolean = false, @@ -211,6 +256,32 @@ trait Linalg { input = matrix ).setAttribute("compute_v", compute_v).build().output + /** + * Performs singular value decomposition of a matrix. + * + * The matrix must be of [..., M, N] whose inner-most 2 dimensions + * form matrices of size [M, N]. Let P be the minimum of M and N. + * + * Returns: + * s: Singular values. Shape is [..., P]. + * u: Left singular vectors. If full_matrices is False then shape is + * [..., M, P]; if full_matrices is True then shape is + * [..., M, M]. Undefined if compute_uv is False. + * v: Left singular vectors. If full_matrices is False then shape is + * [..., N, P]. If full_matrices is True then shape is [..., N, N]. + * Undefined if compute_uv is false. + * + * + * @tparam T The underlying scala type of the matrix elements. + * @param matrix The matrix to decompose. + * @param full_matrices If true, compute full-sized u and v. + * If false (the default), compute only the + * leading P singular vectors. + * @param compute_uv If true, left and right singular vectors will be + * computed and returned in u and v, respectively. + * @param name An optional name to assign to the op. + * + */ def svd[T: TF: IsRealOrComplex]( matrix: Output[T], compute_uv: Boolean = true,