library(dplyr)
library(ggplot2)
library(tidyr)
library(purrr)
library(modelr)
library(gapminder)
library(broom)

In this challenge you’re going to explore three powerful ideas that help you to work with large numbers of models with ease:

  1. Using many simple models to better understand complex datasets.

  2. Using list-columns to store arbitrary data structures in a data frame. For example, this will allow you to have a column that contains linear models.

  3. Using the broom package, by David Robinson, to turn models into tidy data. This is a powerful technique for working with large numbers of models because once you have tidy data, you can apply all of the techniques that you’ve learned about earlier in the book.

To motivate the power of many simple models, we’re going to look into the gapminder data. This data was popularised by Hans Rosling, a Swedish doctor and statistician. The gapminder data summarises the progression of countries over time, looking at statistics like life expectancy and GDP. The data is easy to access in R, thanks to Jenny Bryan who created the gapminder package.

gapminder
## # A tibble: 1,704 x 6
##    country     continent  year lifeExp      pop gdpPercap
##    <fct>       <fct>     <int>   <dbl>    <int>     <dbl>
##  1 Afghanistan Asia       1952    28.8  8425333      779.
##  2 Afghanistan Asia       1957    30.3  9240934      821.
##  3 Afghanistan Asia       1962    32.0 10267083      853.
##  4 Afghanistan Asia       1967    34.0 11537966      836.
##  5 Afghanistan Asia       1972    36.1 13079460      740.
##  6 Afghanistan Asia       1977    38.4 14880372      786.
##  7 Afghanistan Asia       1982    39.9 12881816      978.
##  8 Afghanistan Asia       1987    40.8 13867957      852.
##  9 Afghanistan Asia       1992    41.7 16317921      649.
## 10 Afghanistan Asia       1997    41.8 22227415      635.
## # … with 1,694 more rows

In this case study, we’re going to focus on just three variables to answer the question:

How does life expectancy change over time for each country?

A good place to start is with a plot:

ggplot(gapminder, aes(year, lifeExp)) +
    geom_line(aes(group = country), alpha = 1/3) +
    theme_classic()

Notice the group aesthetics. Many geoms, like geom_line() and geom_smooth(), use a single geometric object to display multiple rows of data. For these geoms, you can set the group aesthetic to a categorical variable to draw multiple objects. ggplot2 will draw a separate object for each unique value of the grouping variable.

This is a small dataset: it only has ~1,700 observations and 3 variables. But it’s still hard to see what’s going on! Overall, it looks like life expectancy has been steadily improving. However, if you look closely, you might notice some countries that don’t follow this pattern. How can we make those countries easier to see?

There’s a strong signal (overall linear growth) that makes it hard to see subtler trends. We’ll tease these factors apart by fitting a model with a linear trend. The model captures steady growth over time, and the residuals will show what’s left. You already know how to do that if we had a single country:

it <- filter(gapminder, country == "Italy")
it %>% 
  ggplot(aes(year, lifeExp)) + 
  geom_line() + 
  ggtitle("Full data")

it_mod <- lm(lifeExp ~ year, data = it)
it %>% 
  add_predictions(it_mod) %>%
  ggplot(aes(year, pred)) + 
  geom_line() + 
  ggtitle("Linear trend")

it %>% 
  add_residuals(it_mod) %>% 
  ggplot(aes(year, resid)) + 
  geom_hline(yintercept = 0, colour = "white", size = 3) + 
  geom_line() + 
  ggtitle("Remaining pattern")

How can we easily fit that model to every country? Extract out the common code with a function and repeat using a map function from purrr. To do that, we need a new data structure: the nested data frame. To create a nested data frame we start with a grouped data frame, and nest it:

by_country <- gapminder %>% 
  group_by(country, continent) %>% 
  nest()

by_country
## # A tibble: 142 x 3
##    country     continent data             
##    <fct>       <fct>     <list>           
##  1 Afghanistan Asia      <tibble [12 × 4]>
##  2 Albania     Europe    <tibble [12 × 4]>
##  3 Algeria     Africa    <tibble [12 × 4]>
##  4 Angola      Africa    <tibble [12 × 4]>
##  5 Argentina   Americas  <tibble [12 × 4]>
##  6 Australia   Oceania   <tibble [12 × 4]>
##  7 Austria     Europe    <tibble [12 × 4]>
##  8 Bahrain     Asia      <tibble [12 × 4]>
##  9 Bangladesh  Asia      <tibble [12 × 4]>
## 10 Belgium     Europe    <tibble [12 × 4]>
## # … with 132 more rows
by_country$data[[1]]
## # A tibble: 12 x 4
##     year lifeExp      pop gdpPercap
##    <int>   <dbl>    <int>     <dbl>
##  1  1952    28.8  8425333      779.
##  2  1957    30.3  9240934      821.
##  3  1962    32.0 10267083      853.
##  4  1967    34.0 11537966      836.
##  5  1972    36.1 13079460      740.
##  6  1977    38.4 14880372      786.
##  7  1982    39.9 12881816      978.
##  8  1987    40.8 13867957      852.
##  9  1992    41.7 16317921      649.
## 10  1997    41.8 22227415      635.
## 11  2002    42.1 25268405      727.
## 12  2007    43.8 31889923      975.

