Machine Learning
By Tainá Carreira da Rocha
24/8/2022
The final report of Machine Learning course from curso-r company is a prediction about who will purchase at Google Virtual Store in the next month.
Packages
library(tidymodels)
## ── Attaching packages ────────────────────────────────────── tidymodels 1.0.0 ──
## ✔ broom 1.0.0 ✔ recipes 1.0.1
## ✔ dials 1.0.0 ✔ rsample 1.1.0
## ✔ dplyr 1.0.9 ✔ tibble 3.1.8
## ✔ ggplot2 3.3.6 ✔ tidyr 1.2.0
## ✔ infer 1.0.2 ✔ tune 1.0.0
## ✔ modeldata 1.0.0 ✔ workflows 1.0.0
## ✔ parsnip 1.0.1 ✔ workflowsets 1.0.0
## ✔ purrr 0.3.4 ✔ yardstick 1.0.0
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
## ✖ recipes::step() masks stats::step()
## • Learn how to get started at https://www.tidymodels.org/start/
library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✔ readr 2.1.2 ✔ forcats 0.5.1
## ✔ stringr 1.4.0
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ readr::col_factor() masks scales::col_factor()
## ✖ purrr::discard() masks scales::discard()
## ✖ dplyr::filter() masks stats::filter()
## ✖ stringr::fixed() masks recipes::fixed()
## ✖ dplyr::lag() masks stats::lag()
## ✖ readr::spec() masks yardstick::spec()
library(rpart)
##
## Attaching package: 'rpart'
## The following object is masked from 'package:dials':
##
## prune
library(rpart.plot)
library(pROC)
## Type 'citation("pROC")' for a citation.
##
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
##
## cov, smooth, var
library(vip)
##
## Attaching package: 'vip'
## The following object is masked from 'package:utils':
##
## vi
Read the data
ga = readr::read_csv("data/ga_train.csv")
## Rows: 1061278 Columns: 38
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (15): fullVisitorId, last_channel_grouping, last_browser, last_deviceCa...
## dbl (22): last_ses_from_the_period_end, interval_dates, unique_date_num, ma...
## date (1): month
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
glimpse(ga)
## Rows: 1,061,278
## Columns: 38
## $ month <date> 2016-09-01, 2016-09-01, 2016-09-01, 2016…
## $ fullVisitorId <chr> "000005103959234087", "000011415654313568…
## $ last_channel_grouping <chr> "Organic Search", "Social", "Social", "So…
## $ last_ses_from_the_period_end <dbl> 11, 24, 23, 12, 5, 15, 7, 17, 13, 6, 17, …
## $ interval_dates <dbl> 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,…
## $ unique_date_num <dbl> 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1,…
## $ max_visit_num <dbl> 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 5,…
## $ last_browser <chr> "Chrome", "Safari", "Opera Mini", "Chrome…
## $ last_deviceCategory <chr> "mobile", "desktop", "mobile", "desktop",…
## $ last_continent <chr> "Americas", "Asia", "Africa", "Asia", "Eu…
## $ last_operatingSystem <chr> "Android", "Macintosh", "(not set)", "Win…
## $ last_subContinent <chr> "Northern America", "Western Asia", "Nort…
## $ last_country <chr> "United States", "Turkey", "Sudan", "Phil…
## $ last_region <chr> "not available in demo dataset", "Istanbu…
## $ last_metro <chr> "not available in demo dataset", "(not se…
## $ last_city <chr> "not available in demo dataset", "Istanbu…
## $ last_networkDomain <chr> "comcast.net", "ttnet.com.tr", "opera-min…
## $ last_source <chr> "google", "youtube.com", "youtube.com", "…
## $ last_medium <chr> "organic", "referral", "referral", "refer…
## $ prop_isMobile <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,…
## $ prop_isTrueDirect <dbl> 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0…
## $ sum_hits <dbl> 10, 1, 1, 1, 1, 2, 2, 3, 2, 1, 2, 4, 46, …
## $ mean_hits <dbl> 10.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0, …
## $ min_hits <dbl> 10, 1, 1, 1, 1, 2, 1, 3, 2, 1, 2, 4, 46, …
## $ max_hits <dbl> 10, 1, 1, 1, 1, 2, 1, 3, 2, 1, 2, 4, 46, …
## $ median_hits <dbl> 10.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0, …
## $ sd_hits <dbl> NA, NA, NA, NA, NA, NA, 0.00000, NA, NA, …
## $ sum_pageviews <dbl> 8, 1, 1, 1, 1, 2, 2, 3, 2, 1, 2, 3, 31, 2…
## $ mean_pageviews <dbl> 8.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 2…
## $ min_pageviews <dbl> 8, 1, 1, 1, 1, 2, 1, 3, 2, 1, 2, 3, 31, 2…
## $ max_pageviews <dbl> 8, 1, 1, 1, 1, 2, 1, 3, 2, 1, 2, 3, 31, 2…
## $ median_pageviews <dbl> 8.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0, 2…
## $ sd_pageviews <dbl> NA, NA, NA, NA, NA, NA, 0.00000, NA, NA, …
## $ bounce_sessions <dbl> 0, 1, 1, 1, 1, 0, 2, 0, 0, 1, 0, 0, 0, 0,…
## $ session_cnt <dbl> 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1,…
## $ totalTransactionRevenue <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 97…
## $ transactions <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1,…
## $ comprou <chr> "não", "não", "não", "não", "não", "não",…
ga |>
count(comprou)
## # A tibble: 2 × 2
## comprou n
## <chr> <int>
## 1 não 1058330
## 2 sim 2948
Train and test data
ga_initial_split = make_splits(
x = list(
analysis = which(!as.character(ga$month) %in% c("2018-01-01", "2018-02-01")),
assessment = which(as.character(ga$month) %in% c("2018-01-01", "2018-02-01"))
),
data = ga
)
ga_train = training(ga_initial_split)
ga_valid = testing(ga_initial_split)
Resample
ga_resamples = sliding_period(ga_train, index = month, period = "month",
lookback = 5, step=2)
Exploratory analysis
Smikr
skimr::skim(ga_train)
## Warning in inline_hist(min_pageviews, 5): Variable contains Inf or -Inf value(s)
## that were converted to NA.
## Warning in inline_hist(max_pageviews, 5): Variable contains Inf or -Inf value(s)
## that were converted to NA.
Table: Table 1: Data summary
Name | ga_train |
Number of rows | 930624 |
Number of columns | 38 |
_______________________ | |
Column type frequency: | |
character | 15 |
Date | 1 |
numeric | 22 |
________________________ | |
Group variables | None |
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
fullVisitorId | 0 | 1 | 14 | 20 | 0 | 885756 | 0 |
last_channel_grouping | 0 | 1 | 6 | 14 | 0 | 8 | 0 |
last_browser | 0 | 1 | 1 | 26 | 0 | 64 | 0 |
last_deviceCategory | 0 | 1 | 6 | 7 | 0 | 3 | 0 |
last_continent | 0 | 1 | 4 | 9 | 0 | 6 | 0 |
last_operatingSystem | 0 | 1 | 3 | 16 | 0 | 22 | 0 |
last_subContinent | 0 | 1 | 9 | 18 | 0 | 23 | 0 |
last_country | 0 | 1 | 4 | 24 | 0 | 226 | 0 |
last_region | 0 | 1 | 4 | 33 | 0 | 402 | 0 |
last_metro | 0 | 1 | 6 | 55 | 0 | 99 | 0 |
last_city | 0 | 1 | 3 | 33 | 0 | 745 | 0 |
last_networkDomain | 0 | 1 | 2 | 64 | 0 | 32301 | 0 |
last_source | 0 | 1 | 3 | 31 | 0 | 293 | 0 |
last_medium | 0 | 1 | 3 | 9 | 0 | 7 | 0 |
comprou | 0 | 1 | 3 | 3 | 0 | 2 | 0 |
Variable type: Date
skim_variable | n_missing | complete_rate | min | max | median | n_unique |
---|---|---|---|---|---|---|
month | 0 | 1 | 2016-09-01 | 2017-12-01 | 2017-04-01 | 16 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
last_ses_from_the_period_end | 0 | 1.00 | 13.70 | 7.84 | 1 | 7.00 | 14.00 | 20.00 | 2.700000e+01 | ▇▆▆▆▇ |
interval_dates | 0 | 1.00 | 0.47 | 2.34 | 0 | 0.00 | 0.00 | 0.00 | 2.600000e+01 | ▇▁▁▁▁ |
unique_date_num | 0 | 1.00 | 1.11 | 0.52 | 1 | 1.00 | 1.00 | 1.00 | 2.600000e+01 | ▇▁▁▁▁ |
max_visit_num | 0 | 1.00 | 1.49 | 3.41 | 1 | 1.00 | 1.00 | 1.00 | 4.080000e+02 | ▇▁▁▁▁ |
prop_isMobile | 0 | 1.00 | 1.00 | 0.00 | 1 | 1.00 | 1.00 | 1.00 | 1.000000e+00 | ▁▁▇▁▁ |
prop_isTrueDirect | 0 | 1.00 | 0.23 | 0.40 | 0 | 0.00 | 0.00 | 0.50 | 1.000000e+00 | ▇▁▁▁▂ |
sum_hits | 0 | 1.00 | 5.42 | 13.45 | 1 | 1.00 | 2.00 | 4.00 | 1.541000e+03 | ▇▁▁▁▁ |
mean_hits | 0 | 1.00 | 4.04 | 7.53 | 1 | 1.00 | 1.50 | 4.00 | 5.000000e+02 | ▇▁▁▁▁ |
min_hits | 0 | 1.00 | 3.58 | 6.97 | 1 | 1.00 | 1.00 | 3.00 | 5.000000e+02 | ▇▁▁▁▁ |
max_hits | 0 | 1.00 | 4.67 | 9.70 | 1 | 1.00 | 2.00 | 4.00 | 5.000000e+02 | ▇▁▁▁▁ |
median_hits | 0 | 1.00 | 3.96 | 7.45 | 1 | 1.00 | 1.00 | 4.00 | 5.000000e+02 | ▇▁▁▁▁ |
sd_hits | 819003 | 0.12 | 5.29 | 9.61 | 0 | 0.55 | 2.08 | 6.36 | 3.507200e+02 | ▇▁▁▁▁ |
sum_pageviews | 0 | 1.00 | 4.54 | 10.03 | 0 | 1.00 | 2.00 | 4.00 | 1.445000e+03 | ▇▁▁▁▁ |
mean_pageviews | 12 | 1.00 | 3.43 | 5.50 | 1 | 1.00 | 1.00 | 3.00 | 4.310000e+02 | ▇▁▁▁▁ |
min_pageviews | 0 | 1.00 | Inf | NaN | 1 | 1.00 | 1.00 | 3.00 | Inf | ▇▁▁▁▁ |
max_pageviews | 0 | 1.00 | -Inf | NaN | -Inf | 1.00 | 1.00 | 4.00 | 4.830000e+02 | ▇▁▁▁▁ |
median_pageviews | 12 | 1.00 | 3.37 | 5.45 | 1 | 1.00 | 1.00 | 3.00 | 4.310000e+02 | ▇▁▁▁▁ |
sd_pageviews | 819051 | 0.12 | 4.01 | 6.78 | 0 | 0.50 | 1.41 | 4.95 | 2.432400e+02 | ▇▁▁▁▁ |
bounce_sessions | 0 | 1.00 | 0.62 | 0.73 | 0 | 0.00 | 1.00 | 1.00 | 6.300000e+01 | ▇▁▁▁▁ |
session_cnt | 0 | 1.00 | 1.21 | 0.86 | 1 | 1.00 | 1.00 | 1.00 | 8.100000e+01 | ▇▁▁▁▁ |
totalTransactionRevenue | 0 | 1.00 | 2129194.73 | 119342544.05 | 0 | 0.00 | 0.00 | 0.00 | 9.277596e+10 | ▇▁▁▁▁ |
transactions | 0 | 1.00 | 0.01 | 0.14 | 0 | 0.00 | 0.00 | 0.00 | 1.500000e+01 | ▇▁▁▁▁ |
Correlation
ga_train |>
select(where(is.numeric)) |>
cor(use = "pairwise.complete.obs") |>
corrplot::corrplot()
## Warning in cor(select(ga_train, where(is.numeric)), use =
## "pairwise.complete.obs"): the standard deviation is zero
Decision tree
Data prep
ga_dt_recipe = recipe(comprou ~ ., data = ga_train) |>
update_role(month, new_role = "date") |>
update_role(fullVisitorId, new_role = "id") |>
step_rm(skip = TRUE,
last_region,
last_metro,
last_city,
last_networkDomain,
last_source,
last_browser
) |>
themis::step_downsample(comprou, under_ratio = 10) |>
step_novel(all_nominal_predictors()) |>
step_zv(all_predictors()) |>
step_other(
last_subContinent,
last_operatingSystem
)
Model
ga_dt_model = decision_tree(
cost_complexity = tune(),
tree_depth = tune(),
min_n = tune()
) |>
set_mode("classification") |>
set_engine("rpart")
Workflow
ga_dt_wf = workflow() |>
add_model(ga_dt_model) |>
add_recipe(ga_dt_recipe)
Tune
grid_dt = grid_random(
cost_complexity(c(-9, -1)),
tree_depth(range = c(5, 15)),
min_n(range = c(20, 40)),
size = 3
)
ga_dt_tune_grid = tune_grid(
ga_dt_wf,
resamples = ga_resamples,
grid = grid_dt,
metrics = metric_set(roc_auc),
control = control_grid(verbose = TRUE)
)
## i Slice1: preprocessor 1/1
## ✓ Slice1: preprocessor 1/1
## i Slice1: preprocessor 1/1, model 1/3
## ✓ Slice1: preprocessor 1/1, model 1/3
## i Slice1: preprocessor 1/1, model 1/3 (predictions)
## i Slice1: preprocessor 1/1, model 2/3
## ✓ Slice1: preprocessor 1/1, model 2/3
## i Slice1: preprocessor 1/1, model 2/3 (predictions)
## i Slice1: preprocessor 1/1, model 3/3
## ✓ Slice1: preprocessor 1/1, model 3/3
## i Slice1: preprocessor 1/1, model 3/3 (predictions)
## i Slice2: preprocessor 1/1
## ✓ Slice2: preprocessor 1/1
## i Slice2: preprocessor 1/1, model 1/3
## ✓ Slice2: preprocessor 1/1, model 1/3
## i Slice2: preprocessor 1/1, model 1/3 (predictions)
## i Slice2: preprocessor 1/1, model 2/3
## ✓ Slice2: preprocessor 1/1, model 2/3
## i Slice2: preprocessor 1/1, model 2/3 (predictions)
## i Slice2: preprocessor 1/1, model 3/3
## ✓ Slice2: preprocessor 1/1, model 3/3
## i Slice2: preprocessor 1/1, model 3/3 (predictions)
## i Slice3: preprocessor 1/1
## ✓ Slice3: preprocessor 1/1
## i Slice3: preprocessor 1/1, model 1/3
## ✓ Slice3: preprocessor 1/1, model 1/3
## i Slice3: preprocessor 1/1, model 1/3 (predictions)
## i Slice3: preprocessor 1/1, model 2/3
## ✓ Slice3: preprocessor 1/1, model 2/3
## i Slice3: preprocessor 1/1, model 2/3 (predictions)
## i Slice3: preprocessor 1/1, model 3/3
## ✓ Slice3: preprocessor 1/1, model 3/3
## i Slice3: preprocessor 1/1, model 3/3 (predictions)
## i Slice4: preprocessor 1/1
## ✓ Slice4: preprocessor 1/1
## i Slice4: preprocessor 1/1, model 1/3
## ✓ Slice4: preprocessor 1/1, model 1/3
## i Slice4: preprocessor 1/1, model 1/3 (predictions)
## i Slice4: preprocessor 1/1, model 2/3
## ✓ Slice4: preprocessor 1/1, model 2/3
## i Slice4: preprocessor 1/1, model 2/3 (predictions)
## i Slice4: preprocessor 1/1, model 3/3
## ✓ Slice4: preprocessor 1/1, model 3/3
## i Slice4: preprocessor 1/1, model 3/3 (predictions)
## i Slice5: preprocessor 1/1
## ✓ Slice5: preprocessor 1/1
## i Slice5: preprocessor 1/1, model 1/3
## ✓ Slice5: preprocessor 1/1, model 1/3
## i Slice5: preprocessor 1/1, model 1/3 (predictions)
## i Slice5: preprocessor 1/1, model 2/3
## ✓ Slice5: preprocessor 1/1, model 2/3
## i Slice5: preprocessor 1/1, model 2/3 (predictions)
## i Slice5: preprocessor 1/1, model 3/3
## ✓ Slice5: preprocessor 1/1, model 3/3
## i Slice5: preprocessor 1/1, model 3/3 (predictions)
autoplot(ga_dt_tune_grid)
collect_metrics(ga_dt_tune_grid)
## # A tibble: 3 × 9
## cost_complexity tree_depth min_n .metric .estima…¹ mean n std_err .config
## <dbl> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 0.000000372 13 23 roc_auc binary 0.907 5 0.00320 Prepro…
## 2 0.00000360 11 34 roc_auc binary 0.915 5 0.00340 Prepro…
## 3 0.00000132 6 31 roc_auc binary 0.841 5 0.0187 Prepro…
## # … with abbreviated variable name ¹.estimator
Model performance
ga_dt_best_params = select_best(ga_dt_tune_grid, "roc_auc")
ga_dt_wf = ga_dt_wf |> finalize_workflow(ga_dt_best_params)
ga_dt_last_fit = last_fit(ga_dt_wf, ga_initial_split)
ga_test_preds = collect_predictions(ga_dt_last_fit) |> mutate(modelo = "dt")
ROC AUC
ga_test_preds |>
group_by(modelo) |>
roc_curve(comprou, `.pred_não`) |>
autoplot()
Variable importance
ga_dt_last_fit_model = ga_dt_last_fit$.workflow[[1]]$fit$fit
vip(ga_dt_last_fit_model)
rpart.plot(ga_dt_last_fit_model$fit, faclen = 2)
## Warning: Cannot retrieve the data used to build the model (so cannot determine roundint and is.binary for the variables).
## To silence this warning:
## Call rpart.plot with roundint=FALSE,
## or rebuild the rpart model with model=TRUE.
## Warning in abbreviate(names, minlen): abbreviate used with non-ASCII chars
## Warning in abbreviate(names, minlen): abbreviate used with non-ASCII chars
## Warning in abbreviate(names, minlen): abbreviate used with non-ASCII chars
## Warning: labs do not fit even at cex 0.15, there may be some overplotting
Final Model
ga_final_dt_model = ga_dt_wf |>
fit(ga)
Submission file
ga_test = readr::read_csv("data/ga_test.csv")
## Rows: 133534 Columns: 37
## ── Column specification ────────────────────────────────────────────────────────
## Delimiter: ","
## chr (14): fullVisitorId, last_channel_grouping, last_browser, last_deviceCa...
## dbl (22): last_ses_from_the_period_end, interval_dates, unique_date_num, ma...
## date (1): month
##
## ℹ Use `spec()` to retrieve the full column specification for this data.
## ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
ga_submission = ga_test %>%
mutate(
target = predict(ga_final_dt_model, new_data = . , type = "prob")$.pred_sim
)
ga_submission |>
mutate(fullVisitorId = paste(fullVisitorId, month, sep = "-")) |>
select(fullVisitorId, comprou = target) |>
write_csv("ga_submission.csv")
License
Content is available under the Creative CommonsAttribution-ShareAlike (CC BY-SA) license. You can share and adapt it, but you must attribute the credits to the authors, adding a link to the original content. , and your sharing must also have this same type of license.
More info: Creative Commons
- Posted on:
- 24/8/2022
- Length:
- 11 minute read, 2171 words
- See Also: