@@ -146,10 +146,64 @@ plt.show()
146146
147147
148148
149+ # Create example values
150+ height = 224
151+ width = 224
152+ color_channels = 3
153+ patch_size = 16
149154
155+ # Calculate the number of patches
156+ number_of_patches = int((height * width) / patch_size**2)
157+ number_of_patches
150158
151159
160+ # Input shape
161+ embedding_layer_input_shape = (height, width, color_channels)
152162
163+ # Output shape
164+ embedding_layer_output_shape = (number_of_patches, patch_size**2 * color_channels)
165+
166+ print(f"Input shape (single 2D image): {embedding_layer_input_shape}")
167+ print(f"Output shape (single 1D sequence of patches): {embedding_layer_output_shape} -> (number_of_patches, embedding_dimension)")
168+
169+
170+
171+
172+
173+ plt.imshow(imgs[0].permute(1,2,0))
174+
175+
176+ image = imgs[0]
177+
178+
179+ image_permuted = image.permute(1,2,0)
180+
181+
182+ patch_size = 16
183+
184+
185+ plt.figure(figsize=(patch_size,patch_size))
186+ plt.imshow(image_permuted[:patch_size, :, :])
187+
188+
189+ # Setup code to plot top row as patches
190+ img_size = 224
191+ patch_size = 16
192+ num_patches = img_size / patch_size
193+ assert img_size % patch_size == 0, "Image Size must be divisible by patch size"
194+
195+ num_patches
196+
197+
198+ img_size / patch_size, img_size // patch_size
199+
200+
201+ fig, axs = plt.subplots(nrows=1, ncols=img_size // patch_size, sharex=True, sharey=True)
202+ for i, patch in enumerate(range(0, img_size, patch_size)):
203+ axs[i].imshow(image_permuted[:patch_size, patch:patch+patch_size, :]);
204+ axs[i].set_xlabel(i+1)
205+ axs[i].set_xticks([])
206+ axs[i].set_yticks([])
153207
154208
155209
0 commit comments