Skip to content

Commit 94f429e

Browse files
committed
Changing fit interface to formula to support list columns
1 parent a049fa2 commit 94f429e

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

R/generic_functional_fit.R

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,22 @@
6464
#' @keywords internal
6565
#' @export
6666
generic_functional_fit <- function(
67-
x,
68-
y,
67+
formula,
68+
data,
6969
layer_blocks,
7070
...
7171
) {
72+
# Separate predictors and outcomes from the processed data frame provided by parsnip
73+
y_names <- all.vars(formula[[2]])
74+
x_names <- all.vars(formula[[3]])
75+
76+
# Handle the `.` case for predictors
77+
if ("." %in% x_names) {
78+
x <- data[, !(names(data) %in% y_names), drop = FALSE]
79+
} else {
80+
x <- data[, x_names, drop = FALSE]
81+
}
82+
y <- data[, y_names, drop = FALSE]
7283
# --- 1. Build and Compile Model ---
7384
model <- build_and_compile_functional_model(x, y, layer_blocks, ...)
7485

R/generic_sequential_fit.R

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,22 @@
6060
#' @keywords internal
6161
#' @export
6262
generic_sequential_fit <- function(
63-
x,
64-
y,
63+
formula,
64+
data,
6565
layer_blocks,
6666
...
6767
) {
68+
# Separate predictors and outcomes from the processed data frame provided by parsnip
69+
y_names <- all.vars(formula[[2]])
70+
x_names <- all.vars(formula[[3]])
71+
72+
# Handle the `.` case for predictors
73+
if ("." %in% x_names) {
74+
x <- data[, !(names(data) %in% y_names), drop = FALSE]
75+
} else {
76+
x <- data[, x_names, drop = FALSE]
77+
}
78+
y <- data[, y_names, drop = FALSE]
6879
# --- 1. Build and Compile Model ---
6980
model <- build_and_compile_sequential_model(x, y, layer_blocks, ...)
7081

R/register_fit_predict.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ register_fit_predict <- function(model_name, mode, layer_blocks, functional) {
3030
eng = "keras",
3131
mode = mode,
3232
value = list(
33-
interface = "data.frame",
34-
protect = c("x", "y"),
33+
interface = "formula",
34+
protect = c("formula", "data"),
3535
func = c(
3636
pkg = "kerasnip",
3737
fun = if (functional) {

0 commit comments

Comments
 (0)