Skip to content

Commit 1ee6a1d

Browse files
csferngtensorflow-copybara
authored andcommitted
Document an example of using AdversarialRegularization with tf.data.Dataset.
PiperOrigin-RevId: 268580505
1 parent 67e86e7 commit 1ee6a1d

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

neural_structured_learning/keras/adversarial_regularization.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,20 @@ class AdversarialRegularization(keras.Model):
392392
# The model minimizes (mean_squared_error + 0.2 * adversarial_regularization).
393393
adv_model.fit(x={'input': x_train, 'label': y_train}, batch_size=32)
394394
```
395+
396+
It is recommended to use `tf.data.Dataset` to handle the dictionary format
397+
requirement of the input, especially when using the `validation_data` argument
398+
in `fit()`.
399+
400+
```python
401+
train_data = tf.data.Dataset.from_tensor_slices(
402+
{'input': x_train, 'label': y_train}).batch(batch_size)
403+
val_data = tf.data.Dataset.from_tensor_slices(
404+
{'input': x_val, 'label': y_val}).batch(batch_size)
405+
val_steps = x_val.shape[0] / batch_size
406+
adv_model.fit(train_data, validation_data=val_data,
407+
validation_steps=val_steps, epochs=2, verbose=1)
408+
```
395409
"""
396410

397411
def __init__(self,

0 commit comments

Comments
 (0)