Making the {survey} package hundreds of times faster using {Rcpp}

A common objection to using the {survey} package instead of SAS or Stata is the computational time it requires. In this post, I show that we can easily obtain hundred-fold speed improvements in its core functions by using the {Rcpp} and {RcppArmadillo} packages. To illustrate, I show how we can make svytotal() run over 500 times faster. Finally, I offer thoughts about how to incorporate these {Rcpp}-based functions in either the {survey} package or a potential add-on R package.

statistics
surveys
R
survey package
Author
Published

December 14, 2021

In industry, I’ve encountered two main objections against using the open-source {survey} package instead of Stata or SAS. First, many statisticians accustomed to closed-source software consider R packages such as {survey} to be less trustworthy. That often stems from a lack of familiarity with open-source software, although there are some valid reasons for this concern.1 Second, the {survey} package is often thought to be relatively slow. This can be a problem for large data sets (such as the BRFSS survey) or for computationally-intensive tasks such as simulation studies. There have been recent improvements on this front: the package has increasingly supported the use of database backends and parallelization. But as we’ll see in this post, this problem is surprisingly easy to address.

The bottleneck

For most functions in the package, such as svytotal() or svyratio(), the variance calculation is a clear bottleneck. Often the statistic being calculated takes only a few simple steps of calculation (e.g. for totals, just multiply the variable by the weights and add it up), but the variance computation is much more involved. For example, for raked designs a QR-decomposition has to be performed, and for multistage designs there are recursive function calls (which in R can be computationally expensive).

In the last post on this blog, we walked through the {survey} R functions used to estimate linearization sampling variances for multistage designs. In a nutshell, svyrecvar() does some data pre-processing and then calls survey::multistage(), which is used recursively to estimate variance contributions from multiple stages of sampling. For any given stage of sampling, the variance contribution is estimated using survey:::onestage(), which makes use of survey:::onestrat() to get the variance contribution from a given stratum. These functions are the “workhorse” functions of the {survey} package. Whenever sampling variances are estimated using linearization–for totals, means, ratios, GLMs, you name it–these functions are going to be used. For this reason, increasing the speed of these functions would improve the speed of most any variance calculation for non-replicate designs.

In short, if we want to make the {survey} package faster, we need to improve the handful of functions mentioned above (svyrecvar() and its helpers).

Enter RCpp

The R language makes data analysis easier, but it needs help if you want to make it fast, too. In its early days, R was supplemented by writing routines in C or Fortran and instructing R to call those routines. Nowadays, normally C++ is used instead of C or Fortran. R developers looking for a quick performance boost often turn to the {Rcpp} package, an incredible R package which makes it remarkably easy to use C++ inside R functions and R packages. And for calculations involving subsetting and linear algebra, the {RcppArmadillo} package helps out further by allowing users to take advantage of Armadillo, a high-performance and relatively-readable C++ linear algebra library.

If you’re curious, below is a simple example (based on the {RcppArmadillo} package documentation) showing how to use {Rcpp} to create an R function which makes use of fast C++ code.

Example of using {Rcpp} to drastically speed up the fitting of a linear model

To create an R function using C++ we can write the C++ code as a string and then supply it to {Rcpp} package’s handy function Rcpp::cppFunction(), which wraps up the C++ code into an R function.

library(Rcpp)

# Create an R function for fitting a linear model,
# which make use of a C++ function

