diff --git a/ml/src/main/scala/frameless/ml/internals/LinearInputsChecker.scala b/ml/src/main/scala/frameless/ml/internals/LinearInputsChecker.scala
new file mode 100644
index 00000000..995a3f96
--- /dev/null
+++ b/ml/src/main/scala/frameless/ml/internals/LinearInputsChecker.scala
@@ -0,0 +1,70 @@
+package frameless
+package ml
+package internals
+
+import org.apache.spark.ml.linalg._
+import shapeless.ops.hlist.Length
+import shapeless.{HList, LabelledGeneric, Nat, Witness}
+
+import scala.annotation.implicitNotFound
+
+/**
+ * Can be used for linear reg algorithm
+ */
+@implicitNotFound(
+ msg = "Cannot prove that ${Inputs} is a valid input type. " +
+ "Input type must only contain a field of type Double (the label) and a field of type " +
+ "org.apache.spark.ml.linalg.Vector (the features) and optional field of float type (weight)."
+)
+trait LinearInputsChecker[Inputs] {
+ val featuresCol: String
+ val labelCol: String
+ val weightCol: Option[String]
+}
+
+object LinearInputsChecker {
+
+ implicit def checkLinearInputs[
+ Inputs,
+ InputsRec <: HList,
+ LabelK <: Symbol,
+ FeaturesK <: Symbol](
+ implicit
+ i0: LabelledGeneric.Aux[Inputs, InputsRec],
+ i1: Length.Aux[InputsRec, Nat._2],
+ i2: SelectorByValue.Aux[InputsRec, Double, LabelK],
+ i3: Witness.Aux[LabelK],
+ i4: SelectorByValue.Aux[InputsRec, Vector, FeaturesK],
+ i5: Witness.Aux[FeaturesK]
+ ): LinearInputsChecker[Inputs] = {
+ new LinearInputsChecker[Inputs] {
+ val labelCol: String = implicitly[Witness.Aux[LabelK]].value.name
+ val featuresCol: String = implicitly[Witness.Aux[FeaturesK]].value.name
+ val weightCol: Option[String] = None
+ }
+ }
+
+ implicit def checkLinearInputs2[
+ Inputs,
+ InputsRec <: HList,
+ LabelK <: Symbol,
+ FeaturesK <: Symbol,
+ WeightK <: Symbol](
+ implicit
+ i0: LabelledGeneric.Aux[Inputs, InputsRec],
+ i1: Length.Aux[InputsRec, Nat._3],
+ i2: SelectorByValue.Aux[InputsRec, Vector, FeaturesK],
+ i3: Witness.Aux[FeaturesK],
+ i4: SelectorByValue.Aux[InputsRec, Double, LabelK],
+ i5: Witness.Aux[LabelK],
+ i6: SelectorByValue.Aux[InputsRec, Float, WeightK],
+ i7: Witness.Aux[WeightK]
+ ): LinearInputsChecker[Inputs] = {
+ new LinearInputsChecker[Inputs] {
+ val labelCol: String = implicitly[Witness.Aux[LabelK]].value.name
+ val featuresCol: String = implicitly[Witness.Aux[FeaturesK]].value.name
+ val weightCol: Option[String] = Some(implicitly[Witness.Aux[WeightK]].value.name)
+ }
+ }
+
+}
diff --git a/ml/src/main/scala/frameless/ml/params/linears/LossStrategy.scala b/ml/src/main/scala/frameless/ml/params/linears/LossStrategy.scala
new file mode 100644
index 00000000..4b9ca6d4
--- /dev/null
+++ b/ml/src/main/scala/frameless/ml/params/linears/LossStrategy.scala
@@ -0,0 +1,16 @@
+package frameless
+package ml
+package params
+package linears
+/**
+ * SquaredError measures the average of the squares of the errors—that is,
+ * the average squared difference between the estimated values and what is estimated.
+ *
+ * Huber Loss loss function less sensitive to outliers in data than the
+ * squared error loss
+ */
+sealed abstract class LossStrategy private[ml](val sparkValue: String)
+object LossStrategy {
+ case object SquaredError extends LossStrategy("squaredError")
+ case object Huber extends LossStrategy("huber")
+}
diff --git a/ml/src/main/scala/frameless/ml/params/linears/Solver.scala b/ml/src/main/scala/frameless/ml/params/linears/Solver.scala
new file mode 100644
index 00000000..277e06e7
--- /dev/null
+++ b/ml/src/main/scala/frameless/ml/params/linears/Solver.scala
@@ -0,0 +1,25 @@
+package frameless
+package ml
+package params
+package linears
+
+/**
+ * solver algorithm used for optimization.
+ * - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton
+ * optimization method.
+ * - "normal" denotes using Normal Equation as an analytical solution to the linear regression
+ * problem. This solver is limited to `LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER`.
+ * - "auto" (default) means that the solver algorithm is selected automatically.
+ * The Normal Equations solver will be used when possible, but this will automatically fall
+ * back to iterative optimization methods when needed.
+ *
+ * spark
+ */
+
+sealed abstract class Solver private[ml](val sparkValue: String)
+object Solver {
+ case object LBFGS extends Solver("l-bfgs")
+ case object Auto extends Solver("auto")
+ case object Normal extends Solver("normal")
+}
+
diff --git a/ml/src/main/scala/frameless/ml/regression/TypedLinearRegression.scala b/ml/src/main/scala/frameless/ml/regression/TypedLinearRegression.scala
new file mode 100644
index 00000000..3b320862
--- /dev/null
+++ b/ml/src/main/scala/frameless/ml/regression/TypedLinearRegression.scala
@@ -0,0 +1,52 @@
+package frameless
+package ml
+package regression
+
+import frameless.ml.internals.LinearInputsChecker
+import frameless.ml.params.linears.{LossStrategy, Solver}
+import frameless.ml.{AppendTransformer, TypedEstimator}
+import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
+
+/**
+ * Linear Regression linear approach to modelling the relationship
+ * between a scalar response (or dependent variable) and one or more explanatory variables
+ */
+final class TypedLinearRegression [Inputs] private[ml](
+ lr: LinearRegression,
+ labelCol: String,
+ featuresCol: String,
+ weightCol: Option[String]
+) extends TypedEstimator[Inputs, TypedLinearRegression.Outputs, LinearRegressionModel] {
+
+ val estimatorWithoutWeight : LinearRegression = lr
+ .setLabelCol(labelCol)
+ .setFeaturesCol(featuresCol)
+ .setPredictionCol(AppendTransformer.tempColumnName)
+
+ val estimator = if (weightCol.isDefined) estimatorWithoutWeight.setWeightCol(weightCol.get) else estimatorWithoutWeight
+
+ def setRegParam(value: Double): TypedLinearRegression[Inputs] = copy(lr.setRegParam(value))
+ def setFitIntercept(value: Boolean): TypedLinearRegression[Inputs] = copy(lr.setFitIntercept(value))
+ def setStandardization(value: Boolean): TypedLinearRegression[Inputs] = copy(lr.setStandardization(value))
+ def setElasticNetParam(value: Double): TypedLinearRegression[Inputs] = copy(lr.setElasticNetParam(value))
+ def setMaxIter(value: Int): TypedLinearRegression[Inputs] = copy(lr.setMaxIter(value))
+ def setTol(value: Double): TypedLinearRegression[Inputs] = copy(lr.setTol(value))
+ def setSolver(value: Solver): TypedLinearRegression[Inputs] = copy(lr.setSolver(value.sparkValue))
+ def setAggregationDepth(value: Int): TypedLinearRegression[Inputs] = copy(lr.setAggregationDepth(value))
+ def setLoss(value: LossStrategy): TypedLinearRegression[Inputs] = copy(lr.setLoss(value.sparkValue))
+ def setEpsilon(value: Double): TypedLinearRegression[Inputs] = copy(lr.setEpsilon(value))
+
+ private def copy(newLr: LinearRegression): TypedLinearRegression[Inputs] =
+ new TypedLinearRegression[Inputs](newLr, labelCol, featuresCol, weightCol)
+
+}
+
+object TypedLinearRegression {
+ case class Outputs(prediction: Double)
+ case class Weight(weight: Double)
+
+
+ def apply[Inputs](implicit inputsChecker: LinearInputsChecker[Inputs]): TypedLinearRegression[Inputs] = {
+ new TypedLinearRegression(new LinearRegression(), inputsChecker.labelCol, inputsChecker.featuresCol, inputsChecker.weightCol)
+ }
+}
\ No newline at end of file
diff --git a/ml/src/test/scala/frameless/ml/Generators.scala b/ml/src/test/scala/frameless/ml/Generators.scala
index 502bcec1..9a109e15 100644
--- a/ml/src/test/scala/frameless/ml/Generators.scala
+++ b/ml/src/test/scala/frameless/ml/Generators.scala
@@ -1,6 +1,7 @@
package frameless
package ml
+import frameless.ml.params.linears.{LossStrategy, Solver}
import frameless.ml.params.trees.FeatureSubsetStrategy
import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors}
import org.scalacheck.{Arbitrary, Gen}
@@ -41,4 +42,16 @@ object Generators {
)
}
+ implicit val arbLossStrategy: Arbitrary[LossStrategy] = Arbitrary {
+ Gen.const(LossStrategy.SquaredError)
+ }
+
+ implicit val arbSolver: Arbitrary[Solver] = Arbitrary {
+ Gen.oneOf(
+ Gen.const(Solver.LBFGS),
+ Gen.const(Solver.Auto),
+ Gen.const(Solver.Normal)
+ )
+ }
+
}
diff --git a/ml/src/test/scala/frameless/ml/regression/TypedLinearRegressionTests.scala b/ml/src/test/scala/frameless/ml/regression/TypedLinearRegressionTests.scala
new file mode 100644
index 00000000..cee7b366
--- /dev/null
+++ b/ml/src/test/scala/frameless/ml/regression/TypedLinearRegressionTests.scala
@@ -0,0 +1,127 @@
+package frameless
+package ml
+package regression
+
+import frameless.ml.params.linears.{LossStrategy, Solver}
+import org.apache.spark.ml.linalg._
+import org.scalacheck.Arbitrary
+import org.scalacheck.Prop._
+import org.scalatest.Matchers._
+import org.scalatest.{MustMatchers}
+import shapeless.test.illTyped
+
+class TypedLinearRegressionTests extends FramelessMlSuite with MustMatchers {
+
+ implicit val arbVectorNonEmpty: Arbitrary[Vector] = Arbitrary(Generators.arbVector.arbitrary)
+
+ test("fit() returns a correct TypedTransformer") {
+ val prop = forAll { x2: X2[Double, Vector] =>
+ val lr = TypedLinearRegression[X2[Double, Vector]]
+ val ds = TypedDataset.create(Seq(x2))
+
+ val model = lr.fit(ds).run()
+ val pDs = model.transform(ds).as[X3[Double, Vector, Double]]
+
+ pDs.select(pDs.col('a), pDs.col('b)).collect.run() == Seq(x2.a -> x2.b)
+ }
+ val prop2 = forAll { x2: X2[Vector, Double] =>
+ val lr = TypedLinearRegression[X2[Vector, Double]]
+ val ds = TypedDataset.create(Seq(x2))
+ val model = lr.fit(ds).run()
+ val pDs = model.transform(ds).as[X3[Vector, Double, Double]]
+
+ pDs.select(pDs.col('a), pDs.col('b)).collect.run() == Seq(x2.a -> x2.b)
+ }
+
+ def prop3[A: TypedEncoder: Arbitrary] = forAll { x3: X3[Vector, Double, A] =>
+ val lr = TypedLinearRegression[X2[Vector, Double]]
+ val ds = TypedDataset.create(Seq(x3))
+ val model = lr.fit(ds).run()
+ val pDs = model.transform(ds).as[X4[Vector, Double, A, Double]]
+
+ pDs.select(pDs.col('a), pDs.col('b), pDs.col('c)).collect.run() == Seq((x3.a, x3.b, x3.c))
+ }
+
+ check(prop)
+ check(prop2)
+ check(prop3[String])
+ check(prop3[Double])
+ }
+
+ test("param setting is retained") {
+ import Generators.{arbLossStrategy, arbSolver}
+
+ val prop = forAll { (lossStrategy: LossStrategy, solver: Solver) =>
+ val lr = TypedLinearRegression[X2[Double, Vector]]
+ .setAggregationDepth(10)
+ .setEpsilon(4)
+ .setFitIntercept(true)
+ .setLoss(lossStrategy)
+ .setMaxIter(23)
+ .setRegParam(1.2)
+ .setStandardization(true)
+ .setTol(2.3)
+ .setSolver(solver)
+
+ val ds = TypedDataset.create(Seq(X2(0D, Vectors.dense(0D))))
+ val model = lr.fit(ds).run()
+
+ model.transformer.getAggregationDepth == 10 &&
+ model.transformer.getEpsilon == 4.0 &&
+ model.transformer.getLoss == lossStrategy.sparkValue &&
+ model.transformer.getMaxIter == 23 &&
+ model.transformer.getRegParam == 1.2 &&
+ model.transformer.getTol == 2.3 &&
+ model.transformer.getSolver == solver.sparkValue
+ }
+
+ check(prop)
+ }
+
+ test("create() compiles only with correct inputs") {
+ illTyped("TypedLinearRegressor.create[Double]()")
+ illTyped("TypedLinearRegressor.create[X1[Double]]()")
+ illTyped("TypedLinearRegressor.create[X2[Double, Double]]()")
+ illTyped("TypedLinearRegressor.create[X3[Vector, Double, Int]]()")
+ illTyped("TypedLinearRegressor.create[X2[Vector, String]]()")
+ }
+
+ test("TypedLinearRegressor should fit straight line ") {
+ case class Point(features: Vector, labels: Double)
+
+ val ds = Seq(
+ X2(new DenseVector(Array(1.0)): Vector, 1.0),
+ X2(new DenseVector(Array(2.0)): Vector, 2.0),
+ X2(new DenseVector(Array(3.0)): Vector, 3.0),
+ X2(new DenseVector(Array(4.0)): Vector, 4.0),
+ X2(new DenseVector(Array(5.0)): Vector, 5.0),
+ X2(new DenseVector(Array(6.0)): Vector, 6.0)
+ )
+
+ val ds2 = Seq(
+ X3(new DenseVector(Array(1.0)): Vector,2F, 1.0),
+ X3(new DenseVector(Array(2.0)): Vector,2F, 2.0),
+ X3(new DenseVector(Array(3.0)): Vector,2F, 3.0),
+ X3(new DenseVector(Array(4.0)): Vector,2F, 4.0),
+ X3(new DenseVector(Array(5.0)): Vector,2F, 5.0),
+ X3(new DenseVector(Array(6.0)): Vector,2F, 6.0)
+ )
+
+ val tds = TypedDataset.create(ds)
+
+ val lr = TypedLinearRegression[X2[Vector, Double]]
+ .setMaxIter(10)
+
+ val model = lr.fit(tds).run()
+
+ val tds2 = TypedDataset.create(ds2)
+
+ val lr2 = TypedLinearRegression[X3[Vector, Float, Double]]
+ .setMaxIter(10)
+
+ val model2 = lr2.fit(tds2).run()
+
+ model.transformer.coefficients shouldEqual new DenseVector(Array(1.0))
+ model2.transformer.coefficients shouldEqual new DenseVector(Array(1.0))
+ }
+}