Created
February 21, 2021 19:16
-
-
Save eugeneyan/5ca7d2aa5683be497fa3c03bf4c608cb to your computer and use it in GitHub Desktop.
Revisions
-
eugeneyan created this gist
Feb 21, 2021 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,21 @@ def test_rf_better_than_dt(dummy_titanic): X_train, y_train, X_test, y_test = dummy_titanic dt = DecisionTree(depth_limit=10) dt.fit(X_train, y_train) rf = RandomForest(depth_limit=10, num_trees=7, col_subsampling=0.8, row_subsampling=0.8) rf.fit(X_train, y_train) pred_test_dt = dt.predict(X_test) pred_test_binary_dt = np.round(pred_test_dt) acc_test_dt = accuracy_score(y_test, pred_test_binary_dt) auc_test_dt = roc_auc_score(y_test, pred_test_dt) pred_test_rf = rf.predict(X_test) pred_test_binary_rf = np.round(pred_test_rf) acc_test_rf = accuracy_score(y_test, pred_test_binary_rf) auc_test_rf = roc_auc_score(y_test, pred_test_rf) assert acc_test_rf > acc_test_dt, 'RandomForest should have higher accuracy than DecisionTree on test set.' assert auc_test_rf > auc_test_dt, 'RandomForest should have higher AUC ROC than DecisionTree on test set.'