File tree Expand file tree Collapse file tree 1 file changed +14
-0
lines changed
neural_structured_learning/keras Expand file tree Collapse file tree 1 file changed +14
-0
lines changed Original file line number Diff line number Diff line change @@ -392,6 +392,20 @@ class AdversarialRegularization(keras.Model):
392392 # The model minimizes (mean_squared_error + 0.2 * adversarial_regularization).
393393 adv_model.fit(x={'input': x_train, 'label': y_train}, batch_size=32)
394394 ```
395+
396+ It is recommended to use `tf.data.Dataset` to handle the dictionary format
397+ requirement of the input, especially when using the `validation_data` argument
398+ in `fit()`.
399+
400+ ```python
401+ train_data = tf.data.Dataset.from_tensor_slices(
402+ {'input': x_train, 'label': y_train}).batch(batch_size)
403+ val_data = tf.data.Dataset.from_tensor_slices(
404+ {'input': x_val, 'label': y_val}).batch(batch_size)
405+ val_steps = x_val.shape[0] / batch_size
406+ adv_model.fit(train_data, validation_data=val_data,
407+ validation_steps=val_steps, epochs=2, verbose=1)
408+ ```
395409 """
396410
397411 def __init__ (self ,
You can’t perform that action at this time.
0 commit comments