How to Select a Univariate Distribution¶
In this example we will show how to select a distribution for a univariate target variable. We use the California housing dataset and select a distribution for the target variable median_house_value
.
Imports¶
from xgboostlss.distributions import *
from xgboostlss.distributions.distribution_utils import DistributionClass
from sklearn import datasets
from sklearn.model_selection import train_test_split
Data¶
housing_data = datasets.fetch_california_housing()
X, y = housing_data["data"], housing_data["target"]
feature_names = housing_data["feature_names"]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=123)
Select Distribution¶
In the following, we specify a list of candidate distributions. The function dist_select
returns the negative log-likelihood of each distribution for the target variable. The distribution with the lowest negative log-likelihood is selected. The function also plots the density of the target variable and the fitted density, using the best suitable distribution among the specified ones.
It is important to note that the list of candidate distributions should be chosen to be suitable for the target variable at hand. For example, if the target variable is a count variable, then the list of candidate distributions should include the Poisson and Negative Binomial. Similarly, if the target variable is on the positive real scale, then the list of continuous candidate distributions should be chosen accordingly.
xgblss_dist_class = DistributionClass()
candidate_distributions = [Gaussian, StudentT, Gamma, Cauchy, LogNormal, Weibull, Gumbel, Laplace]
dist_nll = xgblss_dist_class.dist_select(target=y_train, candidate_distributions=candidate_distributions, max_iter=50, plot=True, figure_size=(8, 5))
dist_nll
Fitting of candidate distributions completed: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:15<00:00, 1.94s/it]
nll | distribution | |
---|---|---|
rank | ||
1 | 23596.791908 | LogNormal |
2 | 23632.597656 | Gamma |
3 | 23899.039920 | Gumbel |
4 | 24083.658916 | Weibull |
5 | 25690.867630 | StudentT |
6 | 25796.219456 | Gaussian |
7 | 25925.138312 | Laplace |
8 | 27559.623077 | Cauchy |