Skip to content

Commit

Permalink
Merge pull request #193 from BrainJS/192-time-series-baseline
Browse files Browse the repository at this point in the history
Add Time Step Recurrent Neural Network
  • Loading branch information
robertleeplummerjr authored Apr 21, 2018
2 parents 1d643a3 + eae7e22 commit 865c7c9
Show file tree
Hide file tree
Showing 25 changed files with 900 additions and 31 deletions.
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- [Training](#training)
+ [Data format](#data-format)
+ [For training with NeuralNetwork](#for-training-with-neuralnetwork)
+ [For training with `RNNTimeStep`, `LSTMTimeStep` and `GRUTimeStep`](#for-training-with-rnntimestep-lstmtimestep-and-gputimestep)
+ [For training with `RNN`, `LSTM` and `GRU`](#for-training-with-rnn-lstm-and-gpu)
+ [Training Options](#training-options)
+ [Async Training](#async-training)
Expand Down Expand Up @@ -128,6 +129,38 @@ net.train([{input: { r: 0.03, g: 0.7 }, output: { black: 1 }},

var output = net.run({ r: 1, g: 0.4, b: 0 }); // { white: 0.81, black: 0.18 }
```

#### For training with `RNNTimeStep`, `LSTMTimeStep` and `GRUTimeStep`
Eeach training pattern can either:
* Be an array of numbers
* Be an array of arrays of numbers

Example using an array of numbers:
```javascript
var net = new brain.recurrent.LSTMTimeStep();

net.train([
1,
2,
3,
]);

var output = net.run([1, 2]); // 3
```

Example using an array of arrays of numbers:
```javascript
var net = new brain.recurrent.LSTMTimeStep();

net.train([
[1, 3],
[2, 2],
[3, 1],
]);

var output = net.run([[1, 3], [2, 2]]); // [3, 1]
```

#### For training with `RNN`, `LSTM` and `GRU`
Each training pattern can either:
* Be an array of values
Expand Down Expand Up @@ -321,13 +354,17 @@ Likely example see: [simple letter detection](./examples/which-letter-simple.js)

# Neural Network Types
* [`brain.NeuralNetwork`](src/neural-network.js) - [Feedforward Neural Network](https://en.wikipedia.org/wiki/Feedforward_neural_network) with backpropagation
* [`brain.recurrent.RNNTimeStep`](src/recurrent/rnn-time-step.js) - [Time Step Recurrent Neural Network or "RNN"](https://en.wikipedia.org/wiki/Recurrent_neural_network)
* [`brain.recurrent.LSTMTimeStep`](src/recurrent/lstm-time-step.js) - [Time Step Long Short Term Memory Neural Network or "LSTM"](https://en.wikipedia.org/wiki/Long_short-term_memory)
* [`brain.recurrent.GRUTimeStep`](src/recurrent/gru-time-step.js) - [Time Step Gated Recurrent Unit or "GRU"](https://en.wikipedia.org/wiki/Gated_recurrent_unit)
* [`brain.recurrent.RNN`](src/recurrent/rnn.js) - [Recurrent Neural Network or "RNN"](https://en.wikipedia.org/wiki/Recurrent_neural_network)
* [`brain.recurrent.LSTM`](src/recurrent/lstm.js) - [Long Short Term Memory Neural Network or "LSTM"](https://en.wikipedia.org/wiki/Long_short-term_memory)
* [`brain.recurrent.GRU`](src/recurrent/gru.js) - [Gated Recurrent Unit or "GRU"](https://en.wikipedia.org/wiki/Gated_recurrent_unit)

### Why different Neural Network Types?
Different neural nets do different things well. For example:
* A Feedforward Neural Network can classify simple things very well, but it has no memory of previous actions and has infinite variation of results.
* A Time Step Recurrent Neural Network _remembers_, and can predict future values.
* A Recurrent Neural Network _remembers_, and has a finite set of results.

# Get Involved!
Expand Down
51 changes: 46 additions & 5 deletions browser.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* license: MIT (http://opensource.org/licenses/MIT)
* author: Heather Arthur <[email protected]>
* homepage: https://github.com/brainjs/brain.js#readme
* version: 1.1.3
* version: 1.2.0
*
* acorn:
* license: MIT (http://opensource.org/licenses/MIT)
Expand Down Expand Up @@ -2463,6 +2463,7 @@ var Equation = function () {
_classCallCheck(this, Equation);

this.inputRow = 0;
this.inputValue = null;
this.states = [];
}

Expand Down Expand Up @@ -2611,6 +2612,25 @@ var Equation = function () {
return product;
}

/**
* copy a matrix
* @param {Matrix} input
* @returns {Matrix}
*/

}, {
key: 'input',
value: function input(_input) {
var self = this;
this.states.push({
product: _input,
forwardFn: function forwardFn() {
_input.weights = self.inputValue;
}
});
return _input;
}

/**
* connects a matrix via a row
* @param {Matrix} m
Expand Down Expand Up @@ -2722,6 +2742,27 @@ var Equation = function () {
* @output {Matrix}
*/

}, {
key: 'runInput',
value: function runInput(inputValue) {
this.inputValue = inputValue;
var state = void 0;
for (var i = 0, max = this.states.length; i < max; i++) {
state = this.states[i];
if (!state.hasOwnProperty('forwardFn')) {
continue;
}
state.forwardFn(state.product, state.left, state.right);
}

return state.product;
}

/**
* @patam {Number} [rowIndex]
* @output {Matrix}
*/

}, {
key: 'runBackpropagate',
value: function runBackpropagate() {
Expand Down Expand Up @@ -3476,7 +3517,7 @@ var RNN = function () {

_classCallCheck(this, RNN);

var defaults = RNN.defaults;
var defaults = this.constructor.defaults;

for (var p in defaults) {
if (!defaults.hasOwnProperty(p)) continue;
Expand Down Expand Up @@ -3868,7 +3909,7 @@ var RNN = function () {
value: function train(data) {
var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {};

options = Object.assign({}, RNN.trainDefaults, options);
options = Object.assign({}, this.constructor.trainDefaults, options);
var iterations = options.iterations;
var errorThresh = options.errorThresh;
var log = options.log === true ? console.log : options.log;
Expand Down Expand Up @@ -3935,7 +3976,7 @@ var RNN = function () {
}, {
key: 'toJSON',
value: function toJSON() {
var defaults = RNN.defaults;
var defaults = this.constructor.defaults;
var model = this.model;
var options = {};
for (var p in defaults) {
Expand Down Expand Up @@ -3966,7 +4007,7 @@ var RNN = function () {
key: 'fromJSON',
value: function fromJSON(json) {
this.json = json;
var defaults = RNN.defaults;
var defaults = this.constructor.defaults;
var model = this.model;
var options = json.options;
var allMatrices = model.allMatrices;
Expand Down
6 changes: 3 additions & 3 deletions browser.min.js

Large diffs are not rendered by default.

17 changes: 16 additions & 1 deletion dist/index.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion dist/index.js.map

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

64 changes: 64 additions & 0 deletions dist/recurrent/gru-time-step.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions dist/recurrent/gru-time-step.js.map

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

64 changes: 64 additions & 0 deletions dist/recurrent/lstm-time-step.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions dist/recurrent/lstm-time-step.js.map

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 865c7c9

Please sign in to comment.