-
Notifications
You must be signed in to change notification settings - Fork 5
/
hsicTestGamma.m
102 lines (73 loc) · 2.61 KB
/
hsicTestGamma.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
%This function implements the HSIC test using a Gamma approximation
%to the test threshold
%Arthur Gretton
%03/06/07
%Inputs:
% X contains dx columns, m rows. Each row is an i.i.d sample
% Y contains dy columns, m rows. Each row is an i.i.d sample
% alpha is the level of the test
% params.sigx is kernel size for x (set to median distance if -1)
% params.sigy is kernel size for y (set to median distance if -1)
%Outputs:
% thresh: test threshold for level alpha test
% testStat: test statistic
%Set kernel size to median distance between points, if no kernel specified
%11/01/08 Used new expression for beta independent of m, and
% m*HSICb as test statistic
function [thresh,testStat,params] = hsicTestGamma(X,Y,alpha,params);
m=size(X,1);
%Set kernel size to median distance between points, if no kernel specified.
%Use at most 100 points (since median is only a heuristic, and 100 points
%is sufficient for a robust estimate).
if params.sigx == -1
size1=size(X,1);
if size1>100
Xmed = X(1:100,:);
size1 = 100;
else
Xmed = X;
end
G = sum((Xmed.*Xmed),2);
Q = repmat(G,1,size1);
R = repmat(G',size1,1);
dists = Q + R - 2*Xmed*Xmed';
dists = dists-tril(dists);
dists=reshape(dists,size1^2,1);
params.sigx = sqrt(0.5*median(dists(dists>0))); %rbf_dot has factor of two in kernel
end
if params.sigy == -1
size1=size(Y,1);
if size1>100
Ymed = Y(1:100,:);
size1 = 100;
else
Ymed = Y;
end
G = sum((Ymed.*Ymed),2);
Q = repmat(G,1,size1);
R = repmat(G',size1,1);
dists = Q + R - 2*Ymed*Ymed';
dists = dists-tril(dists);
dists=reshape(dists,size1^2,1);
params.sigy = sqrt(0.5*median(dists(dists>0)));
end
bone = ones(m,1);
H = eye(m)-1/m*ones(m,m);
K = rbf_dot(X,X,params.sigx);
L = rbf_dot(Y,Y,params.sigy);
Kc = H*K*H; %Note: these are slightly biased estimates of centred Gram matrices
Lc = H*L*H;
%NOTE: we fit Gamma to testStat*m
testStat = 1/m * sum(sum(Kc'.*Lc)); %%%% TEST STATISTIC: m*HSICb (under H1)
varHSIC = (1/6 * Kc.*Lc).^2;
varHSIC = 1/m/(m-1)* ( sum(sum(varHSIC)) - sum(diag(varHSIC)) );
%second subtracted term is bias correction
varHSIC = 72*(m-4)*(m-5)/m/(m-1)/(m-2)/(m-3) * varHSIC; %variance under H0
K = K-diag(diag(K));
L = L-diag(diag(L));
muX = 1/m/(m-1)*bone'*(K*bone);
muY = 1/m/(m-1)*bone'*(L*bone);
mHSIC = 1/m * ( 1 +muX*muY - muX - muY ) ; %mean under H0
al = mHSIC^2 / varHSIC;
bet = varHSIC*m / mHSIC; %NOTE: threshold for hsicArr*m
thresh = icdf('gam',1-alpha,al,bet); %%%% TEST THRESHOLD