Now that we have our nested data frame, we’re in a good position to fit some models. We have a model-fitting function:

country_model <- function(df) {
  lm(lifeExp ~ year, data = df)
}

And we want to apply it to every data frame. The data frames are in a list, so we can use purrr::map() to apply country_model to each element. We’re going to create a new variable in the by_country data frame with dplyr::mutate():

by_country <- by_country %>% 
  mutate(model = map(data, country_model))
by_country
## # A tibble: 142 x 4
##    country     continent data              model   
##    <fct>       <fct>     <list>            <list>  
##  1 Afghanistan Asia      <tibble [12 × 4]> <S3: lm>
##  2 Albania     Europe    <tibble [12 × 4]> <S3: lm>
##  3 Algeria     Africa    <tibble [12 × 4]> <S3: lm>
##  4 Angola      Africa    <tibble [12 × 4]> <S3: lm>
##  5 Argentina   Americas  <tibble [12 × 4]> <S3: lm>
##  6 Australia   Oceania   <tibble [12 × 4]> <S3: lm>
##  7 Austria     Europe    <tibble [12 × 4]> <S3: lm>
##  8 Bahrain     Asia      <tibble [12 × 4]> <S3: lm>
##  9 Bangladesh  Asia      <tibble [12 × 4]> <S3: lm>
## 10 Belgium     Europe    <tibble [12 × 4]> <S3: lm>
## # … with 132 more rows
by_country$model[[1]]
## 
## Call:
## lm(formula = lifeExp ~ year, data = df)
## 
## Coefficients:
## (Intercept)         year  
##   -507.5343       0.2753

This has a big advantage: because all the related objects are stored together, you don’t need to manually keep them in sync when you filter or arrange. The semantics of the data frame takes care of that for you:

by_country %>% 
  filter(continent == "Europe")
## # A tibble: 30 x 4
##    country                continent data              model   
##    <fct>                  <fct>     <list>            <list>  
##  1 Albania                Europe    <tibble [12 × 4]> <S3: lm>
##  2 Austria                Europe    <tibble [12 × 4]> <S3: lm>
##  3 Belgium                Europe    <tibble [12 × 4]> <S3: lm>
##  4 Bosnia and Herzegovina Europe    <tibble [12 × 4]> <S3: lm>
##  5 Bulgaria               Europe    <tibble [12 × 4]> <S3: lm>
##  6 Croatia                Europe    <tibble [12 × 4]> <S3: lm>
##  7 Czech Republic         Europe    <tibble [12 × 4]> <S3: lm>
##  8 Denmark                Europe    <tibble [12 × 4]> <S3: lm>
##  9 Finland                Europe    <tibble [12 × 4]> <S3: lm>
## 10 France                 Europe    <tibble [12 × 4]> <S3: lm>
## # … with 20 more rows
by_country %>% 
  arrange(continent, country)
## # A tibble: 142 x 4
##    country                  continent data              model   
##    <fct>                    <fct>     <list>            <list>  
##  1 Algeria                  Africa    <tibble [12 × 4]> <S3: lm>
##  2 Angola                   Africa    <tibble [12 × 4]> <S3: lm>
##  3 Benin                    Africa    <tibble [12 × 4]> <S3: lm>
##  4 Botswana                 Africa    <tibble [12 × 4]> <S3: lm>
##  5 Burkina Faso             Africa    <tibble [12 × 4]> <S3: lm>
##  6 Burundi                  Africa    <tibble [12 × 4]> <S3: lm>
##  7 Cameroon                 Africa    <tibble [12 × 4]> <S3: lm>
##  8 Central African Republic Africa    <tibble [12 × 4]> <S3: lm>
##  9 Chad                     Africa    <tibble [12 × 4]> <S3: lm>
## 10 Comoros                  Africa    <tibble [12 × 4]> <S3: lm>
## # … with 132 more rows

