Means, Medians, Modes, oh my!

The relationship between measures of central tendency, and loss functions in machine learning

Means , medians, and modes provide measures of central tendency. That is, they are single numbers that sit in the center of a collection of values. But how is this "center" defined? This post will provide some intuition into how they are related, and how they are intertwined with loss functions in machine learning.

In [1]:
using Plots, LossFunctions, Distributions
plotlyjs();

Plotly javascript loaded.

To load again call

init_notebook(true)

In [2]:
srand(0xdead);

The Mean, and Squared Distance Loss (Squared-Error)

Perhaps the most common measure of central tendency is the arithmetic mean, also commonly referred to as the average or simply the mean. The mean is defined as the sum of observations divided by the number of observations, and it is quite closely related to the concept of squared error.

In [3]:
#Generate some data from a normal distribution with mean 3 and standard deviation 1
dist = Normal(3,1);
In [4]:
data = rand(dist, 100);

This data has a "bell-curve" shape that we all know and love. Visual inspection shows that the center of the distribution should be somewhere around 3

In [5]:
histogram(data)
WARNING: filter(flt, itr) is deprecated, use Iterators.filter(flt, itr) instead.
Stacktrace:
 [1] depwarn(::String, ::Symbol) at ./deprecated.jl:70
 [2] filter(::Function, ::Base.KeyIterator{Dict{Symbol,Any}}) at ./deprecated.jl:57
 [3] _apply_style_axis!(::PlotlyJS.Plot{PlotlyJS.GenericTrace{Dict{Symbol,Any}}}, ::String, ::Bool) at /home/mihir/.julia/v0.6/PlotlyJS/src/json.jl:8
 [4] lower(::PlotlyJS.Plot{PlotlyJS.GenericTrace{Dict{Symbol,Any}}}) at /home/mihir/.julia/v0.6/PlotlyJS/src/json.jl:52
 [5] script_content(::PlotlyJS.Plot{PlotlyJS.GenericTrace{Dict{Symbol,Any}}}) at /home/mihir/.julia/v0.6/PlotlyJS/src/display.jl:18
 [6] _show(::Base.AbstractIOBuffer{Array{UInt8,1}}, ::MIME{Symbol("image/svg+xml")}, ::Plots.Plot{Plots.PlotlyJSBackend}) at /home/mihir/.julia/v0.6/Plots/src/backends/plotlyjs.jl:90
 [7] show(::Base.AbstractIOBuffer{Array{UInt8,1}}, ::MIME{Symbol("image/svg+xml")}, ::Plots.Plot{Plots.PlotlyJSBackend}) at /home/mihir/.julia/v0.6/Plots/src/output.jl:197
 [8] show(::Base.AbstractIOBuffer{Array{UInt8,1}}, ::MIME{Symbol("text/html")}, ::Plots.Plot{Plots.PlotlyJSBackend}) at /home/mihir/.julia/v0.6/Plots/src/output.jl:177
 [9] show(::Base.AbstractIOBuffer{Array{UInt8,1}}, ::String, ::Plots.Plot{Plots.PlotlyJSBackend}) at ./multimedia.jl:39
 [10] #sprint#228(::Void, ::Function, ::Int64, ::Function, ::String, ::Vararg{Any,N} where N) at ./strings/io.jl:66
 [11] display_dict(::Plots.Plot{Plots.PlotlyJSBackend}) at /home/mihir/.julia/v0.6/Plots/src/output.jl:266
 [12] execute_request(::ZMQ.Socket, ::IJulia.Msg) at /home/mihir/.julia/v0.6/IJulia/src/execute_request.jl:188
 [13] eventloop(::ZMQ.Socket) at /home/mihir/.julia/v0.6/IJulia/src/eventloop.jl:8
 [14] 
Out[5]:
In [6]:
#Given a centrality estimator (i.e. the mean), sums the error between that estimator and each data point
#as measured by the loss function.
function average_loss(est::Number, data::Vector{<: Number}, loss::Loss)
    n = length(data)
    #Crudely broadcast the estimator to act like an array
    ests = [est for i in 1:n]
    value(loss, data, ests, AvgMode.Mean())
end

function plot_average_loss(data::Vector{<:Number}, loss::Loss; minrange=minimum(data), maxrange=maximum(data))
    avgerrorfunc = est -> average_loss(est, data, loss)
    #Central tendency estimator should be between minimum and maximum of data
    p = plot(avgerrorfunc, linspace(minrange, maxrange, 200), xlabel="Estimator Value", ylabel="Average Error")
