@@ -33,9 +33,10 @@ struct Tensor {
3333 unsigned int data_index (unsigned int const indices[N]) const {
3434 unsigned int index = 0 ;
3535 for (unsigned int i = 0 ; i < N; ++i) {
36- ASSERT (indices[i] < shape[i]);
36+ ASSERT (indices[i] < shape[i], " Invalid index " );
3737 // TODO: 计算 index
3838 }
39+ return index;
3940 }
4041};
4142
@@ -48,10 +49,12 @@ int main(int argc, char **argv) {
4849 unsigned int i0[]{0 , 0 , 0 , 0 };
4950 tensor[i0] = 1 ;
5051 ASSERT (tensor[i0] == 1 , " tensor[i0] should be 1" );
52+ ASSERT (tensor.data [0 ] == 1 , " tensor[i0] should be 1" );
5153
5254 unsigned int i1[]{1 , 2 , 3 , 4 };
53- tensor[i0] = 2 ;
54- ASSERT (tensor[i0] == 2 , " tensor[i1] should be 2" );
55+ tensor[i1] = 2 ;
56+ ASSERT (tensor[i1] == 2 , " tensor[i1] should be 2" );
57+ ASSERT (tensor.data [119 ] == 2 , " tensor[i1] should be 2" );
5558 }
5659 {
5760 unsigned int shape[]{7 , 8 , 128 };
@@ -60,10 +63,12 @@ int main(int argc, char **argv) {
6063 unsigned int i0[]{0 , 0 , 0 };
6164 tensor[i0] = 1 .f ;
6265 ASSERT (tensor[i0] == 1 .f , " tensor[i0] should be 1" );
66+ ASSERT (tensor.data [0 ] == 1 .f , " tensor[i0] should be 1" );
6367
6468 unsigned int i1[]{3 , 4 , 99 };
65- tensor[i0] = 2 .f ;
66- ASSERT (tensor[i0] == 2 .f , " tensor[i1] should be 2" );
69+ tensor[i1] = 2 .f ;
70+ ASSERT (tensor[i1] == 2 .f , " tensor[i1] should be 2" );
71+ ASSERT (tensor.data [3683 ] == 2 .f , " tensor[i1] should be 2" );
6772 }
6873 return 0 ;
6974}
0 commit comments