From ca96c32309dea12ce9ed65aa04a7148e995f51cd Mon Sep 17 00:00:00 2001 From: John Andersen Date: Wed, 11 Dec 2019 22:36:02 -0800 Subject: [PATCH] tests: integration: CSV source string src_urls Signed-off-by: John Andersen --- tests/integration/common.py | 13 +++- tests/integration/test_sources.py | 115 ++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 3 deletions(-) create mode 100644 tests/integration/test_sources.py diff --git a/tests/integration/common.py b/tests/integration/common.py index 99dee843a4..da4aa5d59b 100644 --- a/tests/integration/common.py +++ b/tests/integration/common.py @@ -18,7 +18,7 @@ import asyncio import contextlib import unittest.mock -from typing import Dict, Any +from typing import Dict, Any, Optional from dffml.repo import Repo from dffml.base import config @@ -78,5 +78,12 @@ def required_plugins(self, *args): f"Required plugins: {', '.join(args)} must be installed in development mode" ) - def mktempfile(self): - return self._stack.enter_context(non_existant_tempfile()) + def mktempfile( + self, suffix: Optional[str] = None, text: Optional[str] = None + ): + filename = self._stack.enter_context(non_existant_tempfile()) + if suffix: + filename = filename + suffix + if text: + pathlib.Path(filename).write_text(inspect.cleandoc(text) + "\n") + return filename diff --git a/tests/integration/test_sources.py b/tests/integration/test_sources.py new file mode 100644 index 0000000000..9e15674260 --- /dev/null +++ b/tests/integration/test_sources.py @@ -0,0 +1,115 @@ +""" +This file contains integration tests. We use the CLI to exercise functionality of +various DFFML classes and constructs. +""" +import re +import os +import io +import json +import inspect +import pathlib +import asyncio +import contextlib +import unittest.mock +from typing import Dict, Any + +from dffml.repo import Repo +from dffml.base import config +from dffml.df.types import Definition, Operation, DataFlow, Input +from dffml.df.base import op +from dffml.cli.cli import CLI +from dffml.model.model import Model +from dffml.service.dev import Develop +from dffml.util.packaging import is_develop +from dffml.util.entrypoint import load +from dffml.config.config import BaseConfigLoader +from dffml.util.asynctestcase import AsyncTestCase + +from .common import IntegrationCLITestCase + + +class TestCSV(IntegrationCLITestCase): + async def test_string_src_urls(self): + # Test for issue #207 + self.required_plugins("dffml-model-scikit") + # Create the training data + train_filename = self.mktempfile( + suffix=".csv", + text=""" + Years,Expertise,Trust,Salary + 0,1,0.2,10 + 1,3,0.4,20 + 2,5,0.6,30 + 3,7,0.8,40 + """, + ) + # Create the test data + test_filename = self.mktempfile( + suffix=".csv", + text=""" + Years,Expertise,Trust,Salary + 4,9,1.0,50 + 5,11,1.2,60 + """, + ) + # Create the prediction data + predict_filename = self.mktempfile( + suffix=".csv", + text=""" + Years,Expertise,Trust + 6,13,1.4 + """, + ) + # Features + features = "-model-features def:Years:int:1 def:Expertise:int:1 def:Trust:float:1".split() + # Train the model + await CLI.cli( + "train", + "-model", + "scikitlr", + *features, + "-model-predict", + "Salary", + "-sources", + "training_data=csv", + "-source-filename", + train_filename, + ) + # Assess accuracy + await CLI.cli( + "accuracy", + "-model", + "scikitlr", + *features, + "-model-predict", + "Salary", + "-sources", + "test_data=csv", + "-source-filename", + test_filename, + ) + # Ensure JSON output works as expected (#261) + with contextlib.redirect_stdout(self.stdout): + # Make prediction + await CLI._main( + "predict", + "all", + "-model", + "scikitlr", + *features, + "-model-predict", + "Salary", + "-sources", + "predict_data=csv", + "-source-filename", + predict_filename, + ) + results = json.loads(self.stdout.getvalue()) + self.assertTrue(isinstance(results, list)) + self.assertTrue(results) + results = results[0] + self.assertIn("src_url", results) + self.assertEqual("0", results["src_url"]) + self.assertIn("prediction", results) + self.assertIn("value", results["prediction"]) + self.assertEqual(70.0, results["prediction"]["value"])