|
| 1 | +from typing import Any, Callable, Optional, Tuple |
| 2 | + |
| 3 | +import tensorflow as tf |
| 4 | +import xarray as xr |
| 5 | + |
| 6 | +# Notes: |
| 7 | +# This module includes one Keras dataset, which can be provided to model.fit(). |
| 8 | +# - The CustomTFDataset provides an indexable interface |
| 9 | +# Assumptions made: |
| 10 | +# - The dataset takes pre-configured X/y xbatcher generators (may not always want two generators in a dataset) |
| 11 | + |
| 12 | + |
| 13 | +class CustomTFDataset(tf.keras.utils.Sequence): |
| 14 | + def __init__( |
| 15 | + self, |
| 16 | + X_generator, |
| 17 | + y_generator, |
| 18 | + *, |
| 19 | + transform: Optional[Callable] = None, |
| 20 | + target_transform: Optional[Callable] = None, |
| 21 | + dim: str = 'new_dim', |
| 22 | + ) -> None: |
| 23 | + ''' |
| 24 | + Keras Dataset adapter for Xbatcher |
| 25 | +
|
| 26 | + Parameters |
| 27 | + ---------- |
| 28 | + X_generator : xbatcher.BatchGenerator |
| 29 | + y_generator : xbatcher.BatchGenerator |
| 30 | + transform : callable, optional |
| 31 | + A function/transform that takes in an array and returns a transformed version. |
| 32 | + target_transform : callable, optional |
| 33 | + A function/transform that takes in the target and transforms it. |
| 34 | + dim : str, 'new_dim' |
| 35 | + Name of dim to pass to :func:`xarray.concat` as the dimension |
| 36 | + to concatenate all variables along. |
| 37 | + ''' |
| 38 | + self.X_generator = X_generator |
| 39 | + self.y_generator = y_generator |
| 40 | + self.transform = transform |
| 41 | + self.target_transform = target_transform |
| 42 | + self.concat_dim = dim |
| 43 | + |
| 44 | + def __len__(self) -> int: |
| 45 | + return len(self.X_generator) |
| 46 | + |
| 47 | + def __getitem__(self, idx: int) -> Tuple[Any, Any]: |
| 48 | + X_batch = tf.convert_to_tensor( |
| 49 | + xr.concat( |
| 50 | + ( |
| 51 | + self.X_generator[idx][key] |
| 52 | + for key in list(self.X_generator[idx].keys()) |
| 53 | + ), |
| 54 | + self.concat_dim, |
| 55 | + ).data |
| 56 | + ) |
| 57 | + y_batch = tf.convert_to_tensor( |
| 58 | + xr.concat( |
| 59 | + ( |
| 60 | + self.y_generator[idx][key] |
| 61 | + for key in list(self.y_generator[idx].keys()) |
| 62 | + ), |
| 63 | + self.concat_dim, |
| 64 | + ).data |
| 65 | + ) |
| 66 | + |
| 67 | + # TODO: Should the transformations be applied before tensor conversion? |
| 68 | + if self.transform: |
| 69 | + X_batch = self.transform(X_batch) |
| 70 | + |
| 71 | + if self.target_transform: |
| 72 | + y_batch = self.target_transform(y_batch) |
| 73 | + return X_batch, y_batch |
0 commit comments