end

#Also make functions to plot losses

function plot_loss(loss::DistanceLoss; minrange=-2, maxrange=2)
    lossfunc = dist -> value(loss, dist)
    plot(lossfunc, linspace(minrange, maxrange, 200), xlabel="Distance from true value", ylabel="Loss value")
end

function plot_loss(loss::MarginLoss; minrange=-2, maxrange=2)
    lossfunc = agreement -> value(loss, agreement)
    plot(lossfunc, linspace(minrange, maxrange, 200), xlabel="Agreement with true value", ylabel="Loss value")
end
Out[6]:
plot_loss (generic function with 2 methods)

First, let's plot the squared error loss function

In [7]:
plot_loss(L2DistLoss())
Out[7]:

Similarly to the loss function itself, the average loss function over the dataset looks like a quadratic function.

In [8]:
plot_average_loss(data, L2DistLoss())
Out[8]:

Here we can see that the average error as we change the "center" looks like a parabola centered around 3. We defined the normal distribution with mean 3, so let's see what the actual mean is

In [9]:
mean(data)
Out[9]:
3.018440767800988

This is close to 3, but is it the actual minimizer? Let's take a closer look at values around the mean to confirm that the mean of the data minimizes the average squared error to every other point in the data.

In [10]:
#L2DistLoss is just another way to say "Squared Error"
plot_average_loss(data, L2DistLoss(), minrange=mean(data)-0.001, maxrange=mean(data)+0.001)
Out[10]:

The true minimum of the average loss is in fact, the mean, as shown in this plot of the average loss around a small neighborhood.

Outliers, Absolute Value Loss, and the Median

Of course, not all data comes from a perfect normal distribution. Here, we'll see where other loss functions, and hence measures of central tendency can be valuable. Outlying values affect different centrality measures differently, and I'll try to show here that this is because they assume different error metrics.

In [11]:
data_outliers = rand(dist, 100);

#The first 10 entries are outliers
data_outliers[1:10] = rand(Normal(10, 5), 10);
In [12]:
histogram(data_outliers)
Out[12]:

Now the data still looks kind of like a normal distribution, but there are some clear outliers. The eye can still identify that a kind of "center" is at 3, but the mean may tell something different.

In [13]:
function plot_average_loss(data::Vector{<:Number}, losses::Array; minrange=minimum(data), maxrange=maximum(data))
    errfuncs = [(est -> average_loss(est, data, loss)) for loss in losses]
    labels = [string(loss) for loss in losses]
    #Central tendency estimator should be between minimum and maximum of data
    p = plot(errfuncs, linspace(minrange, maxrange, 200), 
        xlabel="Estimator Value", ylabel="Average Error", label=labels)
end

function plot_losses(losses::Array; minrange=-2, maxrange=2)
    lossfuncs = [dist -> value(loss, dist) for loss in losses]
    plot(lossfuncs, linspace(minrange, maxrange, 200), xlabel="Agreement", ylabel="Loss value")
end
Out[13]:
plot_losses (generic function with 1 method)

The absolute value loss function increases linearly on both sides.

In [14]:
plot_loss(L1DistLoss())
Out[14]:

We can plot the average error for the squared error function (blue) and the absolute value function (red). It's pretty easy to see that the functions are minimized in different places. The value that minimizes the average absolute error is smaller than that for squared error. In fact, these are exactly the median and mean, respectively

A couple extra things to notice. First, the squared error is larger everywhere than the absolute error, and grows faster in both directions. This is one explanation as to why the mean is more sensitive to outliers than the median.

In [15]:
p = plot_average_loss(data_outliers, [L2DistLoss(), L1DistLoss()], minrange=0, maxrange=10)
vline!([mean(data_outliers)], label="mean")
vline!([median(data_outliers)], label="median")
Out[15]:

Let's take a closer look at the error curves. First, we can see that the average L2 (squared error) loss is minimized at the mean, almost 4.0 .

In [16]:
plot_average_loss(data_outliers, L2DistLoss(), 
    minrange=mean(data_outliers)-0.1, maxrange=mean(data_outliers)+0.1)
Out[16]:
In [17]:
mean(data_outliers)
Out[17]:
3.716825605298936

However, the L1 (absolute-value) loss is minimized at the median, much closer to 3.

