Sklearn Basics 3: Train a Classifier on a Snowflake Multi-Table Dataset¶
In this notebook, we will learn how to train a classifier with a more complex multi-table data where a secondary table is itself a parent tables of another table (ie. snowflake schema). It is highly recommended to see the Sklearn Basics 1 and Sklearn Basics 2 lessons if you are not familiar with Khiops’ sklearn estimators.
We start by importing the sklearn estimator KhiopsClassifier
:
import os
import pandas as pd
from khiops import core as kh
from khiops.sklearn import KhiopsClassifier
from khiops.utils.helpers import train_test_split_dataset
from sklearn import metrics
# If there are any issues you may Khiops status with the following command
# kh.get_runner().print_status()
Training a Multi-Table Classifier¶
We’ll train a multi-table classifier on a extension of dataset
AccidentsSummary
that we used in the previous notebook Sklearn
Basics 2. This dataset Accidents
contains the additional table
Users
and is organized in the following relational snowflake schema.
Accidents
|
| -- 1:n -- Vehicles
| |
| |-- 1:n -- Users
|
| -- 1:1 -- Places
Note that the target variable is Gravity
.
To train the KhiopsClassifier for this setup, we must specify a multi-table dataset. Let’s first check the content of the tables:
The main table
Accidents
.The first secondary table
Vehicles
which has a1:n
relationship withAccidents
.The second secondary table
Places
which has a1:1
relationship withAccidents
.The tertiary table
Users
which has a1:n
relationship withVehicles
.
accidents_dataset_dir = os.path.join(kh.get_samples_dir(), "Accidents")
accidents_file = os.path.join(accidents_dataset_dir, "Accidents.txt")
accidents_df = pd.read_csv(accidents_file, sep="\t")
print(f"Accident dataframe (first 10 rows):")
display(accidents_df.head(10))
print()
vehicles_file = os.path.join(accidents_dataset_dir, "Vehicles.txt")
vehicles_df = pd.read_csv(vehicles_file, sep="\t")
print(f"Vehicle dataframe (first 10 rows):")
display(vehicles_df.head(10))
# We drop the "Gravity" column as it was used to create the target
users_file = os.path.join(accidents_dataset_dir, "Users.txt")
users_df = pd.read_csv(users_file, sep="\t")
print(f"User dataframe (first 10 rows):")
display(users_df.head(10))
print()
places_file = os.path.join(accidents_dataset_dir, "Places.txt")
places_df = pd.read_csv(places_file, sep="\t", low_memory=False)
print(f"Places dataframe (first 10 rows):")
display(places_df.head(10))
Accident dataframe (first 10 rows):
AccidentId Gravity Date Hour Light 0 201800000001 NonLethal 2018-01-24 15:05:00 Daylight 1 201800000002 NonLethal 2018-02-12 10:15:00 Daylight 2 201800000003 NonLethal 2018-03-04 11:35:00 Daylight 3 201800000004 NonLethal 2018-05-05 17:35:00 Daylight 4 201800000005 NonLethal 2018-06-26 16:05:00 Daylight 5 201800000006 NonLethal 2018-09-23 06:30:00 TwilightOrDawn 6 201800000007 NonLethal 2018-09-26 00:40:00 NightStreelightsOn 7 201800000008 Lethal 2018-11-30 17:15:00 NightStreelightsOn 8 201800000009 NonLethal 2018-02-18 15:57:00 Daylight 9 201800000010 NonLethal 2018-03-19 15:30:00 Daylight Department Commune InAgglomeration IntersectionType Weather 0 590 5 No Y-type Normal 1 590 11 Yes Square VeryGood 2 590 477 Yes T-type Normal 3 590 52 Yes NoIntersection VeryGood 4 590 477 Yes NoIntersection Normal 5 590 52 Yes NoIntersection LightRain 6 590 133 Yes NoIntersection Normal 7 590 11 Yes NoIntersection Normal 8 590 550 No NoIntersection Normal 9 590 51 Yes X-type Normal CollisionType PostalAddress GPSCode 0 2Vehicles-BehindVehicles-Frontal route des Ansereuilles M 1 NoCollision Place du général de Gaul M 2 NoCollision Rue nationale M 3 2Vehicles-Side 30 rue Jules Guesde M 4 2Vehicles-Side 72 rue Victor Hugo M 5 Other D39 M 6 Other 4 route de camphin M 7 Other rue saint exupéry M 8 Other rue de l'égalité M 9 2Vehicles-BehindVehicles-Frontal face au 59 rue de Lille M Latitude Longitude 0 50.55737 2.55737 1 50.52936 2.52936 2 50.51243 2.51243 3 50.51974 2.51974 4 50.51607 2.51607 5 50.52132 2.52132 6 50.52211 2.52211 7 50.53146 2.53146 8 50.53707 2.53707 9 50.53639 2.53639
Vehicle dataframe (first 10 rows):
AccidentId VehicleId Direction Category PassengerNumber 0 201800000001 A01 Unknown Car<=3.5T 0 1 201800000001 B01 Unknown Car<=3.5T 0 2 201800000002 A01 Unknown Car<=3.5T 0 3 201800000003 A01 Unknown Motorbike>125cm3 0 4 201800000003 B01 Unknown Car<=3.5T 0 5 201800000003 C01 Unknown Car<=3.5T 0 6 201800000004 A01 Unknown Car<=3.5T 0 7 201800000004 B01 Unknown Bicycle 0 8 201800000005 A01 Unknown Moped 0 9 201800000005 B01 Unknown Car<=3.5T 0 FixedObstacle MobileObstacle ImpactPoint Maneuver 0 NaN Vehicle RightFront TurnToLeft 1 NaN Vehicle LeftFront NoDirectionChange 2 NaN Pedestrian NaN NoDirectionChange 3 StationaryVehicle Vehicle Front NoDirectionChange 4 NaN Vehicle LeftSide TurnToLeft 5 NaN NaN RightSide Parked 6 NaN Other RightFront Avoidance 7 NaN Vehicle LeftSide NaN 8 NaN Vehicle RightFront PassLeft 9 NaN Vehicle LeftFront Park
User dataframe (first 10 rows):
AccidentId VehicleId Seat Category Gender TripReason SafetyDevice 0 201800000001 A01 1.0 Driver Male Leisure SeatBelt 1 201800000001 B01 1.0 Driver Male NaN SeatBelt 2 201800000002 A01 1.0 Driver Male NaN SeatBelt 3 201800000002 A01 NaN Pedestrian Male NaN Helmet 4 201800000003 A01 1.0 Driver Male Leisure Helmet 5 201800000003 C01 1.0 Driver Male NaN ChildrenDevice 6 201800000004 A01 1.0 Driver Male Leisure SeatBelt 7 201800000004 B01 1.0 Driver Male Leisure Helmet 8 201800000005 A01 1.0 Driver Male Leisure Helmet 9 201800000005 B01 1.0 Driver Male Leisure SeatBelt SafetyDeviceUsed PedestrianLocation PedestrianAction 0 Yes NaN NaN 1 Yes NaN NaN 2 Yes NaN NaN 3 NaN OnLane<=OnSidewalk0mCrossing Crossing 4 Yes NaN NaN 5 NaN NaN NaN 6 Yes NaN NaN 7 NaN NaN NaN 8 Yes NaN NaN 9 Yes NaN NaN PedestrianCompany BirthYear 0 Unknown 1960.0 1 Unknown 1928.0 2 Unknown 1947.0 3 Alone 1959.0 4 Unknown 1987.0 5 Unknown 1977.0 6 Unknown 1982.0 7 Unknown 2013.0 8 Unknown 2001.0 9 Unknown 1946.0
Places dataframe (first 10 rows):
AccidentId RoadType RoadNumber RoadSecNumber RoadLetter 0 201800000001 Departamental 41 NaN C 1 201800000002 Communal 41 NaN D 2 201800000003 Departamental 39 NaN D 3 201800000004 Departamental 39 NaN NaN 4 201800000005 Communal NaN NaN NaN 5 201800000006 Departamental 39 NaN D 6 201800000007 Departamental 41 NaN D 7 201800000008 Communal - NaN NaN 8 201800000009 Departamental 141 NaN D 9 201800000010 Departamental 641 NaN NaN Circulation LaneNumber SpecialLane Slope RoadMarkerId 0 TwoWay 2.0 0 Flat NaN 1 TwoWay 2.0 0 Flat NaN 2 TwoWay 2.0 0 Flat NaN 3 TwoWay 2.0 0 Flat NaN 4 OneWay 1.0 0 Flat NaN 5 Unknown 2.0 0 Uphill NaN 6 TwoWay 2.0 0 Flat 16.0 7 TwoWay 2.0 0 Flat NaN 8 TwoWay 2.0 0 Flat NaN 9 TwoWay 2.0 Bike Flat 1.0 RoadMarkerDistance Layout StripWidth LaneWidth SurfaceCondition 0 NaN RightCurve NaN NaN Normal 1 NaN LeftCurve NaN NaN Normal 2 NaN Straight NaN NaN Normal 3 NaN Straight NaN NaN Normal 4 NaN Straight NaN NaN Normal 5 NaN LeftCurve NaN NaN Wet 6 500.0 Straight NaN NaN Normal 7 NaN Straight NaN NaN Normal 8 NaN Straight NaN NaN Normal 9 670.0 Straight NaN NaN Normal Infrastructure Localization SchoolNear 0 Unknown Lane 0.0 1 Unknown Lane 0.0 2 Unknown Lane 0.0 3 Unknown Lane 0.0 4 Unknown Lane 0.0 5 Unknown Shoulder 0.0 6 Unknown Shoulder 0.0 7 Unknown Lane 0.0 8 Unknown Shoulder 0.0 9 Unknown Lane 0.0
Create the multi-table dataset specification¶
Note the main table Accidents
and the secondary table Places
have one key AccidentId
. Tables Vehicles
(the other secondary
table) and Users
(the tertiary table) have two keys: AccidentId
and VehicleId
.
To describe relations between tables, we add the relations
field
must to the dataset spec. This field contains a list of tuples
describing the relations between tables. The first two values (str
)
of each tuple correspond to names of both the parent and the child table
involved in the relation. A third value (bool
) can be optionally set
as True
to indicate that the relation is 1:1
. For example, if
the tuple (table1, table2, True)
is contained in this field, it
means that:
table1
andtable2
are in a1:1
relationshipThe key of
table1
is contained in that oftable2
(ie. keys are hierarchical)
If the relations
field is not present then Khiops Python assumes
that the tables are in a star schema with main_table
as the
central table.
X_accidents = {
"main_table": "Accidents",
"tables": {
"Accidents": (accidents_df.drop("Gravity", axis=1), "AccidentId"),
"Vehicles": (vehicles_df, ["AccidentId", "VehicleId"]),
"Users": (users_df, ["AccidentId", "VehicleId"]),
"Places": (places_df, "AccidentId"),
},
"relations": [
("Accidents", "Vehicles"),
("Vehicles", "Users"),
("Accidents", "Places", True),
],
}
y_accidents = accidents_df["Gravity"]
Split the dataset into train and test¶
We use the helper function train_test_split_dataset
with the X
dataset spec to obtain one spec for train and another for test.
(
X_accidents_train,
X_accidents_test,
y_accidents_train,
y_accidents_test,
) = train_test_split_dataset(X_accidents, y_accidents, test_size=0.3)
Train a classifier with this dataset¶
You may choose the number of features
n_features
to be created by the Khiops AutoML engineSet the number of trees to zero (
n_trees=0
)
khc_accidents = KhiopsClassifier(n_trees=0, n_features=1000)
khc_accidents.fit(X_accidents_train, y_accidents_train)
KhiopsClassifier(n_features=1000, n_trees=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KhiopsClassifier(n_features=1000, n_trees=0)
Print the train accuracy and train auc of the model¶
accidents_train_performance = (
khc_accidents.model_report_.train_evaluation_report.get_snb_performance()
)
print(f"Accidents train accuracy: {accidents_train_performance.accuracy}")
print(f"Accidents train auc : {accidents_train_performance.auc}")
Accidents train accuracy: 0.945486
Accidents train auc : 0.847124
Deploy the classifier to obtain predictions and probabilities on the test data¶
y_accidents_test_predicted = khc_accidents.predict(X_accidents_test)
probas_accidents_test = khc_accidents.predict_proba(X_accidents_test)
print("Accidents test predictions (first 10 values):")
display(y_accidents_test_predicted[:10])
print("Accidentns test prediction probabilities (first 10 values):")
display(probas_accidents_test[:10])
Accidents test predictions (first 10 values):
array(['NonLethal', 'NonLethal', 'NonLethal', 'NonLethal', 'NonLethal',
'NonLethal', 'NonLethal', 'NonLethal', 'NonLethal', 'NonLethal'],
dtype='<U9')
Accidentns test prediction probabilities (first 10 values):
array([[0.0526682 , 0.9473318 ],
[0.00265006, 0.99734994],
[0.11575197, 0.88424803],
[0.003901 , 0.996099 ],
[0.12084546, 0.87915454],
[0.03023732, 0.96976268],
[0.02279352, 0.97720648],
[0.01870448, 0.98129552],
[0.01280118, 0.98719882],
[0.00382921, 0.99617079]])
Estimate the accuracy and AUC metrics on the test data¶
accidents_test_accuracy = metrics.accuracy_score(
y_accidents_test, y_accidents_test_predicted
)
accidents_test_auc = metrics.roc_auc_score(
y_accidents_test, probas_accidents_test[:, 1]
)
print(f"Accidents test accuracy: {accidents_test_accuracy}")
print(f"Accidents test auc : {accidents_test_auc}")
Accidents test accuracy: 0.9435246610902798
Accidents test auc : 0.824969336370397