We can use cross validation to evaluate the prediction accuracy of the model. We can keep subset of our dataset without using it for training purposes. So those are new or unknown data for the model once we train that with the rest of data. Then we can use that subset of unused data to evaluate the accuracy of the trained model. Here, first we partition data into test dataset and training dataset and then train the model with the training dataset. Finally we evaluate the model with the test dataset. This process is called "Cross Validation".
In this blog post I would like to demonstrate how we can cross validate a decision tree classification model which is build using scikit-learn + Panda. Please visit decision-tree-classification-using scikit-learn post if you haven't create your classification model yet. As a recap at this point we have a decision tree model which predicts whether a given person in Titanic ship is going to survive from the tragedy or die in the cold, dark sea :(.
In previous blog post we have used entire Titanic dataset for training the model. Let's see how we can use only 80% of data for training and the rest 20% for evaluation purpose.
# separating 80% data for training
train = df.sample(frac=0.8, random_state=1)
# rest 20% data for evaluation purpose
test = df.loc[~df.index.isin(train.index)]
|
Then we train the model normally but we use training dataset
dt = DecisionTreeClassifier(min_samples_split=20, random_state=9)
dt.fit(train[features], train["Survived"])
|
Then we predict the result for rest 20% data.
predictions = dt.predict(test[features])
|
Then we can calculate Mean Squared Error of the predictions vs. actual values as a measurement of the prediction accuracy of the trained model.
We can use scikit-learn built in mean squared error function for this. First import it to current module.
from sklearn.metrics import mean_squared_error
|
Then we can do the calculation as follows
mse = mean_squared_error(predictions, test["Survived"])
print(mse)
|
You can play with the data partition ratio and the features and observe the variation of the Mean Squared Error with those parameters.
No comments:
Post a Comment