Skip to content

Commit 05e6377

Browse files
committed
[OPS] Minor edits.
1 parent aa1726f commit 05e6377

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

api/src/main/scala/org/platanios/tensorflow/api/ops/io/data/Dataset.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,9 @@ object Dataset {
235235
val GroupByWindowDataset: data.GroupByWindowDataset.type = data.GroupByWindowDataset
236236

237237
def fromGenerator[T, O, DA, D, S](
238-
generator: () => Iterable[T], outputDataType: DA, outputShape: S = null
238+
generator: () => Iterable[T],
239+
outputDataType: DA,
240+
outputShape: S = null
239241
)(implicit
240242
evDAToD: DataTypeAuxToDataType.Aux[DA, D],
241243
evData: Data.Aux[T, O, D, S],
@@ -285,7 +287,9 @@ object Dataset {
285287
* @return Constructed dataset.
286288
*/
287289
private[api] def fromGenerator[T, O, DA, D, S](
288-
generator: () => Iterable[T], outputDataType: DA, outputShape: S = null
290+
generator: () => Iterable[T],
291+
outputDataType: DA,
292+
outputShape: S = null
289293
)(implicit
290294
evDAToD: DataTypeAuxToDataType.Aux[DA, D],
291295
evData: Data.Aux[T, O, D, S],

tpu/src/main/scala/org/platanios/tensorflow/tpu/Ops.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,5 +158,11 @@ private[tpu] object Ops {
158158
tf.gradientsRegistry.registerNonDifferentiable("InfeedEnqueueTuple")
159159
tf.gradientsRegistry.registerNonDifferentiable("InfeedDequeue")
160160
tf.gradientsRegistry.registerNonDifferentiable("InfeedDequeueTuple")
161+
162+
tf.gradientsRegistry.register("CrossReplicaSum", crossReplicaSumGradient)
163+
164+
private[this] def crossReplicaSumGradient(op: Op, outputGradients: Seq[OutputLike]): Seq[OutputLike] = {
165+
Seq(crossReplicaSum(outputGradients.head))
166+
}
161167
}
162168
}
Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
1-
package org.platanios.tensorflow.tpu
2-
3-
class Topology {
4-
5-
}
1+
///* Copyright 2017-18, Emmanouil Antonios Platanios. All Rights Reserved.
2+
// *
3+
// * Licensed under the Apache License, Version 2.0 (the "License"); you may not
4+
// * use this file except in compliance with the License. You may obtain a copy of
5+
// * the License at
6+
// *
7+
// * http://www.apache.org/licenses/LICENSE-2.0
8+
// *
9+
// * Unless required by applicable law or agreed to in writing, software
10+
// * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11+
// * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12+
// * License for the specific language governing permissions and limitations under
13+
// * the License.
14+
// */
15+
//
16+
//package org.platanios.tensorflow.tpu
17+
//
18+
///**
19+
// * @author Emmanouil Antonios Platanios
20+
// */
21+
//case class Topology(meshShape: (Int, Int, Int), deviceCoordinates) {
22+
//
23+
//}

0 commit comments

Comments
 (0)