library(dplyr)
library(ggplot2)
library(tidyr)
library(purrr)
library(modelr)
library(gapminder)
library(broom)
In this challenge you’re going to explore three powerful ideas that help you to work with large numbers of models with ease:
Using many simple models to better understand complex datasets.
Using list-columns to store arbitrary data structures in a data frame. For example, this will allow you to have a column that contains linear models.
Using the broom package, by David Robinson, to turn models into tidy data. This is a powerful technique for working with large numbers of models because once you have tidy data, you can apply all of the techniques that you’ve learned about earlier in the book.
To motivate the power of many simple models, we’re going to look into the gapminder data. This data was popularised by Hans Rosling, a Swedish doctor and statistician. The gapminder data summarises the progression of countries over time, looking at statistics like life expectancy and GDP. The data is easy to access in R, thanks to Jenny Bryan who created the gapminder package.
gapminder
## # A tibble: 1,704 x 6
## country continent year lifeExp pop gdpPercap
## <fct> <fct> <int> <dbl> <int> <dbl>
## 1 Afghanistan Asia 1952 28.8 8425333 779.
## 2 Afghanistan Asia 1957 30.3 9240934 821.
## 3 Afghanistan Asia 1962 32.0 10267083 853.
## 4 Afghanistan Asia 1967 34.0 11537966 836.
## 5 Afghanistan Asia 1972 36.1 13079460 740.
## 6 Afghanistan Asia 1977 38.4 14880372 786.
## 7 Afghanistan Asia 1982 39.9 12881816 978.
## 8 Afghanistan Asia 1987 40.8 13867957 852.
## 9 Afghanistan Asia 1992 41.7 16317921 649.
## 10 Afghanistan Asia 1997 41.8 22227415 635.
## # … with 1,694 more rows
In this case study, we’re going to focus on just three variables to answer the question:
How does life expectancy change over time for each country?
A good place to start is with a plot:
ggplot(gapminder, aes(year, lifeExp)) +
geom_line(aes(group = country), alpha = 1/3) +
theme_classic()
Notice the group aesthetics. Many geoms, like geom_line() and geom_smooth(), use a single geometric object to display multiple rows of data. For these geoms, you can set the group aesthetic to a categorical variable to draw multiple objects. ggplot2 will draw a separate object for each unique value of the grouping variable.
This is a small dataset: it only has ~1,700 observations and 3 variables. But it’s still hard to see what’s going on! Overall, it looks like life expectancy has been steadily improving. However, if you look closely, you might notice some countries that don’t follow this pattern. How can we make those countries easier to see?
There’s a strong signal (overall linear growth) that makes it hard to see subtler trends. We’ll tease these factors apart by fitting a model with a linear trend. The model captures steady growth over time, and the residuals will show what’s left. You already know how to do that if we had a single country:
it <- filter(gapminder, country == "Italy")
it %>%
ggplot(aes(year, lifeExp)) +
geom_line() +
ggtitle("Full data")
it_mod <- lm(lifeExp ~ year, data = it)
it %>%
add_predictions(it_mod) %>%
ggplot(aes(year, pred)) +
geom_line() +
ggtitle("Linear trend")
it %>%
add_residuals(it_mod) %>%
ggplot(aes(year, resid)) +
geom_hline(yintercept = 0, colour = "white", size = 3) +
geom_line() +
ggtitle("Remaining pattern")
How can we easily fit that model to every country? Extract out the common code with a function and repeat using a map function from purrr. To do that, we need a new data structure: the nested data frame. To create a nested data frame we start with a grouped data frame, and nest it:
by_country <- gapminder %>%
group_by(country, continent) %>%
nest()
by_country
## # A tibble: 142 x 3
## country continent data
## <fct> <fct> <list>
## 1 Afghanistan Asia <tibble [12 × 4]>
## 2 Albania Europe <tibble [12 × 4]>
## 3 Algeria Africa <tibble [12 × 4]>
## 4 Angola Africa <tibble [12 × 4]>
## 5 Argentina Americas <tibble [12 × 4]>
## 6 Australia Oceania <tibble [12 × 4]>
## 7 Austria Europe <tibble [12 × 4]>
## 8 Bahrain Asia <tibble [12 × 4]>
## 9 Bangladesh Asia <tibble [12 × 4]>
## 10 Belgium Europe <tibble [12 × 4]>
## # … with 132 more rows
by_country$data[[1]]
## # A tibble: 12 x 4
## year lifeExp pop gdpPercap
## <int> <dbl> <int> <dbl>
## 1 1952 28.8 8425333 779.
## 2 1957 30.3 9240934 821.
## 3 1962 32.0 10267083 853.
## 4 1967 34.0 11537966 836.
## 5 1972 36.1 13079460 740.
## 6 1977 38.4 14880372 786.
## 7 1982 39.9 12881816 978.
## 8 1987 40.8 13867957 852.
## 9 1992 41.7 16317921 649.
## 10 1997 41.8 22227415 635.
## 11 2002 42.1 25268405 727.
## 12 2007 43.8 31889923 975.
Now that we have our nested data frame, we’re in a good position to fit some models. We have a model-fitting function:
country_model <- function(df) {
lm(lifeExp ~ year, data = df)
}
And we want to apply it to every data frame. The data frames are in a list, so we can use purrr::map()
to apply country_model
to each element. We’re going to create a new variable in the by_country
data frame with dplyr::mutate()
:
by_country <- by_country %>%
mutate(model = map(data, country_model))
by_country
## # A tibble: 142 x 4
## country continent data model
## <fct> <fct> <list> <list>
## 1 Afghanistan Asia <tibble [12 × 4]> <S3: lm>
## 2 Albania Europe <tibble [12 × 4]> <S3: lm>
## 3 Algeria Africa <tibble [12 × 4]> <S3: lm>
## 4 Angola Africa <tibble [12 × 4]> <S3: lm>
## 5 Argentina Americas <tibble [12 × 4]> <S3: lm>
## 6 Australia Oceania <tibble [12 × 4]> <S3: lm>
## 7 Austria Europe <tibble [12 × 4]> <S3: lm>
## 8 Bahrain Asia <tibble [12 × 4]> <S3: lm>
## 9 Bangladesh Asia <tibble [12 × 4]> <S3: lm>
## 10 Belgium Europe <tibble [12 × 4]> <S3: lm>
## # … with 132 more rows
by_country$model[[1]]
##
## Call:
## lm(formula = lifeExp ~ year, data = df)
##
## Coefficients:
## (Intercept) year
## -507.5343 0.2753
This has a big advantage: because all the related objects are stored together, you don’t need to manually keep them in sync when you filter or arrange. The semantics of the data frame takes care of that for you:
by_country %>%
filter(continent == "Europe")
## # A tibble: 30 x 4
## country continent data model
## <fct> <fct> <list> <list>
## 1 Albania Europe <tibble [12 × 4]> <S3: lm>
## 2 Austria Europe <tibble [12 × 4]> <S3: lm>
## 3 Belgium Europe <tibble [12 × 4]> <S3: lm>
## 4 Bosnia and Herzegovina Europe <tibble [12 × 4]> <S3: lm>
## 5 Bulgaria Europe <tibble [12 × 4]> <S3: lm>
## 6 Croatia Europe <tibble [12 × 4]> <S3: lm>
## 7 Czech Republic Europe <tibble [12 × 4]> <S3: lm>
## 8 Denmark Europe <tibble [12 × 4]> <S3: lm>
## 9 Finland Europe <tibble [12 × 4]> <S3: lm>
## 10 France Europe <tibble [12 × 4]> <S3: lm>
## # … with 20 more rows
by_country %>%
arrange(continent, country)
## # A tibble: 142 x 4
## country continent data model
## <fct> <fct> <list> <list>
## 1 Algeria Africa <tibble [12 × 4]> <S3: lm>
## 2 Angola Africa <tibble [12 × 4]> <S3: lm>
## 3 Benin Africa <tibble [12 × 4]> <S3: lm>
## 4 Botswana Africa <tibble [12 × 4]> <S3: lm>
## 5 Burkina Faso Africa <tibble [12 × 4]> <S3: lm>
## 6 Burundi Africa <tibble [12 × 4]> <S3: lm>
## 7 Cameroon Africa <tibble [12 × 4]> <S3: lm>
## 8 Central African Republic Africa <tibble [12 × 4]> <S3: lm>
## 9 Chad Africa <tibble [12 × 4]> <S3: lm>
## 10 Comoros Africa <tibble [12 × 4]> <S3: lm>
## # … with 132 more rows
Previously we computed the residuals of a single model with a single dataset. Now we have 142 data frames and 142 models. To compute the residuals, we need to call add_residuals()
with each model-data pair using the purrr::map2()
function (which works for binary functions like add_residuals()
):
# add residuals
by_country <- by_country %>%
mutate(data = map2(data, model, add_residuals))
by_country
## # A tibble: 142 x 4
## country continent data model
## <fct> <fct> <list> <list>
## 1 Afghanistan Asia <tibble [12 × 5]> <S3: lm>
## 2 Albania Europe <tibble [12 × 5]> <S3: lm>
## 3 Algeria Africa <tibble [12 × 5]> <S3: lm>
## 4 Angola Africa <tibble [12 × 5]> <S3: lm>
## 5 Argentina Americas <tibble [12 × 5]> <S3: lm>
## 6 Australia Oceania <tibble [12 × 5]> <S3: lm>
## 7 Austria Europe <tibble [12 × 5]> <S3: lm>
## 8 Bahrain Asia <tibble [12 × 5]> <S3: lm>
## 9 Bangladesh Asia <tibble [12 × 5]> <S3: lm>
## 10 Belgium Europe <tibble [12 × 5]> <S3: lm>
## # … with 132 more rows
by_country$data[[1]]
## # A tibble: 12 x 5
## year lifeExp pop gdpPercap resid
## <int> <dbl> <int> <dbl> <dbl>
## 1 1952 28.8 8425333 779. -1.11
## 2 1957 30.3 9240934 821. -0.952
## 3 1962 32.0 10267083 853. -0.664
## 4 1967 34.0 11537966 836. -0.0172
## 5 1972 36.1 13079460 740. 0.674
## 6 1977 38.4 14880372 786. 1.65
## 7 1982 39.9 12881816 978. 1.69
## 8 1987 40.8 13867957 852. 1.28
## 9 1992 41.7 16317921 649. 0.754
## 10 1997 41.8 22227415 635. -0.534
## 11 2002 42.1 25268405 727. -1.54
## 12 2007 43.8 31889923 975. -1.22
But how you can plot a list of data frames? Previously we used nest()
to turn a regular data frame into an nested data frame, and now we do the opposite with unnest()
:
# unnest statistics and force the dropping of alternative list-columns with .drop
(resids = unnest(by_country, data, .drop = TRUE))
## # A tibble: 1,704 x 7
## country continent year lifeExp pop gdpPercap resid
## <fct> <fct> <int> <dbl> <int> <dbl> <dbl>
## 1 Afghanistan Asia 1952 28.8 8425333 779. -1.11
## 2 Afghanistan Asia 1957 30.3 9240934 821. -0.952
## 3 Afghanistan Asia 1962 32.0 10267083 853. -0.664
## 4 Afghanistan Asia 1967 34.0 11537966 836. -0.0172
## 5 Afghanistan Asia 1972 36.1 13079460 740. 0.674
## 6 Afghanistan Asia 1977 38.4 14880372 786. 1.65
## 7 Afghanistan Asia 1982 39.9 12881816 978. 1.69
## 8 Afghanistan Asia 1987 40.8 13867957 852. 1.28
## 9 Afghanistan Asia 1992 41.7 16317921 649. 0.754
## 10 Afghanistan Asia 1997 41.8 22227415 635. -0.534
## # … with 1,694 more rows
Now we have regular data frame, we can plot the residuals:
resids %>%
ggplot(aes(year, resid)) +
geom_line(aes(group = country), alpha = 1 / 3) +
geom_smooth(se = FALSE) +
theme_classic()
Facetting by continent is particularly revealing:
resids %>%
ggplot(aes(year, resid)) +
geom_line(aes(group = country), alpha = 1 / 3) +
facet_wrap(~continent) +
theme_classic()
There’s something interesting going on in Africa: we see some very large residuals which suggests our model isn’t fitting so well there. To better investigate, we can add additional statistics with broom:glance()
function:
# add statistics
by_country <- by_country %>%
mutate(glance = map(model, glance))
by_country
## # A tibble: 142 x 5
## country continent data model glance
## <fct> <fct> <list> <list> <list>
## 1 Afghanistan Asia <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
## 2 Albania Europe <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
## 3 Algeria Africa <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
## 4 Angola Africa <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
## 5 Argentina Americas <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
## 6 Australia Oceania <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
## 7 Austria Europe <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
## 8 Bahrain Asia <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
## 9 Bangladesh Asia <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
## 10 Belgium Europe <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
## # … with 132 more rows
by_country$glance[[1]]
## # A tibble: 1 x 11
## r.squared adj.r.squared sigma statistic p.value df logLik AIC BIC
## <dbl> <dbl> <dbl> <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 0.948 0.942 1.22 181. 9.84e-8 2 -18.3 42.7 44.1
## # … with 2 more variables: deviance <dbl>, df.residual <int>
# unnest statistics and force the dropping of alternative list-columns with .drop
(glance = unnest(by_country, glance, .drop = TRUE))
## # A tibble: 142 x 13
## country continent r.squared adj.r.squared sigma statistic p.value df
## <fct> <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
## 1 Afghan… Asia 0.948 0.942 1.22 181. 9.84e- 8 2
## 2 Albania Europe 0.911 0.902 1.98 102. 1.46e- 6 2
## 3 Algeria Africa 0.985 0.984 1.32 662. 1.81e-10 2
## 4 Angola Africa 0.888 0.877 1.41 79.1 4.59e- 6 2
## 5 Argent… Americas 0.996 0.995 0.292 2246. 4.22e-13 2
## 6 Austra… Oceania 0.980 0.978 0.621 481. 8.67e-10 2
## 7 Austria Europe 0.992 0.991 0.407 1261. 7.44e-12 2
## 8 Bahrain Asia 0.967 0.963 1.64 291. 1.02e- 8 2
## 9 Bangla… Asia 0.989 0.988 0.977 930. 3.37e-11 2
## 10 Belgium Europe 0.995 0.994 0.293 1822. 1.20e-12 2
## # … with 132 more rows, and 5 more variables: logLik <dbl>, AIC <dbl>,
## # BIC <dbl>, deviance <dbl>, df.residual <int>
With this data frame in hand, we can start to look for models that don’t fit well:
glance %>%
arrange(r.squared)
## # A tibble: 142 x 13
## country continent r.squared adj.r.squared sigma statistic p.value df
## <fct> <fct> <dbl> <dbl> <dbl> <dbl> <dbl> <int>
## 1 Rwanda Africa 0.0172 -0.0811 6.56 0.175 0.685 2
## 2 Botswa… Africa 0.0340 -0.0626 6.11 0.352 0.566 2
## 3 Zimbab… Africa 0.0562 -0.0381 7.21 0.596 0.458 2
## 4 Zambia Africa 0.0598 -0.0342 4.53 0.636 0.444 2
## 5 Swazil… Africa 0.0682 -0.0250 6.64 0.732 0.412 2
## 6 Lesotho Africa 0.0849 -0.00666 5.93 0.927 0.358 2
## 7 Cote d… Africa 0.283 0.212 3.93 3.95 0.0748 2
## 8 South … Africa 0.312 0.244 4.74 4.54 0.0588 2
## 9 Uganda Africa 0.342 0.276 3.19 5.20 0.0457 2
## 10 Congo,… Africa 0.348 0.283 2.43 5.34 0.0434 2
## # … with 132 more rows, and 5 more variables: logLik <dbl>, AIC <dbl>,
## # BIC <dbl>, deviance <dbl>, df.residual <int>
The worst models all appear to be in Africa. Let’s double check that with a plot:
glance %>%
ggplot(aes(continent, r.squared)) +
geom_point(alpha = 1/3)
We could pull out the countries with particularly bad \(R^2\) and plot the data:
bad_fit <- filter(glance, r.squared < 0.25)
gapminder %>%
semi_join(bad_fit, by = "country") %>%
ggplot(aes(year, lifeExp, colour = country)) +
geom_line() +
theme_classic()
We see two main effects here: the tragedies of the HIV/AIDS epidemic and the Rwandan genocide. Finally, we plot the linear trend of life expectancy for good-fit countries:
good_fit <- filter(glance, r.squared > 0.95)
gapminder %>%
semi_join(good_fit, by = "country") %>%
ggplot(aes(year, lifeExp)) +
geom_line(aes(group = country), alpha = 1/3) +
geom_smooth(se = FALSE) +
theme_classic()