Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

predcontrib = TRUE in xgboost classification learner causes errors #236

Open
TylerGrantSmith opened this issue May 5, 2022 · 0 comments

Comments

@TylerGrantSmith
Copy link

Using predcontrib = TRUE to get Shapley values with the xgboost classifier learner causes errors because of mlr3 expecting a different prediction format.

With predcontrib = TRUE xgboost returns a matrix and the default behavior of mlr3learners:::LearnerClassifXgboost$private_methods$.train flattens the matrix which causes a # of rows mismatch.

How can I get these Shapley contributions using xgboost's built-in method?

library(mlr3)
#> Warning: package 'mlr3' was built under R version 4.0.5
#> Registered S3 methods overwritten by 'parallelly':
#>   method                     from  
#>   c.cluster                  future
#>   print.RichSOCKcluster      future
#>   stopCluster.RichMPIcluster future
#>   summary.RichSOCKcluster    future
#>   summary.RichSOCKnode       future
library(mlr3learners)
#> Warning: package 'mlr3learners' was built under R version 4.0.5
library(mlr3pipelines)
#> Warning: package 'mlr3pipelines' was built under R version 4.0.5

penguins = palmerpenguins::penguins
penguins <- penguins[!is.na(penguins$sex), ]
task = as_task_classif(penguins, target = "sex", positive = "male")

learner = lrn("classif.xgboost", 
               predict_type = 'prob',
               predcontrib = TRUE, 
               nrounds = 10)

fencoder = po("encode", 
              method = "treatment", 
              affect_columns = selector_type("factor"))

graph = fencoder %>>% learner
graph_learner = as_learner(graph)

graph_learner$train(task)
pred <- graph_learner$predict(task)
#> Error: Predicted prob contains 2997 additional predictions without matching rows
#> This happened PipeOp classif.xgboost's $predict()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant