Skip to content

Commit 14b4e7b

Browse files
committed
set nd4j version to 0.7.2 & remove MaxPool
1 parent 1acecf3 commit 14b4e7b

File tree

4 files changed

+144
-144
lines changed

4 files changed

+144
-144
lines changed

DifferentiableINDArray/build.sbt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ libraryDependencies ++= {
2828
if (VersionNumber(scalaVersion.value).numbers >= Seq(2, 12)) {
2929
Nil
3030
} else {
31-
Seq("org.nd4j" %% "nd4s" % "0.8.0",
32-
"org.nd4j" % "nd4j-api" % "0.8.0",
33-
"org.nd4j" % "nd4j-native-platform" % "0.8.0" % Test)
31+
Seq("org.nd4j" %% "nd4s" % "0.7.2",
32+
"org.nd4j" % "nd4j-api" % "0.7.2",
33+
"org.nd4j" % "nd4j-native-platform" % "0.7.2" % Test)
3434
}
3535
}

DifferentiableINDArray/src/main/scala/com/thoughtworks/deeplearning/DifferentiableINDArray.scala

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -816,68 +816,68 @@ object DifferentiableINDArray {
816816
}
817817
}
818818

819-
final case class MaxPool[Input0 <: Tape](override val operand: Layer.Aux[Input0, INDArrayPlaceholder.Tape],
820-
poolSize: (Int, Int))
821-
extends CumulativeLayer.Unary {
822-
override type CumulativeTape = INDArraySemigroupTape with SemigroupTape with UnaryTape
823-
824-
override type Input = Input0
825-
826-
override protected def rawForward(input0: Input): CumulativeTape = {
827-
new {
828-
override val input = input0
829-
} with INDArraySemigroupTape with SemigroupTape with UnaryTape {
830-
831-
private val upstreamShape = {
832-
upstream.value.shape()
833-
}
834-
835-
private val kernelAndStrideSize: Array[Int] = toArray(poolSize)
836-
837-
private val preMaxPool: INDArray =
838-
Convolution
839-
.im2col(upstream.value, kernelAndStrideSize, kernelAndStrideSize, Array(0, 0))
840-
.permute(0, 1, 4, 5, 2, 3)
841-
842-
private val preShape: Seq[Int] = preMaxPool.shape().toSeq
843-
844-
private val lastDimensionSize: Int = preShape.takeRight(2).product
845-
846-
private val reshapedPreMaxPool: INDArray = preMaxPool
847-
.reshape(preShape.take(preShape.length - 2) :+ lastDimensionSize: _*)
848-
849-
override val value = reshapedPreMaxPool.max(4)
850-
851-
override protected def rawBackward(outputDelta: INDArray): Unit = {
852-
853-
val a = reshapedPreMaxPool
854-
val upStreamDup = a.dup()
855-
val rows = ArrayUtil.prod(a.length())
856-
857-
val isMax: INDArray = Nd4j.getExecutioner
858-
.execAndReturn(new IsMax(upStreamDup, 4))
859-
.reshape(preShape.take(preShape.length - 2) :+ poolSize._2 :+ poolSize._1: _*)
860-
.permute(0, 1, 2, 4, 3, 5)
861-
.reshape('c', rows, 1)
862-
863-
val outputDelta1d = {
864-
outputDelta
865-
.repeat(-1, poolSize._1)
866-
.permute(1, 0, 3, 2)
867-
.repeat(-1, poolSize._2)
868-
.permute(1, 0, 3, 2)
869-
.reshape('c', upstreamShape.product, 1)
870-
}
871-
872-
upstream.backward(
873-
isMax
874-
.muliColumnVector(outputDelta1d)
875-
.reshape(upstreamShape: _*)
876-
)
877-
}
878-
}
879-
}
880-
}
819+
// final case class MaxPool[Input0 <: Tape](override val operand: Layer.Aux[Input0, INDArrayPlaceholder.Tape],
820+
// poolSize: (Int, Int))
821+
// extends CumulativeLayer.Unary {
822+
// override type CumulativeTape = INDArraySemigroupTape with SemigroupTape with UnaryTape
823+
//
824+
// override type Input = Input0
825+
//
826+
// override protected def rawForward(input0: Input): CumulativeTape = {
827+
// new {
828+
// override val input = input0
829+
// } with INDArraySemigroupTape with SemigroupTape with UnaryTape {
830+
//
831+
// private val upstreamShape = {
832+
// upstream.value.shape()
833+
// }
834+
//
835+
// private val kernelAndStrideSize: Array[Int] = toArray(poolSize)
836+
//
837+
// private val preMaxPool: INDArray =
838+
// Convolution
839+
// .im2col(upstream.value, kernelAndStrideSize, kernelAndStrideSize, Array(0, 0))
840+
// .permute(0, 1, 4, 5, 2, 3)
841+
//
842+
// private val preShape: Seq[Int] = preMaxPool.shape().toSeq
843+
//
844+
// private val lastDimensionSize: Int = preShape.takeRight(2).product
845+
//
846+
// private val reshapedPreMaxPool: INDArray = preMaxPool
847+
// .reshape(preShape.take(preShape.length - 2) :+ lastDimensionSize: _*)
848+
//
849+
// override val value = reshapedPreMaxPool.max(4)
850+
//
851+
// override protected def rawBackward(outputDelta: INDArray): Unit = {
852+
//
853+
// val a = reshapedPreMaxPool
854+
// val upStreamDup = a.dup()
855+
// val rows = ArrayUtil.prod(a.length())
856+
//
857+
// val isMax: INDArray = Nd4j.getExecutioner
858+
// .execAndReturn(new IsMax(upStreamDup, 4))
859+
// .reshape(preShape.take(preShape.length - 2) :+ poolSize._2 :+ poolSize._1: _*)
860+
// .permute(0, 1, 2, 4, 3, 5)
861+
// .reshape('c', rows, 1)
862+
//
863+
// val outputDelta1d = {
864+
// outputDelta
865+
// .repeat(-1, poolSize._1)
866+
// .permute(1, 0, 3, 2)
867+
// .repeat(-1, poolSize._2)
868+
// .permute(1, 0, 3, 2)
869+
// .reshape('c', upstreamShape.product, 1)
870+
// }
871+
//
872+
// upstream.backward(
873+
// isMax
874+
// .muliColumnVector(outputDelta1d)
875+
// .reshape(upstreamShape: _*)
876+
// )
877+
// }
878+
// }
879+
// }
880+
// }
881881

