Skip to content

Commit 3f3759f

Browse files
committed
sciwork demo code
1 parent 5012f4d commit 3f3759f

File tree

6 files changed

+90
-72
lines changed

6 files changed

+90
-72
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
uTensor

.vscode/settings.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"python.pythonPath": ".venv/bin/python"
3+
}

gen_inputs_header.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import argparse
4+
5+
6+
def main(num_samples=5, seed=None):
7+
mnist = tf.keras.datasets.mnist
8+
(_, _), (x_test, y_test) = mnist.load_data()
9+
x_test = x_test / 255.0
10+
11+
# Add a channels dimension
12+
x_test = x_test[..., np.newaxis]
13+
total_N = y_test.shape[0]
14+
num_samples = min(num_samples, total_N)
15+
np.random.seed(seed)
16+
idxs = np.random.choice(range(total_N), num_samples, replace=False)
17+
x_selected = x_test[idxs].reshape(num_samples, -1)
18+
y_selected = y_test[idxs]
19+
with open("input_image.h", "w") as fid:
20+
fid.write("// clang-format off\n")
21+
fid.write(
22+
"const float arr_input_image[{}][{}] = {{\n".format(
23+
x_selected.shape[0], x_selected.shape[1]
24+
)
25+
)
26+
for i in range(x_selected.shape[0]):
27+
arr = x_selected[i]
28+
fid.write(" {{ {}".format(", ".join(map(str, arr))))
29+
fid.write("},\n")
30+
fid.write("};\n")
31+
fid.write("const int ref_labels[{}] = {{\n".format(y_selected.shape[0]))
32+
fid.write(" " + ", ".join(map(str, y_selected)) + "\n")
33+
fid.write("};\n\n")
34+
35+
36+
if __name__ == "__main__":
37+
parser = argparse.ArgumentParser()
38+
parser.add_argument(
39+
"--num-samples",
40+
dest="num_samples",
41+
default=5,
42+
help="the number of inpute samples [default: %(default)s]",
43+
type=int,
44+
metavar="INTEGER",
45+
)
46+
parser.add_argument("--seed", default=None, help="the random seed", type=int)
47+
args = vars(parser.parse_args())
48+
main(**args)

0 commit comments

Comments
 (0)