@@ -56,7 +56,7 @@ Check [Compute.scala on Scaladex](https://index.scala-lang.org/thoughtworksinc/c
5656
5757### Creating an N-dimensional array
5858
59- Import different the namespace object ` gpu ` or ` cpu ` , according to the OpenCL runtime you want to use.
59+ Import types in ` gpu ` or ` cpu ` object according to the OpenCL runtime you want to use.
6060
6161``` scala
6262// For N-dimensional array on GPU
@@ -68,7 +68,7 @@ import com.thoughtworks.compute.gpu._
6868import com .thoughtworks .compute .cpu ._
6969```
7070
71- In Compute.scala, an N-dimensional array is typed as ` Tensor ` , which can be created from ` Seq ` or ` scala. Array` .
71+ In Compute.scala, an N-dimensional array is typed as ` Tensor ` , which can be created from ` Seq ` or ` Array ` .
7272
7373``` scala
7474val my2DArray : Tensor = Tensor (Array (Seq (1.0f , 2.0f , 3.0f ), Seq (4.0f , 5.0f , 6.0f )))
@@ -203,7 +203,7 @@ By combining pure `Tensor`s along with the impure `cache` mechanism, we achieved
203203
204204A ` Tensor ` can be ` split ` into small ` Tensor ` s on the direction of a specific dimension.
205205
206- For example, given a 3D tensor whose ` shape ` is 2x3x4 ,
206+ For example, given a 3D tensor whose ` shape ` is 2×3×4 ,
207207
208208``` scala
209209val my3DTensor = Tensor ((0.0f until 24.0f by 1.0f ).grouped(4 ).toSeq.grouped(3 ).toSeq)
@@ -214,10 +214,10 @@ val Array(2, 3, 4) = my3DTensor.shape
214214when ` split ` it at the dimension #0 ,
215215
216216``` scala
217- val subtensors0 = my3DTensor.split(dimension = 0 )
217+ val subtensors0 : Seq [ Tensor ] = my3DTensor.split(dimension = 0 )
218218```
219219
220- then the result should be a ` Seq ` of two 3x4 tensors.
220+ then the result should be a ` Seq ` of two 3×4 tensors.
221221
222222``` scala
223223// Output: TensorSeq([[0.0,1.0,2.0,3.0],[4.0,5.0,6.0,7.0],[8.0,9.0,10.0,11.0]], [[12.0,13.0,14.0,15.0],[16.0,17.0,18.0,19.0],[20.0,21.0,22.0,23.0]])
@@ -227,21 +227,47 @@ println(subtensors0)
227227When ` split ` it at the dimension #1 ,
228228
229229``` scala
230- val subtensors1 = my3DTensor.split(dimension = 1 )
230+ val subtensors1 : Seq [ Tensor ] = my3DTensor.split(dimension = 1 )
231231```
232232
233- then the result should be a ` Seq ` of three 2x4 tensors.
233+ then the result should be a ` Seq ` of three 2×4 tensors.
234234
235235``` scala
236236// Output: TensorSeq([[0.0,1.0,2.0,3.0],[12.0,13.0,14.0,15.0]], [[4.0,5.0,6.0,7.0],[16.0,17.0,18.0,19.0]], [[8.0,9.0,10.0,11.0],[20.0,21.0,22.0,23.0]])
237237println(subtensors1)
238238```
239239
240- Then you can use arbitrary Scala collection functions on Seq of subtensors.
240+ Then you can use arbitrary Scala collection functions on the ` Seq ` of subtensors.
241241
242242#### ` join `
243243
244- TODO
244+ Multiple ` Tensor ` s of the same ` shape ` can be merged into a larger ` Tensor ` via the ` Tensor.join ` function.
245+
246+ Given a ` Seq ` of three 2×2 ` Tensor ` s,
247+
248+ ``` scala
249+ val mySubtensors : Seq [Tensor ] = Seq (
250+ Tensor (Seq (Seq (1.0f , 2.0f ), Seq (3.0f , 4.0f ))),
251+ Tensor (Seq (Seq (5.0f , 6.0f ), Seq (7.0f , 8.0f ))),
252+ Tensor (Seq (Seq (9.0f , 10.0f ), Seq (11.0f , 12.0f ))),
253+ )
254+ ```
255+
256+ when ` join ` ing them,
257+ ``` scala
258+ val merged : Tensor = Tensor .join(mySubtensors)
259+ ```
260+
261+ then the result should be a 2x2x3 ` Tensor ` .
262+
263+ ``` scala
264+ // Output: [[[1.0,5.0,9.0],[2.0,6.0,10.0]],[[3.0,7.0,11.0],[4.0,8.0,12.0]]]
265+ println(merged.toString)
266+ ```
267+
268+ Generally, when ` join ` ing * n* ` Tensor ` s of shape * a* <sub >0</sub > × * a* <sub >1</sub > × * a* <sub >2</sub > × ⋯ × * a* <sub >* i* </sub > , the shape of the result ` Tensor ` is * a* <sub >0</sub > × * a* <sub >1</sub > × * a* <sub >2</sub > × ⋯ × * a* <sub >* i* </sub > × * n*
269+
270+
245271
246272#### Fast matrix multiplication from ` split ` and ` join `
247273
0 commit comments