library(tidyverse)
library(modelr)
library(gapminder)
library(broom)
In this challenge you’re going to explore three powerful ideas that help you to work with large numbers of models with ease:
Using many simple models to better understand complex datasets.
Using list-columns to store arbitrary data structures in a data frame. For example, this will allow you to have a column that contains linear models.
Using the broom package, by David Robinson, to turn models into tidy data.
To motivate the power of many simple models, we’re going to look into the gapminder data. This data was popularised by Hans Rosling, a Swedish doctor and statistician. The gapminder data summarises the progression of countries over time, looking at statistics like life expectancy and GDP. The data is easy to access in R, thanks to Jenny Bryan who created the gapminder package.
gapminder
## # A tibble: 1,704 × 6
## country continent year lifeExp pop gdpPercap
## <fct> <fct> <int> <dbl> <int> <dbl>
## 1 Afghanistan Asia 1952 28.8 8425333 779.
## 2 Afghanistan Asia 1957 30.3 9240934 821.
## 3 Afghanistan Asia 1962 32.0 10267083 853.
## 4 Afghanistan Asia 1967 34.0 11537966 836.
## 5 Afghanistan Asia 1972 36.1 13079460 740.
## 6 Afghanistan Asia 1977 38.4 14880372 786.
## 7 Afghanistan Asia 1982 39.9 12881816 978.
## 8 Afghanistan Asia 1987 40.8 13867957 852.
## 9 Afghanistan Asia 1992 41.7 16317921 649.
## 10 Afghanistan Asia 1997 41.8 22227415 635.
## # … with 1,694 more rows
In this case study, we’re going to focus on just three variables to answer the question:
How does life expectancy change over time for each country?
A good place to start is with a plot:
ggplot(gapminder, aes(year, lifeExp)) +
geom_line(aes(group = country), alpha = 1/3) +
theme_classic()
# or
ggplot(gapminder, aes(year, lifeExp, color = country)) +
geom_line(alpha = 1/2, show.legend = FALSE) +
scale_colour_manual(values = country_colors) +
theme_classic()
Notice the group aesthetics. Many geoms, like geom_line() and geom_smooth(), use a single geometric object to display multiple rows of data. For these geoms, you can set the group aesthetic to a categorical variable to draw multiple objects. ggplot2 will draw a separate object for each unique value of the grouping variable.
This is a small dataset: it only has ~1,700 observations and 3 variables. But it’s still hard to see what’s going on!
Overall, it looks like life expectancy has been steadily improving. However, if you look closely, you might notice some countries that don’t follow this pattern. How can we make those countries easier to see?
There’s a strong signal (overall linear growth) that makes it hard to see subtler trends. We’ll tease these factors apart by fitting a model with a linear trend. The model captures steady growth over time, and the residuals will show what’s left. You already know how to do that if we had a single country:
it = filter(gapminder, country == "Italy")
it %>%
ggplot(aes(year, lifeExp)) +
geom_line() +
ggtitle("Full data")
it_mod = lm(lifeExp ~ year, data = it)
it = add_predictions(it, it_mod)
it %>%
ggplot(aes(year, pred)) +
geom_line() +
ggtitle("Linear trend")
it = add_residuals(it, it_mod)
it %>%
ggplot(aes(year, resid)) +
geom_hline(yintercept = 0, colour = "white", size = 3) +
geom_line() +
ggtitle("Remaining pattern")
How can we easily fit that model to every country? Extract out the common code with a function and repeat using a map function from purrr. To do that, we need a new data structure: the nested data frame. To create a nested data frame we start with a grouped data frame, and nest it:
by_country <- gapminder %>%
group_by(country, continent) %>%
nest()
by_country
## # A tibble: 142 × 3
## # Groups: country, continent [142]
## country continent data
## <fct> <fct> <list>
## 1 Afghanistan Asia <tibble [12 × 4]>
## 2 Albania Europe <tibble [12 × 4]>
## 3 Algeria Africa <tibble [12 × 4]>
## 4 Angola Africa <tibble [12 × 4]>
## 5 Argentina Americas <tibble [12 × 4]>
## 6 Australia Oceania <tibble [12 × 4]>
## 7 Austria Europe <tibble [12 × 4]>
## 8 Bahrain Asia <tibble [12 × 4]>
## 9 Bangladesh Asia <tibble [12 × 4]>
## 10 Belgium Europe <tibble [12 × 4]>
## # … with 132 more rows
by_country$data[[1]]
## # A tibble: 12 × 4
## year lifeExp pop gdpPercap
## <int> <dbl> <int> <dbl>
## 1 1952 28.8 8425333 779.
## 2 1957 30.3 9240934 821.
## 3 1962 32.0 10267083 853.
## 4 1967 34.0 11537966 836.
## 5 1972 36.1 13079460 740.
## 6 1977 38.4 14880372 786.
## 7 1982 39.9 12881816 978.
## 8 1987 40.8 13867957 852.
## 9 1992 41.7 16317921 649.
## 10 1997 41.8 22227415 635.
## 11 2002 42.1 25268405 727.
## 12 2007 43.8 31889923 975.
Now that we have our nested data frame, we’re in a good position to fit some models. We have a model-fitting function:
country_model <- function(df) {
lm(lifeExp ~ year, data = df)
}
And we want to apply it to every data frame. The data frames are in a
list, so we can use purrr::map()
to apply
country_model
to each element. We’re going to create a new
variable in the by_country
data frame with
dplyr::mutate()
:
by_country <- by_country %>%
mutate(model = map(data, country_model))
by_country
## # A tibble: 142 × 4
## # Groups: country, continent [142]
## country continent data model
## <fct> <fct> <list> <list>
## 1 Afghanistan Asia <tibble [12 × 4]> <lm>
## 2 Albania Europe <tibble [12 × 4]> <lm>
## 3 Algeria Africa <tibble [12 × 4]> <lm>
## 4 Angola Africa <tibble [12 × 4]> <lm>
## 5 Argentina Americas <tibble [12 × 4]> <lm>
## 6 Australia Oceania <tibble [12 × 4]> <lm>
## 7 Austria Europe <tibble [12 × 4]> <lm>
## 8 Bahrain Asia <tibble [12 × 4]> <lm>
## 9 Bangladesh Asia <tibble [12 × 4]> <lm>
## 10 Belgium Europe <tibble [12 × 4]> <lm>
## # … with 132 more rows
by_country$model[[1]]
##
## Call:
## lm(formula = lifeExp ~ year, data = df)
##
## Coefficients:
## (Intercept) year
## -507.5343 0.2753
This has a big advantage: because all the related objects are stored together, you don’t need to manually keep them in sync when you filter or arrange. The semantics of the data frame takes care of that for you:
by_country %>%
filter(continent == "Europe")
## # A tibble: 30 × 4
## # Groups: country, continent [30]
## country continent data model
## <fct> <fct> <list> <list>
## 1 Albania Europe <tibble [12 × 4]> <lm>
## 2 Austria Europe <tibble [12 × 4]> <lm>
## 3 Belgium Europe <tibble [12 × 4]> <lm>
## 4 Bosnia and Herzegovina Europe <tibble [12 × 4]> <lm>
## 5 Bulgaria Europe <tibble [12 × 4]> <lm>
## 6 Croatia Europe <tibble [12 × 4]> <lm>
## 7 Czech Republic Europe <tibble [12 × 4]> <lm>
## 8 Denmark Europe <tibble [12 × 4]> <lm>
## 9 Finland Europe <tibble [12 × 4]> <lm>
## 10 France Europe <tibble [12 × 4]> <lm>
## # … with 20 more rows
by_country %>%
arrange(continent, country)
## # A tibble: 142 × 4
## # Groups: country, continent [142]
## country continent data model
## <fct> <fct> <list> <list>
## 1 Algeria Africa <tibble [12 × 4]> <lm>
## 2 Angola Africa <tibble [12 × 4]> <lm>
## 3 Benin Africa <tibble [12 × 4]> <lm>
## 4 Botswana Africa <tibble [12 × 4]> <lm>
## 5 Burkina Faso Africa <tibble [12 × 4]> <lm>
## 6 Burundi Africa <tibble [12 × 4]> <lm>
## 7 Cameroon Africa <tibble [12 × 4]> <lm>
## 8 Central African Republic Africa <tibble [12 × 4]> <lm>
## 9 Chad Africa <tibble [12 × 4]> <lm>
## 10 Comoros Africa <tibble [12 × 4]> <lm>
## # … with 132 more rows
Previously we computed the residuals of a single model with a single
dataset. Now we have 142 data frames and 142 models. To compute the
residuals, we need to call add_residuals()
with each
model-data pair using the purrr::map2()
function (which
works for binary functions like add_residuals()
):
# add residuals
by_country <- by_country %>%
mutate(data = map2(data, model, add_residuals))
by_country
## # A tibble: 142 × 4
## # Groups: country, continent [142]
## country continent data model
## <fct> <fct> <list> <list>
## 1 Afghanistan Asia <tibble [12 × 5]> <lm>
## 2 Albania Europe <tibble [12 × 5]> <lm>
## 3 Algeria Africa <tibble [12 × 5]> <lm>
## 4 Angola Africa <tibble [12 × 5]> <lm>
## 5 Argentina Americas <tibble [12 × 5]> <lm>
## 6 Australia Oceania <tibble [12 × 5]> <lm>
## 7 Austria Europe <tibble [12 × 5]> <lm>
## 8 Bahrain Asia <tibble [12 × 5]> <lm>
## 9 Bangladesh Asia <tibble [12 × 5]> <lm>
## 10 Belgium Europe <tibble [12 × 5]> <lm>
## # … with 132 more rows
by_country$data[[1]]
## # A tibble: 12 × 5
## year lifeExp pop gdpPercap resid
## <int> <dbl> <int> <dbl> <dbl>
## 1 1952 28.8 8425333 779. -1.11
## 2 1957 30.3 9240934 821. -0.952
## 3 1962 32.0 10267083 853. -0.664
## 4 1967 34.0 11537966 836. -0.0172
## 5 1972 36.1 13079460 740. 0.674
## 6 1977 38.4 14880372 786. 1.65
## 7 1982 39.9 12881816 978. 1.69
## 8 1987 40.8 13867957 852. 1.28
## 9 1992 41.7 16317921 649. 0.754
## 10 1997 41.8 22227415 635. -0.534
## 11 2002 42.1 25268405 727. -1.54
## 12 2007 43.8 31889923 975. -1.22
But how you can plot a list of data frames? Previously we used
nest()
to turn a regular data frame into an nested data
frame, and now we do the opposite with unnest()
:
# unnest statistics
resids =
unnest(by_country, data) %>%
select(-model)
resids
## # A tibble: 1,704 × 7
## # Groups: country, continent [142]
## country continent year lifeExp pop gdpPercap resid
## <fct> <fct> <int> <dbl> <int> <dbl> <dbl>
## 1 Afghanistan Asia 1952 28.8 8425333 779. -1.11
## 2 Afghanistan Asia 1957 30.3 9240934 821. -0.952
## 3 Afghanistan Asia 1962 32.0 10267083 853. -0.664
## 4 Afghanistan Asia 1967 34.0 11537966 836. -0.0172
## 5 Afghanistan Asia 1972 36.1 13079460 740. 0.674
## 6 Afghanistan Asia 1977 38.4 14880372 786. 1.65
## 7 Afghanistan Asia 1982 39.9 12881816 978. 1.69
## 8 Afghanistan Asia 1987 40.8 13867957 852. 1.28
## 9 Afghanistan Asia 1992 41.7 16317921 649. 0.754
## 10 Afghanistan Asia 1997 41.8 22227415 635. -0.534
## # ℹ 1,694 more rows
Now we have regular data frame, we can plot the residuals:
resids %>%
ggplot(aes(year, resid)) +
geom_line(aes(group = country), alpha = 1/3) +
geom_smooth(se = FALSE) +
theme_classic()
Facetting by continent is particularly revealing:
resids %>%
ggplot(aes(year, resid)) +
geom_line(aes(group = country), alpha = 1 / 3) +
facet_wrap(~continent) +
theme_classic()
There’s something interesting going on in Africa: we see some very
large residuals which suggests our model isn’t fitting so well there. To
better investigate, we can add additional statistics with
broom:glance()
function:
# add statistics
by_country <- by_country %>%
mutate(glance = map(model, broom::glance))
by_country
## # A tibble: 142 × 5
## # Groups: country, continent [142]
## country continent data model glance
## <fct> <fct> <list> <list> <list>
## 1 Afghanistan Asia <tibble [12 × 5]> <lm> <tibble [1 × 12]>
## 2 Albania Europe <tibble [12 × 5]> <lm> <tibble [1 × 12]>
## 3 Algeria Africa <tibble [12 × 5]> <lm> <tibble [1 × 12]>
## 4 Angola Africa <tibble [12 × 5]> <lm> <tibble [1 × 12]>
## 5 Argentina Americas <tibble [12 × 5]> <lm> <tibble [1 × 12]>
## 6 Australia Oceania <tibble [12 × 5]> <lm> <tibble [1 × 12]>
## 7 Austria Europe <tibble [12 × 5]> <lm> <tibble [1 × 12]>
## 8 Bahrain Asia <tibble [12 × 5]> <lm> <tibble [1 × 12]>
## 9 Bangladesh Asia <tibble [12 × 5]> <lm> <tibble [1 × 12]>
## 10 Belgium Europe <tibble [12 × 5]> <lm> <tibble [1 × 12]>
## # … with 132 more rows
by_country$glance[[1]]
## # A tibble: 1 × 12
## r.squared adj.r.squared sigma statistic p.value df logLik AIC BIC
## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 0.948 0.942 1.22 181. 0.0000000984 1 -18.3 42.7 44.1
## # … with 3 more variables: deviance <dbl>, df.residual <int>, nobs <int>
# unnest statistics
(glance = unnest(by_country, glance))
## # A tibble: 142 × 16
## # Groups: country, continent [142]
## country continent data model r.squared adj.r.squared sigma statistic
## <fct> <fct> <list> <lis> <dbl> <dbl> <dbl> <dbl>
## 1 Afghanistan Asia <tibble … <lm> 0.948 0.942 1.22 181.
## 2 Albania Europe <tibble … <lm> 0.911 0.902 1.98 102.
## 3 Algeria Africa <tibble … <lm> 0.985 0.984 1.32 662.
## 4 Angola Africa <tibble … <lm> 0.888 0.877 1.41 79.1
## 5 Argentina Americas <tibble … <lm> 0.996 0.995 0.292 2246.
## 6 Australia Oceania <tibble … <lm> 0.980 0.978 0.621 481.
## 7 Austria Europe <tibble … <lm> 0.992 0.991 0.407 1261.
## 8 Bahrain Asia <tibble … <lm> 0.967 0.963 1.64 291.
## 9 Bangladesh Asia <tibble … <lm> 0.989 0.988 0.977 930.
## 10 Belgium Europe <tibble … <lm> 0.995 0.994 0.293 1822.
## # … with 132 more rows, and 8 more variables: p.value <dbl>, df <dbl>,
## # logLik <dbl>, AIC <dbl>, BIC <dbl>, deviance <dbl>, df.residual <int>,
## # nobs <int>
With this data frame in hand, we can start to look for models that don’t fit well:
glance %>%
arrange(r.squared)
## # A tibble: 142 × 16
## # Groups: country, continent [142]
## country continent data model r.squared adj.r.squared sigma statistic p.value
## <fct> <fct> <lis> <lis> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 Rwanda Africa <tib… <lm> 0.0172 -0.0811 6.56 0.175 0.685
## 2 Botswa… Africa <tib… <lm> 0.0340 -0.0626 6.11 0.352 0.566
## 3 Zimbab… Africa <tib… <lm> 0.0562 -0.0381 7.21 0.596 0.458
## 4 Zambia Africa <tib… <lm> 0.0598 -0.0342 4.53 0.636 0.444
## 5 Swazil… Africa <tib… <lm> 0.0682 -0.0250 6.64 0.732 0.412
## 6 Lesotho Africa <tib… <lm> 0.0849 -0.00666 5.93 0.927 0.358
## 7 Cote d… Africa <tib… <lm> 0.283 0.212 3.93 3.95 0.0748
## 8 South … Africa <tib… <lm> 0.312 0.244 4.74 4.54 0.0588
## 9 Uganda Africa <tib… <lm> 0.342 0.276 3.19 5.20 0.0457
## 10 Congo,… Africa <tib… <lm> 0.348 0.283 2.43 5.34 0.0434
## # … with 132 more rows, and 7 more variables: df <dbl>, logLik <dbl>,
## # AIC <dbl>, BIC <dbl>, deviance <dbl>, df.residual <int>, nobs <int>
glance %>%
filter(continent == "Africa") %>%
select(country, r.squared) %>%
arrange(r.squared)
## # A tibble: 52 × 3
## # Groups: country, continent [52]
## continent country r.squared
## <fct> <fct> <dbl>
## 1 Africa Rwanda 0.0172
## 2 Africa Botswana 0.0340
## 3 Africa Zimbabwe 0.0562
## 4 Africa Zambia 0.0598
## 5 Africa Swaziland 0.0682
## 6 Africa Lesotho 0.0849
## 7 Africa Cote d'Ivoire 0.283
## 8 Africa South Africa 0.312
## 9 Africa Uganda 0.342
## 10 Africa Congo, Dem. Rep. 0.348
## # … with 42 more rows
glance %>%
filter(continent == "Asia") %>%
select(country, r.squared) %>%
arrange(r.squared)
## # A tibble: 33 × 3
## # Groups: country, continent [33]
## continent country r.squared
## <fct> <fct> <dbl>
## 1 Asia Iraq 0.546
## 2 Asia Cambodia 0.639
## 3 Asia Korea, Dem. Rep. 0.703
## 4 Asia China 0.871
## 5 Asia Myanmar 0.879
## 6 Asia Lebanon 0.942
## 7 Asia Malaysia 0.947
## 8 Asia Afghanistan 0.948
## 9 Asia Sri Lanka 0.948
## 10 Asia Kuwait 0.952
## # … with 23 more rows
glance %>%
filter(continent == "Americas") %>%
select(country, r.squared) %>%
arrange(r.squared)
## # A tibble: 25 × 3
## # Groups: country, continent [25]
## continent country r.squared
## <fct> <fct> <dbl>
## 1 Americas Trinidad and Tobago 0.798
## 2 Americas Jamaica 0.806
## 3 Americas Puerto Rico 0.908
## 4 Americas Cuba 0.924
## 5 Americas Venezuela 0.947
## 6 Americas Panama 0.951
## 7 Americas El Salvador 0.956
## 8 Americas Costa Rica 0.962
## 9 Americas Colombia 0.968
## 10 Americas Dominican Republic 0.971
## # … with 15 more rows
glance %>%
filter(continent == "Europe") %>%
select(country, r.squared) %>%
arrange(r.squared)
## # A tibble: 30 × 3
## # Groups: country, continent [30]
## continent country r.squared
## <fct> <fct> <dbl>
## 1 Europe Bulgaria 0.547
## 2 Europe Slovak Republic 0.792
## 3 Europe Hungary 0.795
## 4 Europe Montenegro 0.802
## 5 Europe Romania 0.806
## 6 Europe Poland 0.840
## 7 Europe Serbia 0.879
## 8 Europe Bosnia and Herzegovina 0.896
## 9 Europe Albania 0.911
## 10 Europe Czech Republic 0.917
## # … with 20 more rows
The worst models all appear to be in Africa. Let’s double check that with a plot:
glance %>%
ggplot(aes(continent, r.squared)) +
geom_point(alpha = 1/3)
We could pull out the countries with particularly bad \(R^2\) and plot the data:
bad_fit <- filter(glance, r.squared < 0.25)
gapminder %>%
semi_join(bad_fit, join_by(country)) %>%
ggplot(aes(year, lifeExp, colour = country)) +
geom_line() +
theme_classic()
bad_fit <- filter(glance, r.squared < 0.50)
gapminder %>%
semi_join(bad_fit, join_by(country)) %>%
ggplot(aes(year, lifeExp, colour = country)) +
geom_line() +
theme_classic()
We see two main effects here: the tragedies of the HIV/AIDS epidemic and the Rwandan genocide. Finally, we plot the linear trend of life expectancy for good-fit countries:
good_fit <- filter(glance, r.squared > 0.95)
gapminder %>%
semi_join(good_fit, join_by(country)) %>%
ggplot(aes(year, lifeExp)) +
geom_line(aes(group = country), alpha = 1/3) +
geom_smooth(se = FALSE) +
theme_classic()