Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal measure functions and tests #23

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ Version: 0.0.1
Authors@R: c(
person("Damir", "Pulatov", , "[email protected]", role = c("cre", "aut")),
person("Marc", "Becker", , "[email protected]", role = "aut",
comment = c(ORCID = "0000-0002-8115-0400"))
comment = c(ORCID = "0000-0002-8115-0400")),
person("Baisu", "Zhou", , "[email protected]", role = "aut")
)
Description: Flexible AutoML system for the 'mlr3' ecosystem.
License: LGPL-3
Expand Down
89 changes: 52 additions & 37 deletions R/internal_measure.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#' @title Internal Measure XGBoost
#' @title Internal Measure Catboost
#'
#' @description
#' Function to get the internal xgboost measure for a given [mlr3::Task] and [mlr3::Measure].
#' For example, the measure "classif.auc" will return "auc" for a binary classification task.
#' Function to get the internal catboost measure for a given [mlr3::Task] and [mlr3::Measure].
#' For example, the measure "classif.auc" will return "AUC" for a binary classification task.
#'
#' @param measure [mlr3::Measure]\cr
#' Measure to get the internal measure for.
Expand All @@ -11,42 +11,53 @@
#'
#' @export
#' @examples
#' internal_measure_xgboost(msr("classif.auc"), tsk("pima"))
internal_measure_xgboost = function(measure, task) {
#' internal_measure_catboost(msr("classif.auc"), tsk("pima"))
internal_measure_catboost = function(measure, task) {
id = measure$id

metric = if (task$task_type == "regr") {
switch(id,
"regr.rmse" = "rmse",
"regr.rmsle" = "rmsle",
"regr.mae" = "mae",
"regr.mape" = "mape",
"regr.logloss" = "logloss",
"regr.rmse" = "RMSE",
"regr.mae" = "MAE",
"regr.mape" = "MAPE",
"regr.smape" = "SMAPE",
"regr.medae" = "MedianAbsoluteError",
"rsq" = "R2", # regr.rsq has id `rsq`
NULL
)
} else if ("twoclass" %in% task$properties) {
switch(id,
"classif.ce" = "error",
"classif.acc" = "error",
"classif.auc" = "auc",
"classif.ce" = "Accuracy",
"classif.acc" = "Accuracy",
"classif.bacc" = "BalancedAccuracy",
"classif.auc" = "AUC",
"classif.prauc" = "PRAUC",
"classif.bbrier" = "BrierScore",
"classif.logloss" = "Logloss",
"classif.precision" = "Precision",
"classif.recall" = "Recall",
"classif.mcc" = "MCC",
NULL
)
} else if ("multiclass" %in% task$properties) {
switch(id,
"classif.ce" = "merror",
"classif.acc" = "merror",
"classif.ce" = "Accuracy",
"classif.acc" = "Accuracy",
"classif.mauc_mu" = "AUC",
"classif.logloss" = "MultiClass",
"classif.mcc" = "MCC",
NULL
)
}

return(metric %??% NA_character_)
}

#' @title Internal Measure Catboost
#' @title Internal Measure LightGBM
#'
#' @description
#' Function to get the internal catboost measure for a given [mlr3::Task] and [mlr3::Measure].
#' For example, the measure "classif.auc" will return "AUC" for a binary classification task.
#' Function to get the internal lightgbm measure for a given [mlr3::Task] and [mlr3::Measure].
#' For example, the measure "classif.auc" will return "auc" for a binary classification task.
#'
#' @param measure [mlr3::Measure]\cr
#' Measure to get the internal measure for.
Expand All @@ -55,40 +66,41 @@ internal_measure_xgboost = function(measure, task) {
#'
#' @export
#' @examples
#' internal_measure_catboost(msr("classif.auc"), tsk("pima"))
internal_measure_catboost = function(measure, task) {
#' internal_measure_lightgbm(msr("classif.auc"), tsk("pima"))
internal_measure_lightgbm = function(measure, task) {
id = measure$id

metric = if (task$task_type == "regr") {
switch(id,
"regr.rmse" = "RMSE",
"regr.rmsle" = "RMSLE",
"regr.mae" = "MAE",
"regr.mape" = "MAPE",
"regr.logloss" = "Logloss",
"regr.mse" = "mse",
"regr.rmse" = "rmse",
"regr.mae" = "mae",
"regr.mape" = "mape",
NULL
)
} else if ("twoclass" %in% task$properties) {
switch(id,
"classif.ce" = "Accuracy",
"classif.acc" = "Accuracy",
"classif.auc" = "AUC",
"classif.ce" = "binary_error",
"classif.acc" = "binary_error",
"classif.auc" = "auc",
"classif.logloss" = "binary_logloss",
NULL
)
} else if ("multiclass" %in% task$properties) {
switch(id,
"classif.ce" = "Accuracy",
"classif.acc" = "Accuracy",
"classif.auc" = "AUC",
"classif.logloss" = "MultiLogloss",
"classif.ce" = "multi_error",
"classif.acc" = "multi_error",
"classif.mauc_mu" = "auc_mu",
"classif.logloss" = "multi_logloss",
NULL
)
}

return(metric %??% NA_character_)
}

#' @title Internal Measure LightGBM

#' @title Internal Measure XGBoost
#'
#' @description
#' Function to get the internal xgboost measure for a given [mlr3::Task] and [mlr3::Measure].
Expand All @@ -101,30 +113,33 @@ internal_measure_catboost = function(measure, task) {
#'
#' @export
#' @examples
#' internal_measure_lightgbm(msr("classif.auc"), tsk("pima"))
internal_measure_lightgbm = function(measure, task) {
#' internal_measure_xgboost(msr("classif.auc"), tsk("pima"))
internal_measure_xgboost = function(measure, task) {
id = measure$id

metric = if (task$task_type == "regr") {
switch(id,
"regr.rmse" = "rmse",
"regr.rmsle" = "rmsle",
"regr.mae" = "mae",
"regr.mape" = "mape",
"regr.logloss" = "logloss",
NULL
)
} else if ("twoclass" %in% task$properties) {
switch(id,
"classif.ce" = "error",
"classif.acc" = "error",
"classif.auc" = "auc",
"classif.prauc" = "aucpr",
"classif.logloss" = "logloss",
NULL
)
} else if ("multiclass" %in% task$properties) {
switch(id,
"classif.ce" = "merror",
"classif.acc" = "merror",
"classif.auc" = "auc_mu",
"classif.mauc_aunp" = "auc",
"classif.logloss" = "mlogloss",
NULL
)
}
Expand Down
2 changes: 1 addition & 1 deletion R/train_auto.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ train_auto = function(self, private, task) {
graph_learner$param_set$values$lightgbm.callbacks = list(cb_timeout_lightgbm(pv$learner_timeout * 0.8))
eval_metric = pv$lightgbm_eval_metric %??% internal_measure_lightgbm(pv$measure, task)
if (is.na(eval_metric)) eval_metric = pv$lightgbm_eval_metric
graph_learner$param_set$values$lightgbm.eval = eval_metric
graph_learner$param_set$values$lightgbm.eval = eval_metric # maybe change this to `lightgbm.eval_metric` for consistency?
}

# initialize search space
Expand Down
84 changes: 72 additions & 12 deletions tests/testthat/test_LearnerClassifAutoCatboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,87 @@ test_that("LearnerClassifAutoCatboost is trained", {
expect_equal(learner$model$instance$result$branch.selection, "catboost")
})

test_that("internal eval metric is found", {
test_that("LearnerClassifAutoCatboost twoclass internal eval metric is found", {
skip_on_cran()
skip_if_not_installed("rush")
flush_redis()

rush_plan(n_workers = 2)


task_twoclass = tsk("pima")
msrs_twoclass = rbindlist(list(
list(measure = "classif.ce", metric = "Accuracy"),
list(measure = "classif.acc", metric = "Accuracy"),
list(measure = "classif.bacc", metric = "BalancedAccuracy"),
list(measure = "classif.auc", metric = "AUC"),
list(measure = "classif.prauc", metric = "PRAUC"),
list(measure = "classif.bbrier", metric = "BrierScore"),
list(measure = "classif.logloss", metric = "Logloss"),
list(measure = "classif.precision", metric = "Precision"),
list(measure = "classif.recall", metric = "Recall"),
list(measure = "classif.mcc", metric = "MCC")
))
walk(seq_row(msrs_twoclass), function(i) {
learner = lrn("classif.auto_catboost",
small_data_size = 1,
resampling = rsmp("holdout"),
measure = msr(msrs_twoclass$measure[[i]]),
terminator = trm("evals", n_evals = 6),
store_benchmark_result = TRUE,
store_models = TRUE
)
learner$train(task_twoclass)

task = tsk("penguins")
learner = lrn("classif.auto_catboost",
small_data_size = 1,
resampling = rsmp("holdout"),
measure = msr("classif.ce"),
terminator = trm("evals", n_evals = 6),
store_benchmark_result = TRUE,
store_models = TRUE
)
expect_equal(
learner$instance$archive$benchmark_result$resample_result(1)$learners[[1]]$model$catboost$param_vals$eval_metric,
msrs_twoclass$metric[[i]]
)
})
})

learner$train(task)
expect_equal(learner$instance$archive$benchmark_result$resample_result(1)$learners[[1]]$model$catboost$param_vals$eval_metric, "Accuracy")
test_that("LearnerClassifAutoCatboost multiclass internal eval metric is found", {
skip_on_cran()
skip_if_not_installed("rush")
flush_redis()

rush_plan(n_workers = 2)


task_multiclass = tsk("penguins")
msrs_multiclass = rbindlist(list(
list(measure = "classif.ce", metric = "Accuracy"),
list(measure = "classif.acc", metric = "Accuracy"),
list(measure = "classif.mauc_mu", metric = "AUC"),
list(measure = "classif.logloss", metric = "MultiClass"),
list(measure = "classif.mcc", metric = "MCC")
))
walk(seq_row(msrs_multiclass), function(i) {
learner = lrn("classif.auto_catboost",
small_data_size = 1,
resampling = rsmp("holdout"),
measure = msr(msrs_multiclass$measure[[i]]),
terminator = trm("evals", n_evals = 6),
store_benchmark_result = TRUE,
store_models = TRUE
)
learner$train(task_multiclass)

expect_equal(
learner$instance$archive$benchmark_result$resample_result(1)$learners[[1]]$model$catboost$param_vals$eval_metric,
msrs_multiclass$metric[[i]]
)
})
})

test_that("catboost not supported internal eval metric throws error", {
skip_on_cran()
skip_if_not_installed("rush")
flush_redis()

rush_plan(n_workers = 2)

task = tsk("penguins")
learner = lrn("classif.auto_catboost",
small_data_size = 1,
resampling = rsmp("holdout"),
Expand Down
68 changes: 68 additions & 0 deletions tests/testthat/test_LearnerClassifAutoLightGBM.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,71 @@ test_that("LearnerClassifAutoLightGBM is trained", {
expect_equal(learner$graph$param_set$values$branch.selection, "lightgbm")
expect_equal(learner$model$instance$result$branch.selection, "lightgbm")
})

test_that("LearnerClassifAutoLightGBM twoclass internal eval metric is found", {
skip_on_cran()
skip_if_not_installed("rush")
flush_redis()

rush_plan(n_workers = 2)


task_twoclass = tsk("pima")
msrs_twoclass = rbindlist(list(
list(measure = "classif.ce", metric = "binary_error"),
list(measure = "classif.acc", metric = "binary_error"),
list(measure = "classif.logloss", metric = "binary_logloss"),
list(measure = "classif.auc", metric = "auc")
))
walk(seq_row(msrs_twoclass), function(i) {
learner = lrn("classif.auto_lightgbm",
small_data_size = 1,
resampling = rsmp("holdout"),
measure = msr(msrs_twoclass$measure[[i]]),
terminator = trm("evals", n_evals = 6),
store_benchmark_result = TRUE,
store_models = TRUE
)
learner$train(task_twoclass)

expect_equal(
learner$instance$archive$benchmark_result$resample_result(1)$learners[[1]]$model$lightgbm$param_vals$eval,
# only for lightgbm, it is called `eval` instead of `eval.metric`
msrs_twoclass$metric[[i]]
)
})
})

test_that("LearnerClassifAutoLightGBM multiclass internal eval metric is found", {
skip_on_cran()
skip_if_not_installed("rush")
flush_redis()

rush_plan(n_workers = 2)


task_multiclass = tsk("penguins")
msrs_multiclass = rbindlist(list(
list(measure = "classif.ce", metric = "multi_error"),
list(measure = "classif.acc", metric = "multi_error"),
list(measure = "classif.logloss", metric = "multi_logloss"),
list(measure = "classif.mauc_mu", metric = "auc_mu")
))
walk(seq_row(msrs_multiclass), function(i) {
learner = lrn("classif.auto_lightgbm",
small_data_size = 1,
resampling = rsmp("holdout"),
measure = msr(msrs_multiclass$measure[[i]]),
terminator = trm("evals", n_evals = 6),
store_benchmark_result = TRUE,
store_models = TRUE
)
learner$train(task_multiclass)

expect_equal(
learner$instance$archive$benchmark_result$resample_result(1)$learners[[1]]$model$lightgbm$param_vals$eval,
# only for lightgbm, it is called `eval` instead of `eval.metric`
msrs_multiclass$metric[[i]]
)
})
})
Loading
Loading