Skip to content

Instantly share code, notes, and snippets.

@eugeneyan
Created February 21, 2021 19:16
Show Gist options
  • Select an option

  • Save eugeneyan/5ca7d2aa5683be497fa3c03bf4c608cb to your computer and use it in GitHub Desktop.

Select an option

Save eugeneyan/5ca7d2aa5683be497fa3c03bf4c608cb to your computer and use it in GitHub Desktop.

Revisions

  1. eugeneyan created this gist Feb 21, 2021.
    21 changes: 21 additions & 0 deletions test_rf_better_at_same_depth.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,21 @@
    def test_rf_better_than_dt(dummy_titanic):
    X_train, y_train, X_test, y_test = dummy_titanic

    dt = DecisionTree(depth_limit=10)
    dt.fit(X_train, y_train)

    rf = RandomForest(depth_limit=10, num_trees=7, col_subsampling=0.8, row_subsampling=0.8)
    rf.fit(X_train, y_train)

    pred_test_dt = dt.predict(X_test)
    pred_test_binary_dt = np.round(pred_test_dt)
    acc_test_dt = accuracy_score(y_test, pred_test_binary_dt)
    auc_test_dt = roc_auc_score(y_test, pred_test_dt)

    pred_test_rf = rf.predict(X_test)
    pred_test_binary_rf = np.round(pred_test_rf)
    acc_test_rf = accuracy_score(y_test, pred_test_binary_rf)
    auc_test_rf = roc_auc_score(y_test, pred_test_rf)

    assert acc_test_rf > acc_test_dt, 'RandomForest should have higher accuracy than DecisionTree on test set.'
    assert auc_test_rf > auc_test_dt, 'RandomForest should have higher AUC ROC than DecisionTree on test set.'