Previously we computed the residuals of a single model with a single dataset. Now we have 142 data frames and 142 models. To compute the residuals, we need to call add_residuals() with each model-data pair using the purrr::map2() function (which works for binary functions like add_residuals()):

# add residuals
by_country <- by_country %>% 
  mutate(data = map2(data, model, add_residuals))

by_country
## # A tibble: 142 x 4
##    country     continent data              model   
##    <fct>       <fct>     <list>            <list>  
##  1 Afghanistan Asia      <tibble [12 × 5]> <S3: lm>
##  2 Albania     Europe    <tibble [12 × 5]> <S3: lm>
##  3 Algeria     Africa    <tibble [12 × 5]> <S3: lm>
##  4 Angola      Africa    <tibble [12 × 5]> <S3: lm>
##  5 Argentina   Americas  <tibble [12 × 5]> <S3: lm>
##  6 Australia   Oceania   <tibble [12 × 5]> <S3: lm>
##  7 Austria     Europe    <tibble [12 × 5]> <S3: lm>
##  8 Bahrain     Asia      <tibble [12 × 5]> <S3: lm>
##  9 Bangladesh  Asia      <tibble [12 × 5]> <S3: lm>
## 10 Belgium     Europe    <tibble [12 × 5]> <S3: lm>
## # … with 132 more rows
by_country$data[[1]]
## # A tibble: 12 x 5
##     year lifeExp      pop gdpPercap   resid
##    <int>   <dbl>    <int>     <dbl>   <dbl>
##  1  1952    28.8  8425333      779. -1.11  
##  2  1957    30.3  9240934      821. -0.952 
##  3  1962    32.0 10267083      853. -0.664 
##  4  1967    34.0 11537966      836. -0.0172
##  5  1972    36.1 13079460      740.  0.674 
##  6  1977    38.4 14880372      786.  1.65  
##  7  1982    39.9 12881816      978.  1.69  
##  8  1987    40.8 13867957      852.  1.28  
##  9  1992    41.7 16317921      649.  0.754 
## 10  1997    41.8 22227415      635. -0.534 
## 11  2002    42.1 25268405      727. -1.54  
## 12  2007    43.8 31889923      975. -1.22

But how you can plot a list of data frames? Previously we used nest() to turn a regular data frame into an nested data frame, and now we do the opposite with unnest():

# unnest statistics and force the dropping of alternative list-columns with .drop
(resids = unnest(by_country, data, .drop = TRUE))
## # A tibble: 1,704 x 7
##    country     continent  year lifeExp      pop gdpPercap   resid
##    <fct>       <fct>     <int>   <dbl>    <int>     <dbl>   <dbl>
##  1 Afghanistan Asia       1952    28.8  8425333      779. -1.11  
##  2 Afghanistan Asia       1957    30.3  9240934      821. -0.952 
##  3 Afghanistan Asia       1962    32.0 10267083      853. -0.664 
##  4 Afghanistan Asia       1967    34.0 11537966      836. -0.0172
##  5 Afghanistan Asia       1972    36.1 13079460      740.  0.674 
##  6 Afghanistan Asia       1977    38.4 14880372      786.  1.65  
##  7 Afghanistan Asia       1982    39.9 12881816      978.  1.69  
##  8 Afghanistan Asia       1987    40.8 13867957      852.  1.28  
##  9 Afghanistan Asia       1992    41.7 16317921      649.  0.754 
## 10 Afghanistan Asia       1997    41.8 22227415      635. -0.534 
## # … with 1,694 more rows

Now we have regular data frame, we can plot the residuals:

resids %>% 
  ggplot(aes(year, resid)) +
    geom_line(aes(group = country), alpha = 1 / 3) + 
    geom_smooth(se = FALSE) +
    theme_classic()

Facetting by continent is particularly revealing:

resids %>% 
  ggplot(aes(year, resid)) +
    geom_line(aes(group = country), alpha = 1 / 3) + 
    facet_wrap(~continent)  +
    theme_classic()

There’s something interesting going on in Africa: we see some very large residuals which suggests our model isn’t fitting so well there. To better investigate, we can add additional statistics with broom:glance() function:

# add statistics
by_country <- by_country %>% 
  mutate(glance = map(model, glance))

