Skip to content

Instantly share code, notes, and snippets.

@david90
Created February 14, 2017 09:08
Show Gist options
  • Save david90/cd4e3288a535424fcb926a5ac91ee7ea to your computer and use it in GitHub Desktop.
Save david90/cd4e3288a535424fcb926a5ac91ee7ea to your computer and use it in GitHub Desktop.

Revisions

  1. david90 created this gist Feb 14, 2017.
    58 changes: 58 additions & 0 deletions train_svm.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,58 @@
    import os

    import sklearn
    from sklearn import cross_validation, grid_search
    from sklearn.metrics import confusion_matrix, classification_report
    from sklearn.svm import SVC
    from sklearn.externals import joblib

    def train_svm_classifer(features, labels, model_output_path):
    """
    train_svm_classifer will train a SVM, saved the trained and SVM model and
    report the classification performance
    features: array of input features
    labels: array of labels associated with the input features
    model_output_path: path for storing the trained svm model
    """
    # save 20% of data for performance evaluation
    X_train, X_test, y_train, y_test = cross_validation.train_test_split(features, labels, test_size=0.2)

    param = [
    {
    "kernel": ["linear"],
    "C": [1, 10, 100, 1000]
    },
    {
    "kernel": ["rbf"],
    "C": [1, 10, 100, 1000],
    "gamma": [1e-2, 1e-3, 1e-4, 1e-5]
    }
    ]

    # request probability estimation
    svm = SVC(probability=True)

    # 10-fold cross validation, use 4 thread as each fold and each parameter set can be train in parallel
    clf = grid_search.GridSearchCV(svm, param,
    cv=10, n_jobs=4, verbose=3)

    clf.fit(X_train, y_train)

    if os.path.exists(model_output_path):
    joblib.dump(clf.best_estimator_, model_output_path)
    else:
    print("Cannot save trained svm model to {0}.".format(model_output_path))

    print("\nBest parameters set:")
    print(clf.best_params_)

    y_predict=clf.predict(X_test)

    labels=sorted(list(set(labels)))
    print("\nConfusion matrix:")
    print("Labels: {0}\n".format(",".join(labels)))
    print(confusion_matrix(y_test, y_predict, labels=labels))

    print("\nClassification report:")
    print(classification_report(y_test, y_predict))