Source code for ax.models.discrete.ashr_utils

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

import warnings

import numpy as np
import numpy.typing as npt
from ax.exceptions.core import UserInputError
from scipy.stats import norm


[docs] class Ashr: r""" An empirical Bayes model for estimating the effect sizes. Given the observations Y_i and their variances Yvar_i, Ashr estimates the underlying effect sizes mu_i by placing a mixture of Gaussians prior on the effect sizes. The prior consists of a point mass at zero and a set of centered Gaussians with specified variaces. The mixture proportions in the prior as well as the variances of the Gaussians in the mixture are learned based on observed outcome data. Methodology is based on the paper: False discovery rates: a new deal by M. Stephens https://academic.oup.com/biostatistics/article/18/2/275/2557030. """ def __init__( self, Y: npt.NDArray, Yvar: npt.NDArray, prior_vars: npt.NDArray | None = None, eb_grid_param: float = 2.0, ) -> None: r""" Args: Y: A length n array denoting the observed treatment effects. Yvar: A length n array denoting the variances of the observed values. prior_vars: A length k array denoting the variances of normal distributions in the mixture of Gaussians prior. In case of None, the variances are estimated based on data using the provided grid parameter. eb_grid_param: A grid parameter for estimating the prior variances based on data in case none were given. """ self.Y: npt.NDArray = Y self.Yvar: npt.NDArray = Yvar if prior_vars is None: prior_stds = prior_grid(Y=Y, Yvar=Yvar, grid_param=eb_grid_param) prior_vars = prior_stds**2 self.prior_vars: npt.NDArray = prior_vars self.ll: npt.NDArray = marginal_densities(Y=Y, Yvar=Yvar, prior_vars=prior_vars)
[docs] def posterior(self, w: npt.NDArray) -> GaussianMixture: r""" The posterior for mu_i can be calculated via the following rules. For the normal prior mu_i~N(0, sigma_k^2) and likelihood hat{mu}_i~N(mu_i, s_i^2), posterior for mu_i is N(sigma_1^2*hat{mu}_i/s_i*2, sigma_1^2), where sigma_1^2=sigma_k^2*s_i^2/(sigma_k^2+s_i^2). Args: w: (n,k) dim matrix containing the weights of the posterior mixture. Returns: Mixutre of Gaussians distribution. """ # (n, k) matrix of normal variances # a variance per mixture component and per observation normal_vars_by_class = np.divide( np.multiply.outer(self.Yvar, self.prior_vars), np.add.outer(self.Yvar, self.prior_vars), ) # (n, k) matrix of normal standard deviations # a mean per mixture component and per observation normal_means_by_class = normal_vars_by_class * (self.Y / self.Yvar)[:, None] return GaussianMixture( normal_means_by_class=normal_means_by_class, normal_vars_by_class=normal_vars_by_class, weights=w, )
[docs] def fit( self, lambdas: npt.NDArray | None = None, threshold: float = 10e-4, nsteps: int = 1000, ) -> dict[str, npt.NDArray]: r""" Fit Ashr to estimate prior proportions pi, posterior weights and lfdr. Args: lambdas: A length k array of penalty levels corresponding to each of k classes with entry equal to one meaning no penalty for the corresponding class. thereshold: The threshold used in the EM stoping rule. If the difference between two consecutive estimates of weights in the EM algorithm is smaller than the threshold, the algorithm stops. nsteps: The maximum number of steps in the EM algorithm. Returns: A dict containing - weights: (n, k)-dim estimated weight matrix for computing posteriors, - pi: parameter estimates: (n, k) dim matrix of proportions, and - lfdr: length n local False Discovery Rate (FDR) sequence. lfdr equals the posterior probability of true parameter being zero. """ k = len(self.prior_vars) # total number of classes if lambdas is None: lambdas = np.ones(k) # no penalty if len(lambdas) != k: raise ValueError( "The length of the penalty sequence should be the number of " "prior classes." ) lambdas = lambdas.astype(np.float64) results = fit_ashr_em( ll=self.ll, lambdas=lambdas, threshold=threshold, nsteps=nsteps ) return results
[docs] class GaussianMixture: r""" A weighted mixure of Gaussian distributions. The class computes the mean, standard errors and tails of each of n random variables from a mixture of k Gaussians. The Gaussians in the mixture are allowed to be degenerate, i.e., have zero variance. This is used in the Ashr model since one of the distributions in the prior mixture is a point mass at zero. """ def __init__( self, normal_means_by_class: npt.NDArray, normal_vars_by_class: npt.NDArray, weights: npt.NDArray, ) -> None: r""" Args: normal_means_by_class: (n, k)-dim matrix of normal means for each of the classes. Each of the n random variables comes from a mixture of k normal distributions. normal_var_by_class: (n, k)-dim matrix of normal variances for each of the k classes. The first column is all zeros, the variance of the null class. weights: (n, k)-dim matrix of weights on individual distributions per prior class. """ self.normal_means_by_class = normal_means_by_class self.normal_vars_by_class = normal_vars_by_class self.weights = weights
[docs] def tail_probabilities( self, left_tail: bool = True, threshold: float = 0.0 ) -> npt.NDArray: r""" Args: left_tail: An indicator for the tail probability to calculate. Note that neither tail includes null class. threshold: For left tail, the returned value measures probability of the effect being less than the threshold. For right tail, it is the probability of the effect being larger than the threshold. Returns: Length n array of tail probabilities for each of n rvs. """ tails_by_class = np.zeros_like(self.normal_means_by_class) # normal left tails for i in range(tails_by_class.shape[0]): for j in range(tails_by_class.shape[1]): tails_by_class[i, j] = ( norm.cdf( -np.divide( self.normal_means_by_class[i, j] - threshold, np.sqrt(self.normal_vars_by_class[i, j]), ) ) if self.normal_vars_by_class[i, j] > 0 else 0.0 ) # correcting the normal tails in case of a right tail if left_tail is False and self.normal_vars_by_class[i, j] > 0: tails_by_class[i, j] = 1.0 - tails_by_class[i, j] return np.multiply(tails_by_class, self.weights).sum(axis=1)
@property def means(self) -> npt.NDArray: r""" Returns: Length n array of final means for each rv. """ return np.multiply(self.weights, self.normal_means_by_class).sum(axis=1) @property def vars(self) -> npt.NDArray: r""" " Returns: Length n array of final standard deviations for each effect. """ # standard errors of the mixture distributions # https://en.wikipedia.org/wiki/Mixture_distribution#Moments return ( np.multiply( self.weights, self.normal_means_by_class**2 + self.normal_vars_by_class, ).sum(axis=1) - self.means**2 )
[docs] def prior_grid( Y: npt.NDArray, Yvar: npt.NDArray, grid_param: float = 2.0 ) -> npt.NDArray: r""" Produces the grid of standard deviations for each of the Gaussians in the prior mixture based on the observed data. Args: Y: A length n array of the observed treatment effects. Yvar: A length n array of observed variances of the above observations. grid_param: A grid parameter. Default 2.0 recommended in the paper to control the number of Gaussians in the mixture. Returns: A length n array of the standard deviations of the centered Gaussians in the mixture of Gaussians prior. """ m = np.sqrt(grid_param) sigma_lower = np.min(np.sqrt(Yvar)) / 10.0 sigma_upper = np.max(Y**2 - Yvar) sigma_upper = 2 * np.sqrt(sigma_upper) if sigma_upper > 0 else 8 * sigma_lower max_power = int(np.ceil(np.log(sigma_upper / sigma_lower) / np.log(m))) return np.array([0] + [sigma_lower * (m**power) for power in range(max_power + 1)])
[docs] def marginal_densities( Y: npt.NDArray, Yvar: npt.NDArray, prior_vars: npt.NDArray, ) -> npt.NDArray: r""" Evaluates marginal densities for each observed statistics and each prior class. Args: Y: A length n array denoting the observed treatment effects. Yvar: A length n array denoting the standard variances of the observed values. prior_vars: A length k array denoting the variances of prior classes. Returns: (n, k) dim matrix ll, where ll_{jk} is marginal density of j-th statistics eval at its observed value, assuming prior is coming from k-th class. """ k = len(prior_vars) # total number of classes n = len(Y) # total number of observations if prior_vars[0] != 0: raise UserInputError( "Ashr prior consists of a mixture of Gaussians where the " "first Gaussian in the prior mixture should be a point mass at zero. \ This degenerate Gaussian represents the prior on the effects being null." ) ll = np.zeros((n, k), dtype=np.float64) # marginal densities when prior is mass at zero ll[:, 0] = norm.pdf(Y, loc=0, scale=np.sqrt(Yvar)) for i in range(1, k): ll[:, i] = norm.pdf( Y, loc=0.0, scale=np.sqrt(prior_vars[i] + Yvar), ) return ll
[docs] def compute_weights(ll: npt.NDArray, pi: npt.NDArray) -> npt.NDArray: r""" Compute posterior weights based on marginal densities and prior probabilities. Args: ll: (n,k)-dim matrix of marginal densities eval at each observation, pi: length k vector of prior mixture proportions. Returns: (n,k)-dim matrix of weights. """ # multiply each row of ll with the corresponding element of pi vector w = np.multiply(ll, pi) # divide weights by the sum across each row w = w / w.sum(axis=1, keepdims=True) return w
[docs] def fit_ashr_em( ll: npt.NDArray, lambdas: npt.NDArray, threshold: float = 10e-4, nsteps: int = 1000, ) -> dict[str, npt.NDArray]: r""" Estimating proportions and posterior weights via an Expectation Maximization (EM) algorithm. Args: ll: (n,k)-dim matrix of marginal densities eval at each observation, marginalizing over the true effects. lambdas: A length k array of penalty levels. thereshold: If the difference between two consecutive estimates of weights is smaller than the threshold, the algorithm stops. nsteps: The maximum number of steps in the EM algorithm. Returns: A dictionary containing: - weights: (n,k)-dim matrix of weights, - pi: length k vector of estimates of prior mixture proportions, and - lfdr: length n local False Discovery Rate (FDR) sequence. lfdr equals the posterior probability of true parameter being zero. """ n, k = ll.shape # initializing pi vector if k - 1 < n: pi = np.ones(k) / n pi[0] = 1 - (k - 1) / n else: pi = np.ones(k) / k w = np.zeros_like(ll) for _ in range(nsteps): # E-step: compute weight matrix w; size of w: (n, k) w = compute_weights(ll=ll, pi=pi) # M-step: update pi ns = w.sum(axis=0).squeeze() + lambdas - 1.0 # length k pi_new = ns / ns.sum() # length k if sum(abs(pi - pi_new)) <= threshold: w = compute_weights(ll=ll, pi=pi_new) return {"weights": w, "pi": pi_new, "lfdr": w[:, 0]} pi = pi_new warnings.warn("EM did not converge.", stacklevel=2) return {"weights": w, "pi": pi, "lfdr": w[:, 0]}