by_country
## # A tibble: 142 x 5
##    country     continent data              model    glance           
##    <fct>       <fct>     <list>            <list>   <list>           
##  1 Afghanistan Asia      <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
##  2 Albania     Europe    <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
##  3 Algeria     Africa    <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
##  4 Angola      Africa    <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
##  5 Argentina   Americas  <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
##  6 Australia   Oceania   <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
##  7 Austria     Europe    <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
##  8 Bahrain     Asia      <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
##  9 Bangladesh  Asia      <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
## 10 Belgium     Europe    <tibble [12 × 5]> <S3: lm> <tibble [1 × 11]>
## # … with 132 more rows
by_country$glance[[1]]
## # A tibble: 1 x 11
##   r.squared adj.r.squared sigma statistic p.value    df logLik   AIC   BIC
##       <dbl>         <dbl> <dbl>     <dbl>   <dbl> <int>  <dbl> <dbl> <dbl>
## 1     0.948         0.942  1.22      181. 9.84e-8     2  -18.3  42.7  44.1
## # … with 2 more variables: deviance <dbl>, df.residual <int>
# unnest statistics and force the dropping of alternative list-columns with .drop
(glance = unnest(by_country, glance, .drop = TRUE))
## # A tibble: 142 x 13
##    country continent r.squared adj.r.squared sigma statistic  p.value    df
##    <fct>   <fct>         <dbl>         <dbl> <dbl>     <dbl>    <dbl> <int>
##  1 Afghan… Asia          0.948         0.942 1.22      181.  9.84e- 8     2
##  2 Albania Europe        0.911         0.902 1.98      102.  1.46e- 6     2
##  3 Algeria Africa        0.985         0.984 1.32      662.  1.81e-10     2
##  4 Angola  Africa        0.888         0.877 1.41       79.1 4.59e- 6     2
##  5 Argent… Americas      0.996         0.995 0.292    2246.  4.22e-13     2
##  6 Austra… Oceania       0.980         0.978 0.621     481.  8.67e-10     2
##  7 Austria Europe        0.992         0.991 0.407    1261.  7.44e-12     2
##  8 Bahrain Asia          0.967         0.963 1.64      291.  1.02e- 8     2
##  9 Bangla… Asia          0.989         0.988 0.977     930.  3.37e-11     2
## 10 Belgium Europe        0.995         0.994 0.293    1822.  1.20e-12     2
## # … with 132 more rows, and 5 more variables: logLik <dbl>, AIC <dbl>,
## #   BIC <dbl>, deviance <dbl>, df.residual <int>

With this data frame in hand, we can start to look for models that don’t fit well:

glance %>% 
  arrange(r.squared)
## # A tibble: 142 x 13
##    country continent r.squared adj.r.squared sigma statistic p.value    df
##    <fct>   <fct>         <dbl>         <dbl> <dbl>     <dbl>   <dbl> <int>
##  1 Rwanda  Africa       0.0172      -0.0811   6.56     0.175  0.685      2
##  2 Botswa… Africa       0.0340      -0.0626   6.11     0.352  0.566      2
##  3 Zimbab… Africa       0.0562      -0.0381   7.21     0.596  0.458      2
##  4 Zambia  Africa       0.0598      -0.0342   4.53     0.636  0.444      2
##  5 Swazil… Africa       0.0682      -0.0250   6.64     0.732  0.412      2
##  6 Lesotho Africa       0.0849      -0.00666  5.93     0.927  0.358      2
##  7 Cote d… Africa       0.283        0.212    3.93     3.95   0.0748     2
##  8 South … Africa       0.312        0.244    4.74     4.54   0.0588     2
##  9 Uganda  Africa       0.342        0.276    3.19     5.20   0.0457     2
## 10 Congo,… Africa       0.348        0.283    2.43     5.34   0.0434     2
## # … with 132 more rows, and 5 more variables: logLik <dbl>, AIC <dbl>,
## #   BIC <dbl>, deviance <dbl>, df.residual <int>

The worst models all appear to be in Africa. Let’s double check that with a plot:

glance %>% 
  ggplot(aes(continent, r.squared)) + 
  geom_point(alpha = 1/3)

We could pull out the countries with particularly bad \(R^2\) and plot the data:

bad_fit <- filter(glance, r.squared < 0.25)

gapminder %>% 
  semi_join(bad_fit, by = "country") %>% 
  ggplot(aes(year, lifeExp, colour = country)) +
    geom_line() +
    theme_classic()

We see two main effects here: the tragedies of the HIV/AIDS epidemic and the Rwandan genocide. Finally, we plot the linear trend of life expectancy for good-fit countries:

good_fit <- filter(glance, r.squared > 0.95)

gapminder %>% 
  semi_join(good_fit, by = "country") %>% 
  ggplot(aes(year, lifeExp)) +
  geom_line(aes(group = country), alpha = 1/3) +
  geom_smooth(se = FALSE) +
  theme_classic()