Illustration of Adaptive Shrinkage

The goal here is to illustrate the “adaptive” nature of the adaptive shrinkage. The shrinkage is adaptive in two senses. First, the amount of shrinkage depends on the distribution g of the true effects, which is learned from the data: when g is very peaked about zero then ash learns this and deduces that signals should be more strongly shrunk towards zero than when g is less peaked about zero. Second, the amount of shrinkage of each observation depends on its standard error: the smaller the standard error, the more informative the data, and so the less shrinkage that occurs. From an Empirical Bayesian perspective both of these points are entirely natural: the posterior depends on both the prior and the likelihood; the prior, g, is learned from the data, and the likelihood incorporates the standard error of each observation.

First, we load the necessary libraries.

library(ashr)
library(ggplot2)

We simulate from two scenarios: in the first scenario, the effects are more peaked about zero (sim.spiky); in the second scenario, the effects are less peaked at zero (sim.bignormal). A summary of the two data sets is printed at the end of this chunk.

set.seed(100)

# Simulates data sets for experiments below.
rnormmix_datamaker = function (args) {
  
  # generate the proportion of true nulls randomly.
  pi0 = runif(1,args$min_pi0,args$max_pi0) 
  k   = ncomp(args$g)
  
  #randomly draw a component
  comp   = sample(1:k,args$nsamp,mixprop(args$g),replace = TRUE) 
  isnull = (runif(args$nsamp,0,1) < pi0)
  beta   = ifelse(isnull,0,rnorm(args$nsamp,comp_mean(args$g)[comp],
                                 comp_sd(args$g)[comp]))
  sebetahat = args$betahatsd
  betahat   = beta + rnorm(args$nsamp,0,sebetahat)
  meta      = list(g1 = args$g,beta = beta,pi0 = pi0)
  input     = list(betahat = betahat,sebetahat = sebetahat,df = NULL)
  return(list(meta = meta,input = input))
}

NSAMP = 1000
s     = 1/rgamma(NSAMP,5,5)

sim.spiky =
  rnormmix_datamaker(args = list(g = normalmix(c(0.4,0.2,0.2,0.2),
                                               c(0,0,0,0),
                                               c(0.25,0.5,1,2)),
                                  min_pi0   = 0,
                                  max_pi0   = 0,
                                  nsamp     = NSAMP,
                                  betahatsd = s))

sim.bignormal =
  rnormmix_datamaker(args = list(g         = normalmix(1,0,4),
                                 min_pi0   = 0,
                                 max_pi0   = 0,
                                 nsamp     = NSAMP,
                                 betahatsd = s))

cat("Summary of observed beta-hats:\n")
## Summary of observed beta-hats:
print(rbind(spiky     = quantile(sim.spiky$input$betahat,seq(0,1,0.1)),
            bignormal = quantile(sim.bignormal$input$betahat,seq(0,1,0.1))),
      digits = 3)
##               0%   10%   20%    30%    40%     50%   60%   70%  80%  90% 100%
## spiky      -6.86 -1.99 -1.19 -0.717 -0.326  0.0380 0.385 0.728 1.23 2.04 15.2
## bignormal -14.26 -5.03 -3.48 -1.989 -0.984 -0.0941 1.026 2.141 3.63 5.61 13.4

Now we run ash on both data sets.

beta.spiky.ash     = ash(sim.spiky$input$betahat,s)
beta.bignormal.ash = ash(sim.bignormal$input$betahat,s)

Next we plot the shrunken estimates against the observed values, colored according to the (square root of) precision: precise estimates being colored red, and less precise estimates being blue. Two key features of the plots illustrate the ideas of adaptive shrinkage: i) the estimates under the spiky scenario are shrunk more strongly, illustrating that shrinkage adapts to the underlying distribution of beta; ii) in both cases, estimates with large standard error (blue) are shrunk more than estimates with small standard error (red) illustrating that shrinkage adapts to measurement precision.

make_df_for_ashplot =
  function (sim1, sim2, ash1, ash2, name1 = "spiky", name2 = "big-normal") {
    n = length(sim1$input$betahat)
    x = c(get_lfsr(ash1),get_lfsr(ash2))
    return(data.frame(betahat  = c(sim1$input$betahat,sim2$input$betahat),
                      beta_est = c(get_pm(ash1),get_pm(ash2)),
                      lfsr     = x,
                      s        = c(sim1$input$sebetahat,sim2$input$sebetahat),
                      scenario = c(rep(name1,n),rep(name2,n)),
                      signif   = x < 0.05))
  }