882882
final case class Shape[Input0 <: Tape](operand: Layer.Aux[Input0, INDArrayPlaceholder.Tape]) extends Layer {
883883
override def forward(input: Input0): Output = {
@@ -1373,9 +1373,9 @@ object DifferentiableINDArray {
13731373
Permute(operand, DifferentiableSeq.Layers.ToSeq(newShape.map(toLayer.apply(_))))
13741374
}
13751375

1376-
def maxPool(poolSize: (Int, Int)): Layer.Aux[Input, INDArrayPlaceholder.Tape] = {
1377-
MaxPool(operand, poolSize)
1378-
}
1376+
// def maxPool(poolSize: (Int, Int)): Layer.Aux[Input, INDArrayPlaceholder.Tape] = {
1377+
// MaxPool(operand, poolSize)
1378+
// }
13791379

13801380
/**
13811381
* Returns shape of INDArray

DifferentiableINDArray/src/test/scala/com/thoughtworks/deeplearning/LayerSpec.scala

Lines changed: 75 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -318,81 +318,81 @@ final class LayerSpec extends FreeSpec with Matchers with Inside {
318318
shapeSeq(3) should be(9)
319319
}
320320

321-
"INDArrayPlaceholder maxPool poolSize --forward" in {
322-
323-
implicit val learningRate = new LearningRate {
324-
override def currentLearningRate() = 0.03
325-
}
326-
327-
def makeNetwork(poolSize: (Int, Int))(implicit x: INDArray @Symbolic) = {
328-
val weightInitialValue = 1 to 96
329-
weightInitialValue.toNDArray.reshape(2, 3, 4, 4).toWeight.maxPool(poolSize)
330-
}
331-
332-
val network = makeNetwork((2, 2))
333-
334-
val inputData = 1 to 96
335-
336-
def train() = {
337-
val outputTape = network.forward(
338-
inputData.toNDArray.reshape(2, 3, 4, 4).toTape
339-
)
340-
try {
341-
val loss = (outputTape.value: INDArray).sumT
342-
outputTape.backward(outputTape.value)
343-
loss
344-
} finally {
345-
outputTape.close()
346-
}
347-
}
348-
349-
train().value should be(1224.0)
350-
351-
for (_ <- 0 until 700) {
352-
train().value
353-
}
354-
355-
math.abs(train().value) should be < 10.0
356-
357-
}
358-
359-
"INDArrayPlaceholder maxPool poolSize --backward" in {
360-
361-
implicit val learningRate = new LearningRate {
362-
override def currentLearningRate() = 1.0
363-
}
364-
365-
def makeNetwork(poolSize: (Int, Int))(implicit x: INDArray @Symbolic) = {
366-
val weightInitialValue = 1 to 96
367-
weightInitialValue.toNDArray.reshape(2, 3, 4, 4).toWeight.maxPool(poolSize)
368-
}
369-
370-
val network = makeNetwork((2, 2))
371-
372-
val inputData = 1 to 96
373-
374-
def train() = {
375-
val outputTape = network.forward(
376-
inputData.toNDArray.reshape(2, 3, 4, 4).toTape
377-
)
378-
try {
379-
val loss = (outputTape.value: INDArray).sumT
380-
outputTape.backward(outputTape.value)
381-
loss
382-
} finally {
383-
outputTape.close()
384-
}
385-
}
386-
387-
train().value
388-
389-
val result = inside(network) {
390-
case MaxPool(Weight(w), _) => w
391-
}
392-
393-
result.sumT should be(3432)
394-
395-
}
321+
// "INDArrayPlaceholder maxPool poolSize --forward" in {
322+
//
323+
// implicit val learningRate = new LearningRate {
324+
// override def currentLearningRate() = 0.03
325+
// }
326+
//
327+
// def makeNetwork(poolSize: (Int, Int))(implicit x: INDArray @Symbolic) = {
328+
// val weightInitialValue = 1 to 96
329+
// weightInitialValue.toNDArray.reshape(2, 3, 4, 4).toWeight.maxPool(poolSize)
330+
// }
331+
//
332+
// val network = makeNetwork((2, 2))
333+
//
334+
// val inputData = 1 to 96
335+
//
336+
// def train() = {
337+
// val outputTape = network.forward(
338+
// inputData.toNDArray.reshape(2, 3, 4, 4).toTape
339+
// )
340+
// try {
341+
// val loss = (outputTape.value: INDArray).sumT
342+
// outputTape.backward(outputTape.value)
343+
// loss
344+
// } finally {
345+
// outputTape.close()
346+
// }
347+
// }
348+
//
349+
// train().value should be(1224.0)
350+
//
351+
// for (_ <- 0 until 700) {
352+
// train().value
353+
// }
354+
//
355+
// math.abs(train().value) should be < 10.0
356+
//
357+
// }
358+
//
359+
// "INDArrayPlaceholder maxPool poolSize --backward" in {
360+
//
361+
// implicit val learningRate = new LearningRate {
362+
// override def currentLearningRate() = 1.0
363+
// }
364+
//
365+
// def makeNetwork(poolSize: (Int, Int))(implicit x: INDArray @Symbolic) = {
366+
// val weightInitialValue = 1 to 96
367+
// weightInitialValue.toNDArray.reshape(2, 3, 4, 4).toWeight.maxPool(poolSize)
368+
// }
369+
//
370+
// val network = makeNetwork((2, 2))
371+
//
372+
// val inputData = 1 to 96
373+
//
374+
// def train() = {
375+
// val outputTape = network.forward(
376+
// inputData.toNDArray.reshape(2, 3, 4, 4).toTape
377+
// )
378+
// try {
379+
// val loss = (outputTape.value: INDArray).sumT
380+
// outputTape.backward(outputTape.value)
381+
// loss
382+
// } finally {
383+
// outputTape.close()
384+
// }
385+
// }
386+
//
387+
// train().value
388+
//
389+
// val result = inside(network) {
390+
// case MaxPool(Weight(w), _) => w
391+
// }
392+
//
393+
// result.sumT should be(3432)
394+
//
395+
// }
396396

397397
"INDArrayPlaceholder shape --only forward no backward" in {
398398

build.sbt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ addCompilerPlugin("com.thoughtworks.implicit-dependent-type" %% "implicit-depend
7373

7474
libraryDependencies += "com.thoughtworks.enableIf" %% "enableif" % "1.1.4" % Test
7575

76-
libraryDependencies += "org.nd4j" % "nd4j-native-platform" % "0.8.0" % Test
76+
libraryDependencies += "org.nd4j" % "nd4j-native-platform" % "0.7.2" % Test
7777

7878
crossScalaVersions := Seq("2.10.6", "2.11.8", "2.12.1")
7979

0 commit comments

Comments
 (0)