Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JakubSroka authored and JakubSroka committed Sep 21, 2018
1 parent 43e51ee commit 345d6cd
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ final class TypedLinearRegression [Inputs] private[ml](
def setLoss(value: LossStrategy): TypedLinearRegression[Inputs] = copy(lr.setLoss(value.sparkValue))
def setEpsilon(value: Double): TypedLinearRegression[Inputs] = copy(lr.setEpsilon(value))

private def copy(newRf: LinearRegression): TypedLinearRegression[Inputs] =
new TypedLinearRegression[Inputs](newRf, labelCol, featuresCol, weightCol)
private def copy(newLr: LinearRegression): TypedLinearRegression[Inputs] =
new TypedLinearRegression[Inputs](newLr, labelCol, featuresCol, weightCol)

}

Expand Down
2 changes: 1 addition & 1 deletion ml/src/test/scala/frameless/ml/Generators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object Generators {
implicit val arbLossStrategy: Arbitrary[LossStrategy] = Arbitrary {
Gen.oneOf(
Gen.const(LossStrategy.SquaredError),
Gen.const(LossStrategy.SquaredError)
Gen.const(LossStrategy.Huber)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,27 @@ class TypedLinearRegressionTests extends FramelessMlSuite with MustMatchers {

test("fit() returns a correct TypedTransformer") {
val prop = forAll { x2: X2[Double, Vector] =>
val rf = TypedLinearRegression[X2[Double, Vector]]
val lr = TypedLinearRegression[X2[Double, Vector]]
val ds = TypedDataset.create(Seq(x2))

val model = rf.fit(ds).run()
val model = lr.fit(ds).run()
val pDs = model.transform(ds).as[X3[Double, Vector, Double]]

pDs.select(pDs.col('a), pDs.col('b)).collect.run() == Seq(x2.a -> x2.b)
}
val prop2 = forAll { x2: X2[Vector, Double] =>
val rf = TypedLinearRegression[X2[Vector, Double]]
val lr = TypedLinearRegression[X2[Vector, Double]]
val ds = TypedDataset.create(Seq(x2))
val model = rf.fit(ds).run()
val model = lr.fit(ds).run()
val pDs = model.transform(ds).as[X3[Vector, Double, Double]]

pDs.select(pDs.col('a), pDs.col('b)).collect.run() == Seq(x2.a -> x2.b)
}

def prop3[A: TypedEncoder: Arbitrary] = forAll { x3: X3[Vector, Double, A] =>
val rf = TypedLinearRegression[X2[Vector, Double]]
val lr = TypedLinearRegression[X2[Vector, Double]]
val ds = TypedDataset.create(Seq(x3))
val model = rf.fit(ds).run()
val model = lr.fit(ds).run()
val pDs = model.transform(ds).as[X4[Vector, Double, A, Double]]

pDs.select(pDs.col('a), pDs.col('b), pDs.col('c)).collect.run() == Seq((x3.a, x3.b, x3.c))
Expand Down Expand Up @@ -70,11 +70,9 @@ class TypedLinearRegressionTests extends FramelessMlSuite with MustMatchers {
model.transformer.getAggregationDepth == 10 &&
model.transformer.getElasticNetParam == 0.5 &&
model.transformer.getEpsilon == 4.0 &&
model.transformer.getFitIntercept == true &&
model.transformer.getLoss == lossStrategy.sparkValue &&
model.transformer.getMaxIter == 23 &&
model.transformer.getRegParam == 1.2 &&
model.transformer.getStandardization == true &&
model.transformer.getTol == 2.3 &&
model.transformer.getSolver == solver.sparkValue
}
Expand Down Expand Up @@ -103,12 +101,12 @@ class TypedLinearRegressionTests extends FramelessMlSuite with MustMatchers {
)

val ds2 = Seq(
X3(new DenseVector(Array(1.0)): Vector,2: Float, 1.0),
X3(new DenseVector(Array(2.0)): Vector,2: Float, 2.0),
X3(new DenseVector(Array(3.0)): Vector,2: Float, 3.0),
X3(new DenseVector(Array(4.0)): Vector,2: Float, 4.0),
X3(new DenseVector(Array(5.0)): Vector,2: Float, 5.0),
X3(new DenseVector(Array(6.0)): Vector,2: Float, 6.0)
X3(new DenseVector(Array(1.0)): Vector,2F, 1.0),
X3(new DenseVector(Array(2.0)): Vector,2F, 2.0),
X3(new DenseVector(Array(3.0)): Vector,2F, 3.0),
X3(new DenseVector(Array(4.0)): Vector,2F, 4.0),
X3(new DenseVector(Array(5.0)): Vector,2F, 5.0),
X3(new DenseVector(Array(6.0)): Vector,2F, 6.0)
)

val tds = TypedDataset.create(ds)
Expand Down

0 comments on commit 345d6cd

Please sign in to comment.