ashplot = function(df,xlab="Observed beta-hat",ylab="Shrunken beta estimate")
  ggplot(df,aes(x = betahat,y = beta_est,color = 1/s)) +
    xlab(xlab) + ylab(ylab) + geom_point() +
    facet_grid(.~scenario) +
    geom_abline(intercept = 0,slope = 1,linetype = "dotted") +
    scale_colour_gradient2(midpoint = median(1/s),low = "blue",
                           mid = "white",high = "red",space = "Lab") +
    coord_fixed(ratio = 1)

df = make_df_for_ashplot(sim.spiky,sim.bignormal,beta.spiky.ash,
                         beta.bignormal.ash)
print(ashplot(df))

A related consequence is that significance of each observation is no longer monotonic with p value.

pval_plot = function (df)
  ggplot(df,aes(x = pnorm(-abs(betahat/s)),y = lfsr,color = log(s))) +
  geom_point() + facet_grid(.~scenario) + xlim(c(0,0.025)) +
  xlab("p value") + ylab("lfsr") +
  scale_colour_gradient2(midpoint = 0,low = "red",
                         mid = "white",high = "blue")

print(pval_plot(df))

Let’s see how these are affected by changing the modelling assumptions so that the standardized beta are exchangeable (rather than the beta being exchangeable).

beta.bignormal.ash.ET =
  ash(sim.bignormal$input$betahat,s,alpha = 1,mixcompdist = "normal")
beta.spiky.ash.ET =
  ash(sim.spiky$input$betahat,s,alpha = 1,mixcompdist = "normal")
df.ET = make_df_for_ashplot(sim.spiky,sim.bignormal,beta.spiky.ash.ET,
                            beta.bignormal.ash.ET)
ashplot(df.ET,ylab = "Shrunken beta estimate (ET model)")

pval_plot(df.ET)

This is a “volcano plot” showing effect size against p value. The blue points are “significant” in that they have lfsr < 0.05.

print(ggplot(df,aes(x = betahat,y = -log10(2*pnorm(-abs(betahat/s))),
                    col = signif)) +
  geom_point(alpha = 1,size = 1.75) + facet_grid(.~scenario) +
  theme(legend.position = "none") + xlim(c(-10,10)) + ylim(c(0,15)) +
  xlab("Effect (beta)") + ylab("-log10 p-value"))

In this case the significance by lfsr is not quite the same as cutting off at a given p value (you can see that the decision boundary is not quite the same as drawing a horizontal line), but also not that different, presumably because the standard errors, although varying across observations, do not vary greatly.

Session information.

print(sessionInfo())
## R version 4.4.2 (2024-10-31)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.1 LTS
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/openblas-pthread/libblas.so.3 
## LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.26.so;  LAPACK version 3.12.0
## 
## locale:
##  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=C              
##  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
##  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
## 
## time zone: Etc/UTC
## tzcode source: system (glibc)
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] ggplot2_3.5.1 ashr_2.2-66  
## 
## loaded via a namespace (and not attached):
##  [1] Matrix_1.7-1      gtable_0.3.6      jsonlite_1.8.9    compiler_4.4.2   
##  [5] Rcpp_1.0.13-1     jquerylib_0.1.4   scales_1.3.0      yaml_2.3.10      
##  [9] fastmap_1.2.0     lattice_0.22-6    R6_2.5.1          labeling_0.4.3   
## [13] mixsqp_0.3-54     knitr_1.49        tibble_3.2.1      maketools_1.3.1  
## [17] munsell_0.5.1     bslib_0.8.0       pillar_1.9.0      rlang_1.1.4      
## [21] utf8_1.2.4        cachem_1.1.0      SQUAREM_2021.1    xfun_0.49        
## [25] sass_0.4.9        sys_3.4.3         truncnorm_1.0-9   invgamma_1.1     
## [29] cli_3.6.3         withr_3.0.2       magrittr_2.0.3    digest_0.6.37    
## [33] grid_4.4.2        irlba_2.3.5.1     lifecycle_1.0.4   vctrs_0.6.5      
## [37] evaluate_1.0.1    glue_1.8.0        farver_2.1.2      buildtools_1.0.0 
## [41] fansi_1.0.6       colorspace_2.1-1  rmarkdown_2.29    pkgconfig_2.0.3  
## [45] tools_4.4.2       htmltools_0.5.8.1