cpp_code <- (
'using namespace Rcpp;

// [[Rcpp::export]]
List fastLm_cpp(const arma::vec & y, const arma::mat & X) {

    int n = X.n_rows, k = X.n_cols;
   
    arma::colvec coef = arma::solve(X, y); 
    arma::colvec resid = y - X*coef; 
   
    double sig2 = arma::as_scalar(arma::trans(resid)*resid/(n-k));
    arma::colvec stderrest = 
        arma::sqrt(sig2 * arma::diagvec( arma::inv(arma::trans(X)*X)) );
   
   return List::create(Named("coefficients") = coef,
                       Named("stderr")       = stderrest);
}')

cppFunction(cpp_code, depends = "RcppArmadillo")
Warning: No function found for Rcpp::export attribute at file70042bd43868.cpp:8
# Use the function
y <- iris$Sepal.Length
X <- as.matrix(iris[,c("Sepal.Width", "Petal.Width")])

fastLm_cpp(y = y,
           X = X)
$coefficients
         [,1]
[1,] 1.392293
[2,] 1.282020

$stderr
           [,1]
[1,] 0.02749755
[2,] 0.05981114

To understand the speed of the Rcpp-based function against base R, we’ll write a similar function in pure R.

# Create a "pure" R function to fit a linear model
fastlm_pureR <- function(y, X) {
  coef <- qr.solve(a = X,
                   b = y)
  resids <- y - (X %*% coef)
  df <- nrow(X) - ncol(X)
  
  sig2 <- as.vector(crossprod(as.matrix(resids)) / df)
  stderr <- sqrt(sig2 * diag(qr.solve(t(X) %*% X)))
  return(list('coefficients' = coef,
              'stderr' = stderr))
}

y <- iris$Sepal.Length
X <- as.matrix(iris[,c("Sepal.Width", "Petal.Width")])

# Use the function
fastlm_pureR(y = iris$Sepal.Length,
             X = as.matrix(iris[,c("Sepal.Width", "Petal.Width")]))
$coefficients
Sepal.Width Petal.Width 
   1.392293    1.282020 

$stderr
[1] 0.02749755 0.05981114

Using the {microbenchmark} package to record how long it takes for each function to run, we can see that the Rcpp-based function runs several times faster than the comparable pure R function, and both run much faster than the base lm() function in R (which does a lot more and so of course will take longer to run).

# Compare the time required to run each function
microbenchmark::microbenchmark(
  'fastLm, Rcpp' = fastLm_cpp(y = y,
                              X = X),
  'fastLm, pure R' = fastlm_pureR(y = y,
                                  X = X),
  'base R lm()' = lm(formula = Sepal.Length ~ -1 + Sepal.Width + Petal.Width,
                     data = iris)
)
Unit: microseconds
           expr   min    lq    mean median     uq    max neval cld
   fastLm, Rcpp   6.4   7.9  36.027  15.75  28.50 1529.2   100 a  
 fastLm, pure R 137.3 145.4 267.029 164.35 194.00 9079.3   100  b 
    base R lm() 502.9 535.7 693.551 566.85 644.25 4387.4   100   c

Re-writing survey:::multistage() with {Rcpp} and {RcppArmadillo}

In this script of a GitHub repo, I’ve re-written the workhorse functions multistage() and onestage() in C++, relying heavily on the Armadillo library to help with memory-efficient subsetting and matrix algebra. To use these C++ functions in an R session, we can simply supply the C++ code to the {Rcpp} function sourceCpp().

library(Rcpp)

cpp_code_file <- "https://raw.githubusercontent.com/bschneidr/surveycpp/main/arma_multistage.cpp"

cpp_code <- readLines(cpp_code_file) |>
  paste(collapse = "\n")

sourceCpp(code = cpp_code)

To illustrate the code’s usage, we’ll generate an example data set named sample_data. This data set contains 6,014 rows, has two numeric outcome variables of interest, api00 and api99, and is a sample drawn using stratified, multistage sampling without replacement. The relevant design variables are the stratum and sampling unit identifiers (stratum, psu_id, ssu_id) and columns giving the first and second-stage population sizes of each stratum (N_psus and N_ssus).

View R code used to generate example data
# Create an example survey design ----

  set.seed(1999)

  library(survey)
  options("survey.ultimate.cluster" = FALSE)
  data(api, package = 'survey')

##_ Create a fake population to sample from
  population <- MASS::mvrnorm(n = 100000,
                              mu = colMeans(apipop[,c('api00', 'api99')]),
                              Sigma = cov(apipop[,c('api00', 'api99')])) |>
    apply(MARGIN = 2, FUN = round) |> `colnames<-`(c("api00", "api99")) |>
    as.data.frame()

  population <- cbind(population,
                      'stratum' = sample(x = c(1:10),
                                         size = nrow(population),
                                         replace = TRUE),
                      'psu_id' = sample(x = c(1:500),
                                        size = nrow(population),
                                        replace = TRUE),
                      'ssu_id' = sample(x = c(1:50),
                                        size = nrow(population),
                                        replace = TRUE))
  rownames(population) <- 1:nrow(population)

##_ Add columns giving population sizes needed for FPCs
  population <- aggregate(x = population$psu_id,
                          by = population[,'stratum', drop = FALSE],
                          FUN = function(x) length(unique(x))) |>
    setNames(c("stratum", 'N_psus')) |>
    merge(x = population,
          by = c("stratum"))

  population <- aggregate(x = population$ssu_id,
                          by = population[,c('stratum', 'psu_id'), drop = FALSE],
                          FUN = function(x) length(unique(x))) |>
    setNames(c("stratum", 'psu_id' , 'N_ssus')) |>
    merge(x = population,
          by = c("stratum", "psu_id"))
  
    population <- population[,c('stratum', 'psu_id', 'ssu_id',
                                'N_psus', 'N_ssus',
                                'api00', 'api99')]

##_ Draw stratified multistage sample
  population$is_sampled <- FALSE

  for (stratum_h in unique(population$stratum)) {
    stratum_psus <- population |>
      subset(stratum == stratum_h) |>
      getElement("psu_id") |>
      unique()

    sampled_psus <- sample(stratum_psus, size = 100, replace = FALSE)

    for (h_psu in sampled_psus) {

      ssus_in_psu_of_stratum_h <- population |>
        subset(stratum == stratum_h & psu_id == h_psu) |>
        getElement("ssu_id") |>
        unique()

      sampled_ssus <- sample(ssus_in_psu_of_stratum_h, size = 5, replace = FALSE)

      sample_indices <- which(population$stratum == stratum_h &
                              population$psu_id == h_psu &
                              population$ssu_id %in% sampled_ssus)

      population[['is_sampled']][sample_indices] <- TRUE

    }
  }

  sample_data <- population[population$is_sampled,]
  sample_data[['is_sampled']] <- NULL

Based on this sample data, we’ll create a survey design object.

head(sample_data)
    stratum psu_id ssu_id N_psus N_ssus api00 api99
321       1    114     29    500     17   567   502
322       1    114     17    500     17   828   766
326       1    114     26    500     17   911   910
327       1    114     48    500     17   620   601
334       1    114     17    500     17   686   649
340       1    114     33    500     17   541   528
# Declare survey design ----
  library(survey)
  
  multistage_design <- svydesign(strata = ~ stratum,
                                 id = ~ psu_id + ssu_id,
                                 fpc = ~ N_psus + N_ssus,
                                 nest = TRUE,
                                 data = sample_data)

  print(multistage_design)
Stratified 2 - level Cluster Sampling design
With (1000, 5000) clusters.
svydesign(strata = ~stratum, id = ~psu_id + ssu_id, fpc = ~N_psus + 
    N_ssus, nest = TRUE, data = sample_data)
# Tell R to estimate variances from multiple stages of sampling
# and not just the first
  options("survey.ultimate.cluster" = FALSE)

Now suppose we want to estimate the totals for the variables api00 and api99 as well as their sampling variances.

estimated_totals <- svytotal(~ api00 + api99,
                             design = multistage_design)
vcov(estimated_totals)
             api00        api99
api00 288436238871 275908990143
api99 275908990143 265410510692

This takes around three seconds to run on my machine.

system.time(
  svytotal(~ api00 + api99,
           design = multistage_design)
)
   user  system elapsed 
   1.82    0.06    2.80 

Now let’s look at the variance estimation being done here by survey:::multistage(). To use survey:::multistage(), we first have to extract the necessary matrices of inputs from the survey design object.

# Extract inputs ----
  Y = as.matrix(multistage_design$variables[,c('api00', 'api99')])

  ##_ Weight the variables of interest
  Y_wtd <- Y |> apply(MARGIN = 2,
                      function(x) x/multistage_design$prob)
  
  ##_ Strata, clusters, and population size information
  ##_ need to be represented as numeric matrices
  strata = lapply(multistage_design$strata,
                  MARGIN = 2, FUN = as.numeric) |>
    do.call(what = cbind)
  
  clusters <- multistage_design$cluster |>
    lapply(as.numeric) |> Reduce(f = cbind)
  
  strata_samp_sizes = as.matrix(multistage_design$fpc$sampsize)
  strata_pop_sizes = as.matrix(multistage_design$fpc$popsize)
survey:::multistage(x = Y_wtd,
                    clusters = clusters,
                    stratas = strata,
                    nPSUs = strata_samp_sizes,
                    fpcs = strata_pop_sizes,
                    lonely.psu = 'adjust',
                    cal = NULL)
             api00        api99
api00 288436238871 275908990143
api99 275908990143 265410510692

For comparison, here’s the {Rcpp}-based function, named arma_multistage() due to its heavy use of {RcppArmadillo}.

arma_multistage(Y = Y_wtd,
                samp_unit_ids = clusters,
                strata_ids = strata,
                strata_samp_sizes = strata_samp_sizes,
                strata_pop_sizes = strata_pop_sizes,
                singleton_method = 'adjust',
                use_only_first_stage = FALSE)
             [,1]         [,2]
[1,] 288436238871 275908990143
[2,] 275908990143 265410510692

Let’s compare the run-time of the two functions.

  microbenchmark::microbenchmark(
    'arma_multistage' = arma_multistage(Y = Y_wtd,
                                        samp_unit_ids = clusters,
                                        strata_ids = strata,
                                        strata_samp_sizes = strata_samp_sizes,
                                        strata_pop_sizes = strata_pop_sizes,
                                        singleton_method = 'adjust',
                                        use_only_first_stage = FALSE),
    'survey:::multistage' =   survey:::multistage(x = Y_wtd,
                                                  clusters = clusters,
                                                  stratas = strata,
                                                  nPSUs = strata_samp_sizes,
                                                  fpcs = strata_pop_sizes,
                                                  lonely.psu = 'adjust',
                                                  one.stage = FALSE,
                                                  cal = NULL)
  )
Unit: milliseconds
                expr      min       lq       mean   median       uq      max
     arma_multistage   3.3362   3.4556   3.642773   3.5289   3.6835   5.5544
 survey:::multistage 382.9088 402.4978 417.453081 414.1363 429.8123 582.9900
 neval cld
   100  a 
   100   b

From this output, we can see that the arma_multistage() function runs well over a hundred times faster than survey:::multistage().

The payoff

Let’s see what kind of impact this would have on the svytotal() function, which is one of the most commonly-used functions in the {survey} package.

First, we’ll update the svyrecvar() function to make use of the arma_multistage() function instead of survey:::multistage(). To do this, I created a copy of the {survey} package, renamed {fastsurvey}, and set it up to use the C++ based version of multistage() inside of svyrecvar().

# Install the R package from GitHub
remotes::install_github("bschneidr/fastsurvey")

Let’s make sure we’re getting the same results. If there are any differences whatsoever in the results, the testthat::expect_equal() function will catch it and throw an error; otherwise, it will be silent.

survey::svytotal(~ api00 + api99,
                 multistage_design)
         total     SE
api00 65801411 537063
api99 62512510 515180
fastsurvey::svytotal(~ api00 + api99,
                     multistage_design)
         total     SE
api00 65801411 537063
api99 62512510 515180
# An error will be thrown if these aren't equal
testthat::expect_equal(
  object = survey::svytotal(~ api00 + api99,
                            multistage_design),
  expected = fastsurvey::svytotal(~ api00 + api99,
                                  multistage_design)
)

And here’s our payoff: the version of svytotal() which makes use of {Rcpp} is over 600 times faster here! All we had to do was speed up multistage() using {Rcpp}.

microbenchmark::microbenchmark(
  'survey::svytotal' = survey::svytotal(~ api00 + api99,
                                        multistage_design),
  'fastsurvey::svytotal' = fastsurvey::svytotal(~ api00 + api99,
                                                multistage_design)
)
Unit: milliseconds
                 expr       min        lq        mean    median        uq
     survey::svytotal 2608.1698 2663.8019 2729.233917 2699.4612 2800.3312
 fastsurvey::svytotal    4.8675    5.1495    5.679916    5.3138    5.8946
       max neval cld
 2979.5283   100   b
   14.3231   100  a 

Incorporating Rcpp into the {survey} package (or an add-on package)

The {fastsurvey} fork of the {survey} package illustrates the minimal steps needed for the {survey} package to use {Rcpp} and a faster version of multistage(), named multistage_rcpp(). Before actually incorporating this into any version of the {survey} package which would appear on CRAN, it would be necessary to thoroughly check to make sure the package is giving correct results with the update across a range of designs. This would be easiest to carry out by using automated unit tests such as those facilitated by the {testthat} package, which in general would be a helpful addition to the package. In terms of statistical details to iron out, there are some further small changes to the C++ functions which would be needed to handle G-calibration using stage-specific cluster information..

If it’s not a good fit to incorporate the C++ functions into the {survey} package, it could be possible to make use of them in another R package which extends {survey}. The downside of that approach is that it would involve lots of duplicated effort, since a new class of design object would have to be created (e.g. "survey.design2.rcpp") and new methods for all the functions svytotal(), svymean(), etc. would have to be implemented to work with this new class. An upside though of making a new package would be that the new package could have multiple maintainers working on it (Dr. Lumley, the {srvyr} team, etc.) and would have more flexibility to use new package infrastructure tools that have become available in recent years (e.g. pkgdown or testthat). Plus, it would offer another way for R users and organizations to use the {survey} package in production code with stable, thoroughly unit-tested functions, without impeding on Dr. Lumley’s ability to regularly add helpful statistical innovations to the {survey} package (e.g. his recent work on PPS or crossed designs).

Footnotes

  1. Concerns about the package’s reliability are the biggest obstacle for using the package in production, particularly in the official statistics setting, where bugs might have meaningful ramifications for public policy decisions that affect people’s lives. The fact that the {survey} package is a large, complex package maintained entirely by one person with infrequent financial support means that there is a shortage of unit testing, and it can take weeks or months for serious bugs to be fixed. In contrast, funded teams of software developers, or collaborative projects such as {srvyr}, can more quickly address bugs and have more capacity for the tedious work of writing unit tests.↩︎