case_when(): define categorical variables based on existing variables

The case_when function is useful for vectorizing conditional statements. It is similar to ifelse but can output any number of values, as opposed to just TRUE or FALSE. Here is an example splitting numbers into negative, positives and 0:

x <- c(-2, -1, 0, 1, 2)
case_when(x < 0 ~ "Negative", x > 0 ~ "Positive", TRUE ~ "Zero")
#> [1] "Negative" "Negative" "Zero"     "Positive" "Positive"

A common use for this function is to define categorical variables based on existing variables. For example, suppose we we want compare the murder rates in in three groups of states: New England, West Coast, South, and other. For each state, we need to ask if it is in New England, if it is not we ask if it is in the West Coast, if not we ask if it is in the South and if not we assign other. Here is how we use case_when to do this:

data(murders)
murders %>% 
  mutate(group = case_when(
    abb %in% c("ME", "NH", "VT", "MA", "RI", "CT") ~ "New England",
    abb %in% c("WA", "OR", "CA") ~ "West Coast",
    region == "South" ~ "South",
    TRUE ~ "other")) %>%
  group_by(group) %>%
  summarize(rate = sum(total) / sum(population) * 10^5) %>%
  arrange(rate)
#> # A tibble: 4 x 2
#>   group        rate
#>   <chr>       <dbl>
#> 1 New England  1.72
#> 2 other        2.71
#> 3 West Coast   2.90
#> 4 South        3.63

Instruction

Run the sample code to see how case_when() function works.

library(dplyr) library(dslabs) data(heights) data(murders) murders <- murders %>% mutate(rate = total/population*100000) us_murder_rate <- murders %>% summarize(rate = sum(total) / sum(population) * 100000) # Adding categorical column called group based on the region data(murders) murders %>% mutate(group = case_when( abb %in% c("ME", "NH", "VT", "MA", "RI", "CT") ~ "New England", abb %in% c("WA", "OR", "CA") ~ "West Coast", region == "South" ~ "South", TRUE ~ "other")) %>% group_by(group) %>% summarize(rate = sum(total) / sum(population) * 10^5) %>% arrange(rate)

Previous: 3-11 | top_n(): see the top n rows

Next: 4-1 | Import file

Back to Main