Home Reference Source Test

test/unit/cross_validation_spec.mjs

import * as ms from '../../index.mjs';
import chai from 'chai';
import path from 'path';
import expose from './expose.js';
const { __dirname, } = expose;
const expect = chai.expect;
const testArray = [20, 25, 10, 33, 50, 42, 19, 34, 90, 23, ];
const ml = ms.ml;
const DTClassifier = ml.SL.DecisionTreeClassifier;
const RFRegression = ml.Regression.RandomForestRegression;
const dependentFeatures = [['Age', ], ['EstimatedSalary', ], ];
const independentFeatures = [['Purchased', ], ];
const regressionDependentFeatures = [['R&D Spend', ], ['Administration', ], ['Marketing Spend', ], ];
const regressionIndependentFeatures = [['Profit', ], ];
let SNA_csv;
let Start50_csv;
let Start50DataSet;

describe('cross_validation', function () { 
  this.timeout(10000);
  before((done) => {
    Promise.all([
      ms.loadCSV(path.join(__dirname, '../mock/Social_Network_Ads.csv'), {
        colParser: {
          Age: 'number',
          EstimatedSalary: 'number',
          Purchased: 'number',
        },
      }),
      ms.loadCSV(path.join(__dirname, '../mock/50_Startups.csv'), {
        colParser: {
          'Administration': 'number',
          'R&D Spend': 'number',
          'Marketing Spend': 'number',
          'Profit': 'number',
        },
      }),
    ])
      .then(csvs => {
        const [
          SNA_csv_list,
          S50_csv_list,
        ] = csvs;
        SNA_csv = SNA_csv_list;
        Start50_csv = S50_csv_list;
        Start50DataSet = new ms.DataSet(Start50_csv)
          .fitColumns({
            columns: [
              {
                name: 'State',
                options: {
                  strategy: 'onehot',
                },
              },
            ],
          });
        done();
      })
      .catch(done);
  });
  describe('train_test_split', () => {
    const defaultTrainTestSplit = ms.cross_validation.train_test_split(testArray);
    it('should split dataset with default values', () => {
      expect(ms.cross_validation).to.be.an('object');
      expect(defaultTrainTestSplit.train.length).to.equal(8);
      expect(defaultTrainTestSplit.test.length).to.equal(2);
    });
    it('should split the data into two arrays', () => {
      const defaultTrainTestSplitArray = ms.cross_validation.train_test_split(testArray, { return_array: true, });
      const [train, test,] = defaultTrainTestSplitArray;
      expect(train.length).to.equal(8);
      expect(test.length).to.equal(2);
    });
    it('should split with defined test size', () => {
      const defaultTrainTestSplitSize = ms.cross_validation.train_test_split(testArray, { test_size: 0.4, });
      const { train, test, } = defaultTrainTestSplitSize;
      expect(train.length).to.equal(6);
      expect(test.length).to.equal(4);
    });
    it('should split with defined train size', () => {
      const defaultTrainTestSplitSize = ms.cross_validation.train_test_split(testArray, { train_size: 0.5, });
      const { train, test, } = defaultTrainTestSplitSize;
      expect(train.length).to.equal(5);
      expect(test.length).to.equal(5);
    });
    it('should use a randomized seed', () => {
      const defaultTrainTestSplitSeed0 = ms.cross_validation.train_test_split(testArray, { random_state: 0, return_array: true, });
      const defaultTrainTestSplitSeed0a = ms.cross_validation.train_test_split(testArray, { random_state: 0, return_array: true, });
      const defaultTrainTestSplitSeed1 = ms.cross_validation.train_test_split(testArray, { random_state: 1, return_array: true, });
      const [train0,] = defaultTrainTestSplitSeed0;
      const [train0a,] = defaultTrainTestSplitSeed0a;
      const [train1,] = defaultTrainTestSplitSeed1;
      expect(train0.toString()).to.equal(train0a.toString());
      expect(train0.toString()).to.not.equal(train1.toString());
    });
  });
  describe('cross_validation_split', () => {
    const defaultCrossValidation = ms.cross_validation.cross_validation_split(testArray);
    it('should split dataset with default values', () => {
      expect(defaultCrossValidation.length).to.equal(3);
    });
    it('should split the data into k-folds', () => {
      const defaultCrossValidationArray = ms.cross_validation.cross_validation_split(testArray, {
        folds: 2,
      });
      expect(defaultCrossValidationArray[0].length).to.equal(5);
      expect(defaultCrossValidationArray.length).to.equal(2);
    });
    it('should use a randomized seed', () => {
      const defaultCrossValidationSeed0 = ms.cross_validation.cross_validation_split(testArray, { random_state: 0, });
      const defaultCrossValidationSeed0a = ms.cross_validation.cross_validation_split(testArray, { random_state: 0, });
      const defaultCrossValidationSeed1 = ms.cross_validation.cross_validation_split(testArray, { random_state: 1, });
      
      expect(JSON.stringify(defaultCrossValidationSeed0)).to.equal(JSON.stringify(defaultCrossValidationSeed0a));
      expect(JSON.stringify(defaultCrossValidationSeed0)).to.not.equal(JSON.stringify(defaultCrossValidationSeed1));
    });
  });
  describe('cross_validate_score', () => {
    it('should validate classification', () => {
      // console.log({ SNA_csv });
      const { train, test, } = ms.cross_validation.train_test_split(SNA_csv, {
        test_size: 0.25,
        random_state: 0,
        parse_int_train_size: true,
      });
      const classifier = new DTClassifier({
        gainFunction: 'gini',
        minNumSamples: 4,
      });
      const accuracy = ms.cross_validation.cross_validate_score({
        dataset: train,
        testingset: test,
        classifier,
        dependentFeatures,
        independentFeatures,
      });
      expect(accuracy).to.have.lengthOf(10);
      expect(ms.util.mean(accuracy)).to.be.greaterThan(0.75);
      expect(ms.util.sd(accuracy)).to.be.lessThan(0.08);
      // console.log('accuracy', accuracy);
      // console.log('acc avg', ms.util.mean(accuracy));
      // console.log('acc sd', ms.util.sd(accuracy));   
    });
    it('should validate classification constructor', () => {
      // console.log({ SNA_csv });
      const { train, test, } = ms.cross_validation.train_test_split(SNA_csv, {
        test_size: 0.25,
        random_state: 0,
        parse_int_train_size: true,
      });
      const classifier = ml.KNN;
      const accuracy = ms.cross_validation.cross_validate_score({
        dataset: train,
        testingset: test,
        classifier,
        modelOptions: { k: 5, },
        use_estimates_y_vector: true,
        dependentFeatures,
        independentFeatures,
      });
      expect(accuracy).to.have.lengthOf(10);
      expect(ms.util.mean(accuracy)).to.be.greaterThan(0.65);
      expect(ms.util.sd(accuracy)).to.be.lessThan(0.08);
      // console.log('accuracy', accuracy);
      // console.log('acc avg', ms.util.mean(accuracy));
      // console.log('acc sd', ms.util.sd(accuracy));   
    });
    it('should validate regression with train method', () => {
      const { train, test, } = ms.cross_validation.train_test_split(Start50DataSet, {
        test_size: 0.25,
        random_state: 0,
        parse_int_train_size: true,
      });
      const regression = new RFRegression({
        seed: 3,
        maxFeatures: 2,
        replacement: false,
        nEstimators: 300,
      });
      const accuracy = ms.cross_validation.cross_validate_score({
        dataset: train,
        testingset: test,
        folds: 2,
        // accuracy:'rSquared',
        use_train_y_vector:true,
        regression,
        dependentFeatures: regressionDependentFeatures,
        independentFeatures: regressionIndependentFeatures,
      });
      // console.log('accuracy', accuracy);
      // console.log('acc avg', ms.util.mean(accuracy));
      expect(accuracy).to.have.lengthOf(2);
      expect(ms.util.mean(accuracy)).to.be.lessThan(40000);
      // expect(ms.util.sd(accuracy)).to.be.lessThan(0.08);
      // console.log('acc sd', ms.util.sd(accuracy));   
    });
    it('should validate regression with constructor methods', () => {
      const { train, test, } = ms.cross_validation.train_test_split(Start50DataSet, {
        test_size: 0.25,
        random_state: 0,
        parse_int_train_size: true,
      });
      const regression = ml.MultivariateLinearRegression;
      const accuracy = ms.cross_validation.cross_validate_score({
        dataset: train,
        testingset: test,
        folds: 2,
        // accuracy:'rSquared',
        use_train_y_vector:false,
        regression,
        dependentFeatures: regressionDependentFeatures,
        independentFeatures: regressionIndependentFeatures,
      });
      // console.log('accuracy', accuracy);
      // console.log('acc avg', ms.util.mean(accuracy));
      expect(accuracy).to.have.lengthOf(2);
      expect(ms.util.mean(accuracy)).to.be.lessThan(9000);
      // expect(ms.util.sd(accuracy)).to.be.lessThan(0.08);
      // console.log('acc sd', ms.util.sd(accuracy));   
    });
  });
  describe('grid_search', () => {
    it('should return sorted best regression parameters', (done) => {
      const { train, test, } = ms.cross_validation.train_test_split(Start50DataSet, {
        test_size: 0.25,
        random_state: 0,
        parse_int_train_size: true,
      }); 
      const optimizedParameters = ms.cross_validation.grid_search({
        regression: ml.Regression.RandomForestRegression,
        parameters: {
          seed: [2, ],
          maxFeatures: [2, 3, ],
          replacement: [false, true, ],
          nEstimators: [300, 500, ],
        },
        dataset: train,
        testingset: test,
        folds: 2,
        // accuracy:'rSquared',
        use_train_y_vector:true,
        dependentFeatures: regressionDependentFeatures,
        independentFeatures: regressionIndependentFeatures,
      });
      // console.log(JSON.stringify(optimizedParameters, null, 2));
      expect(optimizedParameters).to.haveOwnProperty('params');
      expect(optimizedParameters).to.haveOwnProperty('results');
      done();
    });
    it('should return sorted best classification parameters', () => {
      const { train, test, } = ms.cross_validation.train_test_split(SNA_csv, {
        test_size: 0.25,
        random_state: 0,
        parse_int_train_size: true,
      });
      const optimizedParameters = ms.cross_validation.grid_search({
        dataset: train,
        testingset: test,
        classifier: ml.SL.RandomForestClassifier,
        return_parameters:true,
        parameters: {
          maxFeatures: [1.0, 0.5,],
          // replacement: [true, false, ],
          // useSampleBagging: [true, false, ],
          nEstimators: [10, 20,],
          treeOptions: [
            { minNumSamples: 3, },
            { minNumSamples: 2, },
          ],
        },
        dependentFeatures,
        independentFeatures,
      });
      // console.log(JSON.stringify(optimizedParameters, null, 2));
      expect(optimizedParameters).to.be.an('array');
      expect(optimizedParameters).to.have.lengthOf(8);
    });
    it('should sort parameter performance logically for regression', () => {
      //TODO
    });
    it('should sort parameter performance logically for classification', () => {
      //TODO
    });
  });
});