-
Notifications
You must be signed in to change notification settings - Fork 5
/
nmf.go
165 lines (135 loc) · 4.95 KB
/
nmf.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
package mlpack
/*
#cgo CFLAGS: -I./capi -Wall
#cgo LDFLAGS: -L. -lmlpack_go_nmf
#include <capi/nmf.h>
#include <stdlib.h>
*/
import "C"
import "gonum.org/v1/gonum/mat"
type NmfOptionalParam struct {
InitialH *mat.Dense
InitialW *mat.Dense
MaxIterations int
MinResidue float64
Seed int
UpdateRules string
Verbose bool
}
func NmfOptions() *NmfOptionalParam {
return &NmfOptionalParam{
InitialH: nil,
InitialW: nil,
MaxIterations: 10000,
MinResidue: 1e-05,
Seed: 0,
UpdateRules: "multdist",
Verbose: false,
}
}
/*
This program performs non-negative matrix factorization on the given dataset,
storing the resulting decomposed matrices in the specified files. For an
input dataset V, NMF decomposes V into two matrices W and H such that
V = W * H
where all elements in W and H are non-negative. If V is of size (n x m), then
W will be of size (n x r) and H will be of size (r x m), where r is the rank
of the factorization (specified by the "Rank" parameter).
Optionally, the desired update rules for each NMF iteration can be chosen from
the following list:
- multdist: multiplicative distance-based update rules (Lee and Seung 1999)
- multdiv: multiplicative divergence-based update rules (Lee and Seung 1999)
- als: alternating least squares update rules (Paatero and Tapper 1994)
The maximum number of iterations is specified with "MaxIterations", and the
minimum residue required for algorithm termination is specified with the
"MinResidue" parameter.
For example, to run NMF on the input matrix V using the 'multdist' update
rules with a rank-10 decomposition and storing the decomposed matrices into W
and H, the following command could be used:
// Initialize optional parameters for Nmf().
param := mlpack.NmfOptions()
param.UpdateRules = "multdist"
H, W := mlpack.Nmf(V, 10, param)
Input parameters:
- input (mat.Dense): Input dataset to perform NMF on.
- rank (int): Rank of the factorization.
- InitialH (mat.Dense): Initial H matrix.
- InitialW (mat.Dense): Initial W matrix.
- MaxIterations (int): Number of iterations before NMF terminates (0
runs until convergence. Default value 10000.
- MinResidue (float64): The minimum root mean square residue allowed
for each iteration, below which the program terminates. Default value
1e-05.
- Seed (int): Random seed. If 0, 'std::time(NULL)' is used. Default
value 0.
- UpdateRules (string): Update rules for each iteration; ( multdist |
multdiv | als ). Default value 'multdist'.
- Verbose (bool): Display informational messages and the full list of
parameters and timers at the end of execution.
Output parameters:
- h (mat.Dense): Matrix to save the calculated H to.
- w (mat.Dense): Matrix to save the calculated W to.
*/
func Nmf(input *mat.Dense, rank int, param *NmfOptionalParam) (*mat.Dense, *mat.Dense) {
params := getParams("nmf")
timers := getTimers()
disableBacktrace()
disableVerbose()
// Detect if the parameter was passed; set if so.
gonumToArmaMat(params, "input", input, false)
setPassed(params, "input")
// Detect if the parameter was passed; set if so.
setParamInt(params, "rank", rank)
setPassed(params, "rank")
// Detect if the parameter was passed; set if so.
if param.InitialH != nil {
gonumToArmaMat(params, "initial_h", param.InitialH, false)
setPassed(params, "initial_h")
}
// Detect if the parameter was passed; set if so.
if param.InitialW != nil {
gonumToArmaMat(params, "initial_w", param.InitialW, false)
setPassed(params, "initial_w")
}
// Detect if the parameter was passed; set if so.
if param.MaxIterations != 10000 {
setParamInt(params, "max_iterations", param.MaxIterations)
setPassed(params, "max_iterations")
}
// Detect if the parameter was passed; set if so.
if param.MinResidue != 1e-05 {
setParamDouble(params, "min_residue", param.MinResidue)
setPassed(params, "min_residue")
}
// Detect if the parameter was passed; set if so.
if param.Seed != 0 {
setParamInt(params, "seed", param.Seed)
setPassed(params, "seed")
}
// Detect if the parameter was passed; set if so.
if param.UpdateRules != "multdist" {
setParamString(params, "update_rules", param.UpdateRules)
setPassed(params, "update_rules")
}
// Detect if the parameter was passed; set if so.
if param.Verbose != false {
setParamBool(params, "verbose", param.Verbose)
setPassed(params, "verbose")
enableVerbose()
}
// Mark all output options as passed.
setPassed(params, "h")
setPassed(params, "w")
// Call the mlpack program.
C.mlpackNmf(params.mem, timers.mem)
// Initialize result variable and get output.
var hPtr mlpackArma
h := hPtr.armaToGonumMat(params, "h")
var wPtr mlpackArma
w := wPtr.armaToGonumMat(params, "w")
// Clean memory.
cleanParams(params)
cleanTimers(timers)
// Return output(s).
return h, w
}