library(tidyverse)
library(tidymodels)
library(bonsai)
set.seed(1234)
12 LightGBM for regression
12.1 The basics
Reminder from before: LightGBM is another tree-based method that is similar to XGBoost but differs in ways that make it computationally more efficient. Where XGBoost and Random Forests are based on branches, LightGBM grows leaf-wise. Think of it like this – XGBoost uses lots of short trees with branches – it comes to a fork, makes a decision that reduces the amount of error the last tree got, and makes a new branch. LightGBM on the other hand, makes new leaves of each branch, which can mean lots of little splits, instead of big ones. That can lead to over-fitting on small datasets, but it also means it’s much faster than XGBoost.
Did you catch that bold part? LightGBM is prone to overfitting on small datasets. Our dataset is small.
LightGBM also uses histograms of the data to make choices, where XGBoost is comptutationally optimizing those choices. Roughly translated – LightGBM is looking at the fat part of a normal distribution to make choices, where XGBoost is tuning parameters to find the optimal path forward. It’s another reason why LightGBM is faster, but also not reliable with small datasets.
What changes from before to now? Very little. Just the output – we’re predicting a number this time, not a category.
Let’s implement a LightGBM model. We start with libraries.
We’ll use the same data – wide receivers with college stats, draft information and fantasy points.
<- read_csv("https://mattwaite.github.io/sportsdatafiles/wrdraftedstats20132022.csv") wrdraftedstats
We thin up the inputs.
<- wrdraftedstats %>%
wrselected select(
name,
year,
college_team,
nfl_team,
overall,
total_yards,
total_touchdowns,
FantPt%>% na.omit() )
And split our data.
<- initial_split(wrselected, prop = .8)
player_split
<- training(player_split)
player_train <- testing(player_split) player_test
Now a recipe.
<-
player_recipe recipe(FantPt ~ ., data = player_train) %>%
update_role(name, year, college_team, nfl_team, new_role = "ID")
summary(player_recipe)
# A tibble: 8 × 4
variable type role source
<chr> <chr> <chr> <chr>
1 name nominal ID original
2 year numeric ID original
3 college_team nominal ID original
4 nfl_team nominal ID original
5 overall numeric predictor original
6 total_yards numeric predictor original
7 total_touchdowns numeric predictor original
8 FantPt numeric outcome original
12.2 Implementing LightGBM
We’re going to implement XGBoost and LightGBM side by side. That will give us the chance to compare.
We start with model definition.
<- boost_tree(
xg_mod trees = tune(),
learn_rate = tune(),
tree_depth = tune(),
min_n = tune(),
loss_reduction = tune(),
sample_size = tune(),
mtry = tune(),
%>%
) set_mode("regression") %>%
set_engine("xgboost")
<-
lightgbm_mod boost_tree() %>%
set_engine("lightgbm") %>%
set_mode(mode = "regression")
Now we create workflows.
<-
xg_workflow workflow() %>%
add_model(xg_mod) %>%
add_recipe(player_recipe)
<-
lightgbm_workflow workflow() %>%
add_model(lightgbm_mod) %>%
add_recipe(player_recipe)
We’ll tune the XGBoost model.
<- grid_latin_hypercube(
xgb_grid trees(),
tree_depth(),
min_n(),
loss_reduction(),
sample_size = sample_prop(),
finalize(mtry(), player_train),
learn_rate()
)
<- vfold_cv(player_train)
player_folds
<- tune_grid(
xgb_res
xg_workflow,resamples = player_folds,
grid = xgb_grid,
control = control_grid(save_pred = TRUE)
)
<- select_best(xgb_res, "rmse")
best_rmse
<- finalize_workflow(
final_xgb
xg_workflow,
best_rmse )
Now we make fits.
<-
xg_fit %>%
final_xgb fit(data = player_train)
<-
lightgbm_fit %>%
lightgbm_workflow fit(data = player_train)
With the fits in hand, we can bind the predictions to the data.
<-
xgpredict %>%
xg_fit predict(new_data = player_train) %>%
bind_cols(player_train)
<-
lightgbmpredict %>%
lightgbm_fit predict(new_data = player_train) %>%
bind_cols(player_train)
For your assignment: How do these two compare?
metrics(xgpredict, FantPt, .pred)
# A tibble: 3 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 rmse standard 38.3
2 rsq standard 0.449
3 mae standard 29.2
metrics(lightgbmpredict, FantPt, .pred)
# A tibble: 3 × 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 rmse standard 31.8
2 rsq standard 0.638
3 mae standard 24.4
For your assignment: How do these models fare in testing?