Skip to content

Instantly share code, notes, and snippets.

@vaibkumr
Created October 22, 2018 13:05
Show Gist options
  • Select an option

  • Save vaibkumr/c40de566027a54e9126961b926706b84 to your computer and use it in GitHub Desktop.

Select an option

Save vaibkumr/c40de566027a54e9126961b926706b84 to your computer and use it in GitHub Desktop.

Revisions

  1. vaibkumr created this gist Oct 22, 2018.
    22 changes: 22 additions & 0 deletions bettersplit.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,22 @@
    def std_agg(cnt, s1, s2): return math.sqrt((s2/cnt) - (s1/cnt)**2)

    def find_better_split(self, var_idx):
    x, y = self.x.values[self.idxs,var_idx], self.y[self.idxs]
    sort_idx = np.argsort(x)
    sort_y,sort_x = y[sort_idx], x[sort_idx]
    rhs_cnt,rhs_sum,rhs_sum2 = self.n, sort_y.sum(), (sort_y**2).sum()
    lhs_cnt,lhs_sum,lhs_sum2 = 0,0.,0.

    for i in range(0,self.n-self.min_leaf-1):
    xi,yi = sort_x[i],sort_y[i]
    lhs_cnt += 1; rhs_cnt -= 1
    lhs_sum += yi; rhs_sum -= yi
    lhs_sum2 += yi**2; rhs_sum2 -= yi**2
    if i<self.min_leaf or xi==sort_x[i+1]:
    continue

    lhs_std = std_agg(lhs_cnt, lhs_sum, lhs_sum2)
    rhs_std = std_agg(rhs_cnt, rhs_sum, rhs_sum2)
    curr_score = lhs_std*lhs_cnt + rhs_std*rhs_cnt
    if curr_score<self.score:
    self.var_idx,self.score,self.split = var_idx,curr_score,xi