In [18]:
plot_average_loss(data_outliers, L1DistLoss(), 
    minrange=median(data_outliers)-0.1, maxrange=median(data_outliers)+0.1)
Out[18]:
In [19]:
median(data_outliers)
Out[19]:
3.1703034564759283

Of course, more loss functions exist for regression, such as the L1 Epsilon-Insensitive loss (in Support Vector Regression), the Huber Loss function, and the log-cosh loss function (infinitely-differentiable Huber Loss). While these all share the robustness property with the absolute-value loss function, their average values aren't minimized at exactly the median. Let's take a look... (Absolute Value in blue, Huber in red, Epsilon Insensitive in green, Log-Cosh in purple)

In [20]:
plot_average_loss(data_outliers, [L1DistLoss(), HuberLoss(), L1EpsilonInsLoss(0.5), LogitDistLoss()],
    minrange=2, maxrange=5)
vline!([median(data_outliers)], label="median")
Out[20]:

Where are these things relevant?

Statistics other than the mean are used frequently in reporting things like income, where there are very few extremely high incomes that bias the mean upwards. Hence, household income is frequently reported in terms of median income.

These estimators also represent some of the simplest machine learning models. The mean is equivalent to a least-squares regression model with just an intercept, while the median is an L1-regression intercept. These are common baselines to use when fitting models to data, to make sure that model predictions are better than a constant.

But what about other kinds of data?

So far, we've looked at the mean and median as estimates of central tendency in data. These are part of a class of estimators called M-estimators, developed by Huber in 1964 and 1967. M-estimators are scalar values that minimize the sum of loss functions between themselves and a sample of data. However, so far the data has been sampled from a continuous distribution, with loss functions that measure the difference between an estimate and a true value. These loss functions are normally used for regression problems, where the data can take on any value in a continuous range (i.e. temperature, blood pressure, power, etc.).

Often, data is more restricted in the values it can take. For example, it may be binary-valued, ordinal (discrete with distinct levels), or categorical (discrete without any ordering). For the binary case, common loss functions include the Logistic loss (in logistic regression) and the Hinge loss (in Support Vector Machines). The values that minimize these loss functions are less well-known, but constitute the bias term in logistic regression and SVMs, respectively.

In [21]:
#Generate data from a bernoulli distribution (i.e. an unbalanced coin flip)
#Failure is represented by -1, success is 1
srand(123)
data_bin = 2 .* rand(Bernoulli(0.65), 100) .- 1;
In [22]:
plot_loss(LogitMarginLoss())
Out[22]:

Unlike the distance loss cases, the plot of average loss value with an estimator does not mirror a visualization of the loss function itself. Here I'll plot the logistic loss function. It decreases monotonically because it treats the x axis as a measure of confidence. A high confidence in a correct answer (high positive agreement) has low error, while high confidence in a wrong answer (negative agreement) has higher error.

In [23]:
plot_average_loss(data_bin, LogitMarginLoss(), minrange=0, maxrange=1.1)
vline!([mean(data_bin)], label="mean")
vline!([median(data_bin)], label="median")
Out[23]:
In [24]:
mean(data_bin)
Out[24]:
0.24
In [25]:
median(data_bin)
Out[25]:
1.0

The M-estimator of the logistic loss function is different from both the mean and the median. While the mean is around 0.24, and the median is 1 (due to the class distribution), the minimum of the average loss is achieved in between. This is equivalent to the intercept in a logistic regression model with no predictors.

There are also other loss functions for classification, which generally decrease as agreement increases. Hinge loss and Squared Hinge loss (used in Support Vector Machines) are a couple of them. This plot shows the logistic loss (blue) alongside the hinge loss (red) and squared hinge loss (green).

In [26]:
plot_losses([LogitMarginLoss(), HingeLoss(), L2HingeLoss()])
Out[26]:

Let's also look at the error achieved by an estimator for these loss functions (logistic loss in blue, hinge loss in red, and squared hinge loss in green)

In [27]:
plot_average_loss(data_bin, [LogitMarginLoss(), HingeLoss(), L2HingeLoss()],
    minrange=0, maxrange=1.1)
vline!([mean(data_bin)], label="mean")
vline!([median(data_bin)], label="median")
Out[27]:

Here, we see that different classification losses are minimized at different values. The hinge loss has its m-estimator at the median, while the squared hinge loss has its m-estimator at the mean. As above, the logistic loss is minimized for a constant in between the two.