Skip to content

Commit 3c8b464

Browse files
committed
renaming extract_keras_summary to extract_keras_model
1 parent 5bd7cdb commit 3c8b464

File tree

4 files changed

+23
-14
lines changed

4 files changed

+23
-14
lines changed

R/keras_tools.R

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,24 @@ keras_evaluate <- function(object, x, y = NULL, ...) {
3333
keras3::evaluate(keras_model, x = x_proc, y = y_proc, ...)
3434
}
3535

36-
#' Extract Keras Model Summary
36+
#' Extract the Raw Keras Model from a Kerasnip Fit
3737
#'
38+
#' @title Extract Keras Model from a Fitted Kerasnip Object
3839
#' @description
39-
#' Extracts and returns the summary of a Keras model fitted with `kerasnip`.
40+
#' Extracts and returns the underlying Keras model object from a `parsnip`
41+
#' `model_fit` object created by `kerasnip`.
42+
#'
43+
#' @details
44+
#' This is useful when you need to work directly with the Keras model object for
45+
#' tasks like inspecting layer weights, creating custom plots, or passing it to
46+
#' other Keras-specific functions.
4047
#'
4148
#' @param object A `model_fit` object produced by a `kerasnip` specification.
42-
#' @param ... Additional arguments passed on to `keras3::summary()`.
4349
#'
44-
#' @return A character vector, where each element is a line of the model summary.
50+
#' @return The raw Keras model object (`keras_model`).
51+
#' @seealso keras_evaluate, extract_keras_history
4552
#' @export
46-
extract_keras_summary <- function(object, ...) {
53+
extract_keras_model <- function(object) {
4754
object$fit$fit
4855
}
4956

tests/testthat/test_e2e_features.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ test_that("E2E: Setting num_blocks = 0 works for sequential models", {
139139
fit_obj <- parsnip::fit(spec, mpg ~ ., data = mtcars)
140140

141141
# Check that the dense layer is NOT in the model
142-
keras_model <- fit_obj |> extract_keras_summary()
142+
keras_model <- fit_obj |> extract_keras_model()
143143
expect_equal(length(keras_model$layers), 1) # Output layers only
144144

145145
# Check layer names explicitly
@@ -162,7 +162,7 @@ test_that("E2E: Error handling for reserved names works", {
162162
)
163163
})
164164

165-
test_that("E2E: extract_keras_summary works", {
165+
test_that("E2E: extract_keras_model works", {
166166
skip_if_no_keras()
167167

168168
# Reuse model setup from previous tests
@@ -194,7 +194,7 @@ test_that("E2E: extract_keras_summary works", {
194194

195195
fit_obj <- parsnip::fit(spec, mpg ~ ., data = mtcars)
196196

197-
summary_output <- extract_keras_summary(fit_obj)
197+
summary_output <- extract_keras_model(fit_obj)
198198

199199
expect_type(summary_output, "closure")
200200
expect_true(any(grepl("Layer ", summary_output)))

tests/testthat/test_e2e_functional.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ test_that("E2E: Block repetition works for functional models", {
210210
set_engine("keras")
211211
fit_1 <- fit(spec_1, mpg ~ ., data = mtcars)
212212
model_1_layers <- fit_1 |>
213-
extract_keras_summary() |>
213+
extract_keras_model() |>
214214
pluck("layers")
215215

216216
# Expect 3 layers: Input, Dense, Output
@@ -221,7 +221,7 @@ test_that("E2E: Block repetition works for functional models", {
221221
set_engine("keras")
222222
fit_2 <- fit(spec_2, mpg ~ ., data = mtcars)
223223
model_2_layers <- fit_2 |>
224-
extract_keras_summary() |>
224+
extract_keras_model() |>
225225
pluck("layers")
226226
# Expect 4 layers: Input, Dense, Dense, Output
227227
expect_equal(length(model_2_layers), 4)
@@ -231,7 +231,7 @@ test_that("E2E: Block repetition works for functional models", {
231231
set_engine("keras")
232232
fit_3 <- fit(spec_3, mpg ~ ., data = mtcars)
233233
model_3_layers <- fit_3 |>
234-
extract_keras_summary() |>
234+
extract_keras_model() |>
235235
pluck("layers")
236236
# Expect 2 layers: Input, Output
237237
expect_equal(length(model_3_layers), 2)

vignettes/getting_started.Rmd

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,13 @@ mlp_fit <- fit(mlp_spec, y ~ x, data = train_df)
181181

182182
```{r model-summarize}
183183
mlp_fit |>
184-
extract_keras_summary()
184+
extract_keras_model() |>
185+
summary()
185186
```
186187

187188
```{r model-plot, eval=FALSE}
188189
mlp_fit |>
189-
extract_keras_summary() |>
190+
extract_keras_model() |>
190191
plot(show_shapes = TRUE)
191192
```
192193

@@ -320,7 +321,8 @@ We can now inspect our final, tuned model.
320321
# Print the model summary
321322
final_fit |>
322323
extract_fit_parsnip() |>
323-
extract_keras_summary()
324+
extract_keras_model() |>
325+
summary()
324326
325327
# Plot the training history
326328
final_fit |>

0 commit comments

Comments
 (0)