-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
"object 'optimizer' not found" error when fit() custom model #6
Comments
Hmm, that sounds like a bug. I will investigate. Thanks for reporting! |
Which chapter? |
ch07 |
It looks like starting with TensorFlow 2.11, custom We will need to investigate further (and update the corresponding guides on tensorflow.rstudio.com). Updating the ## -------------------------------------------------------------------------
loss_fn <- loss_sparse_categorical_crossentropy()
loss_tracker <- metric_mean(name = "loss")
CustomModel <- new_model_class(
classname = "CustomModel",
compile = function(optimizer, loss_fn, ...) {
super$compile(...)
optimizer$build(self$variables)
self$optimizer <- optimizer
self$loss_fn <- loss_fn
},
train_step = function(data) {
c(inputs, targets) %<-% data
with(tf$GradientTape() %as% tape, {
predictions <- self(inputs, training = TRUE)
loss <- self$loss_fn(targets, predictions)
})
gradients <- tape$gradient(loss, model$trainable_weights)
self$optimizer$apply_gradients(zip_lists(gradients, model$trainable_weights))
loss_tracker$update_state(loss)
list(loss = loss_tracker$result())
},
metrics = mark_active(function() list(loss_tracker))
)
## -------------------------------------------------------------------------
inputs <- layer_input(shape=c(28 * 28))
features <- inputs %>%
layer_dense(512, activation="relu") %>%
layer_dropout(0.5)
outputs <- features %>%
layer_dense(10, activation="softmax")
model <- CustomModel(inputs = inputs, outputs = outputs)
model %>% compile(optimizer = optimizer_rmsprop(), loss = loss_fn)
model %>% fit(train_images, train_labels, epochs = 3) |
Hi!
It looks like
compile()
ignores an optimizer argument when compiling/training a custom model.When i try this code:
model %>% compile(optimizer = optimizer_rmsprop())
(766th row in the book`s code)it falls with an error: "Error in py_call_impl(callable, call_args$unnamed, call_args$named) :
RuntimeError: in user code:
....
RuntimeError: object 'optimizer' not found".
Instead of a passed argument it takes an optimizer variable from parent environment (Global environment).
In other words, it needs to define in advance: optimizer <- optimizer_rmsprop(), then model is training as it should be.
Is this OK?
Any thoughts?
The text was updated successfully, but these errors were encountered: