1- from skorch import NeuralNetClassifier , NeuralNet
2- from skorch .dataset import Dataset as SkorchDataset
3- import torch .nn as nn
4- import torch_frame
5- from torch_frame .data .tensor_frame import TensorFrame
6- from torch_frame .utils import infer_df_stype
7- from torch_frame .data .dataset import DataFrameToTensorFrameConverter , Dataset
8- from torch_frame .data .loader import DataLoader
9- import torch
10- from torch_frame .typing import IndexSelectType
11- from torch import Tensor
12- from pandas import DataFrame
131from typing import Any
2+
143import pandas as pd
4+ import torch
5+ import torch .nn as nn
156from numpy .typing import ArrayLike
7+ from pandas import DataFrame
8+ from skorch import NeuralNet , NeuralNetClassifier
9+ from skorch .dataset import Dataset as SkorchDataset
10+ from torch import Tensor
11+
12+ import torch_frame
1613from torch_frame .config import (
1714 ImageEmbedderConfig ,
1815 TextEmbedderConfig ,
1916 TextTokenizerConfig ,
2017)
18+ from torch_frame .data .dataset import DataFrameToTensorFrameConverter , Dataset
19+ from torch_frame .data .loader import DataLoader
20+ from torch_frame .data .tensor_frame import TensorFrame
21+ from torch_frame .typing import IndexSelectType
22+ from torch_frame .utils import infer_df_stype
23+
2124
2225class NeuralNetPytorchFrameDataLoader (DataLoader ):
23- def __init__ (
24- self , dataset : Dataset | TensorFrame , * args , device : torch .device , ** kwargs
25- ):
26+ def __init__ (self , dataset : Dataset | TensorFrame , * args ,
27+ device : torch .device , ** kwargs ):
2628 super ().__init__ (dataset , * args , ** kwargs )
2729 self .device = device
2830
29- def collate_fn (self , index : IndexSelectType ) -> tuple [TensorFrame , Tensor | None ]:
31+ def collate_fn (
32+ self , index : IndexSelectType ) -> tuple [TensorFrame , Tensor | None ]:
3033 index = torch .tensor (index )
3134 res = super ().collate_fn (index ).to (self .device )
3235 return res , res .y
@@ -112,14 +115,18 @@ def create_dataset(self, df: DataFrame, _: Any) -> Dataset:
112115 dataset_ .materialize ()
113116 return dataset_
114117
115- def split_dataset (self , dataset : Dataset ) -> tuple [TensorFrame , TensorFrame ]:
118+ def split_dataset (self ,
119+ dataset : Dataset ) -> tuple [TensorFrame , TensorFrame ]:
116120 datasets = dataset .split ()[:2 ]
117121 return datasets [0 ].tensor_frame , datasets [1 ].tensor_frame
118122
119- def iterator_train_valid (self , dataset : Dataset , ** kwargs : Any ) -> DataLoader :
120- return NeuralNetPytorchFrameDataLoader (dataset , device = self .device , ** kwargs )
123+ def iterator_train_valid (self , dataset : Dataset ,
124+ ** kwargs : Any ) -> DataLoader :
125+ return NeuralNetPytorchFrameDataLoader (dataset , device = self .device ,
126+ ** kwargs )
121127
122- def fit (self , X : Dataset | DataFrame , y : ArrayLike | None = None , ** fit_params ):
128+ def fit (self , X : Dataset | DataFrame , y : ArrayLike | None = None ,
129+ ** fit_params ):
123130 if isinstance (X , DataFrame ):
124131 if y is not None :
125132 X ["target_col" ] = y
@@ -138,9 +145,11 @@ def fit(self, X: Dataset | DataFrame, y: ArrayLike | None=None, **fit_params):
138145 self .dataset_ = X
139146 return super ().fit (self .dataset_ .df , None , ** fit_params )
140147
148+
141149# TODO: make this behave more like NeuralNetClassifier
142150class NeuralNetClassifierPytorchFrame (NeuralNetPytorchFrame ):
143- def fit (self , X : Dataset | DataFrame , y : ArrayLike | None = None , ** fit_params ):
151+ def fit (self , X : Dataset | DataFrame , y : ArrayLike | None = None ,
152+ ** fit_params ):
144153 fit_result = super ().fit (X , y , ** fit_params )
145154 self .classes = self .dataset_ .df ["target_col" ].unique ()
146- return fit_result
155+ return fit_result
0 commit comments