Sklearn Basics 3: Train a Classifier on a Snowflake Multi-Table Dataset¶
In this notebook, we will learn how to train a classifier with a more complex multi-table data where a secondary table is itself a parent tables of another table (ie. snowflake schema). It is highly recommended to see the Sklearn Basics 1 and Sklearn Basics 2 lessons if you are not familiar with Khiops’ sklearn estimators.
We start by importing Khiops sklearn classifier KhiopsClassifier
and
saving the location of the Khiops Samples
directory into a variable:
from os import path
import pandas as pd
from khiops import core as kh
from khiops.sklearn import KhiopsClassifier
/github/home/.local/lib/python3.10/site-packages/khiops/core/internals/runner.py:1259: UserWarning: Too few cores: 2. To efficiently run Khiops in parallel at least 3 processes are needed. Khiops will run in a single process.
warnings.warn(
Training a Multi-Table Classifier¶
We’ll train a multi-table classifier on a extension of dataset
AccidentsSummary
that we used in the previous notebook Sklearn
Basics 2. This dataset Accidents
contains the additional table
Users
and is organized in the following relational snowflake schema.
Accidents
|
| -- 1:n -- Vehicles
| |
| |-- 1:n -- Users
|
| -- 1:1 -- Places
Note that the target variable is Gravity
.
To train the KhiopsClassifier for this setup, we must specify a multi-table dataset. Let’s first check the content of the tables:
The main table
Accidents
.The first secondary table
Vehicles
which has a1:n
relationship withAccidents
.The second secondary table
Places
which has a1:1
relationship withAccidents
.The tertiary table
Users
which has a1:n
relationship withVehicles
.
accidents_dataset_dir = path.join(kh.get_samples_dir(), "Accidents")
accidents_file = path.join(
path.join(kh.get_samples_dir(), "AccidentsSummary"), "Accidents.txt"
)
accidents_df = pd.read_csv(accidents_file, sep="\t", encoding="latin1")
print(f"Accident dataframe (first 10 rows):")
display(accidents_df.head(10))
print()
vehicles_file = path.join(accidents_dataset_dir, "Vehicles.txt")
vehicles_df = pd.read_csv(vehicles_file, sep="\t", encoding="latin1")
print(f"Vehicle dataframe (first 10 rows):")
display(vehicles_df.head(10))
# We drop the "Gravity" column as it was used to create the target
users_file = path.join(accidents_dataset_dir, "Users.txt")
users_df = pd.read_csv(users_file, sep="\t", encoding="latin1").drop("Gravity", axis=1)
print(f"User dataframe (first 10 rows):")
display(users_df.head(10))
print()
places_file = path.join(accidents_dataset_dir, "Places.txt")
places_df = pd.read_csv(places_file, sep="\t", encoding="latin1")
print(f"Places dataframe (first 10 rows):")
display(places_df.head(10))
Accident dataframe (first 10 rows):
AccidentId Gravity Date Hour Light 0 201800000001 NonLethal 2018-01-24 15:05:00 Daylight 1 201800000002 NonLethal 2018-02-12 10:15:00 Daylight 2 201800000003 NonLethal 2018-03-04 11:35:00 Daylight 3 201800000004 NonLethal 2018-05-05 17:35:00 Daylight 4 201800000005 NonLethal 2018-06-26 16:05:00 Daylight 5 201800000006 NonLethal 2018-09-23 06:30:00 TwilightOrDawn 6 201800000007 NonLethal 2018-09-26 00:40:00 NightStreelightsOn 7 201800000008 Lethal 2018-11-30 17:15:00 NightStreelightsOn 8 201800000009 NonLethal 2018-02-18 15:57:00 Daylight 9 201800000010 NonLethal 2018-03-19 15:30:00 Daylight Department Commune InAgglomeration IntersectionType Weather 0 590 5 No Y-type Normal 1 590 11 Yes Square VeryGood 2 590 477 Yes T-type Normal 3 590 52 Yes NoIntersection VeryGood 4 590 477 Yes NoIntersection Normal 5 590 52 Yes NoIntersection LightRain 6 590 133 Yes NoIntersection Normal 7 590 11 Yes NoIntersection Normal 8 590 550 No NoIntersection Normal 9 590 51 Yes X-type Normal CollisionType PostalAddress 0 2Vehicles-BehindVehicles-Frontal route des Ansereuilles 1 NoCollision Place du général de Gaul 2 NoCollision Rue nationale 3 2Vehicles-Side 30 rue Jules Guesde 4 2Vehicles-Side 72 rue Victor Hugo 5 Other D39 6 Other 4 route de camphin 7 Other rue saint exupéry 8 Other rue de l'égalité 9 2Vehicles-BehindVehicles-Frontal face au 59 rue de Lille
Vehicle dataframe (first 10 rows):
AccidentId VehicleId Direction Category PassengerNumber 0 201800000001 A01 Unknown Car<=3.5T 0 1 201800000001 B01 Unknown Car<=3.5T 0 2 201800000002 A01 Unknown Car<=3.5T 0 3 201800000003 A01 Unknown Motorbike>125cm3 0 4 201800000003 B01 Unknown Car<=3.5T 0 5 201800000003 C01 Unknown Car<=3.5T 0 6 201800000004 A01 Unknown Car<=3.5T 0 7 201800000004 B01 Unknown Bicycle 0 8 201800000005 A01 Unknown Moped 0 9 201800000005 B01 Unknown Car<=3.5T 0 FixedObstacle MobileObstacle ImpactPoint Maneuver 0 NaN Vehicle RightFront TurnToLeft 1 NaN Vehicle LeftFront NoDirectionChange 2 NaN Pedestrian NaN NoDirectionChange 3 StationaryVehicle Vehicle Front NoDirectionChange 4 NaN Vehicle LeftSide TurnToLeft 5 NaN NaN RightSide Parked 6 NaN Other RightFront Avoidance 7 NaN Vehicle LeftSide NaN 8 NaN Vehicle RightFront PassLeft 9 NaN Vehicle LeftFront Park
User dataframe (first 10 rows):
AccidentId VehicleId Seat Category Gender TripReason SafetyDevice 0 201800000001 A01 1.0 Driver Male Leisure SeatBelt 1 201800000001 B01 1.0 Driver Male NaN SeatBelt 2 201800000002 A01 1.0 Driver Male NaN SeatBelt 3 201800000002 A01 NaN Pedestrian Male NaN Helmet 4 201800000003 A01 1.0 Driver Male Leisure Helmet 5 201800000003 C01 1.0 Driver Male NaN ChildrenDevice 6 201800000004 A01 1.0 Driver Male Leisure SeatBelt 7 201800000004 B01 1.0 Driver Male Leisure Helmet 8 201800000005 A01 1.0 Driver Male Leisure Helmet 9 201800000005 B01 1.0 Driver Male Leisure SeatBelt SafetyDeviceUsed PedestrianLocation PedestrianAction 0 Yes NaN NaN 1 Yes NaN NaN 2 Yes NaN NaN 3 NaN OnLane<=OnSidewalk0mCrossing Crossing 4 Yes NaN NaN 5 NaN NaN NaN 6 Yes NaN NaN 7 NaN NaN NaN 8 Yes NaN NaN 9 Yes NaN NaN PedestrianCompany BirthYear 0 Unknown 1960.0 1 Unknown 1928.0 2 Unknown 1947.0 3 Alone 1959.0 4 Unknown 1987.0 5 Unknown 1977.0 6 Unknown 1982.0 7 Unknown 2013.0 8 Unknown 2001.0 9 Unknown 1946.0
Places dataframe (first 10 rows):
/tmp/ipykernel_1437/2245375966.py:24: DtypeWarning: Columns (2) have mixed types. Specify dtype option on import or set low_memory=False. places_df = pd.read_csv(places_file, sep="t", encoding="latin1")
AccidentId RoadType RoadNumber RoadSecNumber RoadLetter 0 201800000001 Departamental 41 NaN C 1 201800000002 Communal 41 NaN D 2 201800000003 Departamental 39 NaN D 3 201800000004 Departamental 39 NaN NaN 4 201800000005 Communal NaN NaN NaN 5 201800000006 Departamental 39 NaN D 6 201800000007 Departamental 41 NaN D 7 201800000008 Communal - NaN NaN 8 201800000009 Departamental 141 NaN D 9 201800000010 Departamental 641 NaN NaN Circulation LaneNumber SpecialLane Slope RoadMarkerId 0 TwoWay 2.0 0 Flat NaN 1 TwoWay 2.0 0 Flat NaN 2 TwoWay 2.0 0 Flat NaN 3 TwoWay 2.0 0 Flat NaN 4 OneWay 1.0 0 Flat NaN 5 Unknown 2.0 0 Uphill NaN 6 TwoWay 2.0 0 Flat 16.0 7 TwoWay 2.0 0 Flat NaN 8 TwoWay 2.0 0 Flat NaN 9 TwoWay 2.0 Bike Flat 1.0 RoadMarkerDistance Layout StripWidth LaneWidth SurfaceCondition 0 NaN RightCurve NaN NaN Normal 1 NaN LeftCurve NaN NaN Normal 2 NaN Straight NaN NaN Normal 3 NaN Straight NaN NaN Normal 4 NaN Straight NaN NaN Normal 5 NaN LeftCurve NaN NaN Wet 6 500.0 Straight NaN NaN Normal 7 NaN Straight NaN NaN Normal 8 NaN Straight NaN NaN Normal 9 670.0 Straight NaN NaN Normal Infrastructure Localization SchoolNear 0 Unknown Lane 0.0 1 Unknown Lane 0.0 2 Unknown Lane 0.0 3 Unknown Lane 0.0 4 Unknown Lane 0.0 5 Unknown Shoulder 0.0 6 Unknown Shoulder 0.0 7 Unknown Lane 0.0 8 Unknown Shoulder 0.0 9 Unknown Lane 0.0
Create the main feature matrix and the target vector for Accidents
¶
accidents_main_df = accidents_df.drop("Gravity", axis=1)
y_accidents_train = accidents_df["Gravity"]
Create the multi-table dataset specification¶
Note the main table Accidents
and the secondary table Places
have one key AccidentId
. Tables Vehicles
(the other secondary
table) and Users
(the tertiary table) have two keys: AccidentId
and VehicleId
.
To describe relations between tables, the field relations
must be
added to the dictionary of table specifications. This field contains a
list of tuples describing the relations between tables. The first two
values (str
) of each tuple correspond to names of both the parent
and the child table involved in the relation. A third value (bool
)
can be optionally set as True
to indicate that the relation is
1:1
. For example, if the tuple (table1, table2, True)
is
contained in this field, it means that:
table1
andtable2
are in a1:1
relationshipThe key of
table1
is contained in that oftable2
(ie. keys are hierarchical)
If the relations
field is not present then Khiops Python assumes
that the tables are in a star schema with main_table
as the
central table.
X_accidents_train = {
"main_table": "Accidents",
"tables": {
"Accidents": (accidents_main_df, "AccidentId"),
"Vehicles": (vehicles_df, ["AccidentId", "VehicleId"]),
"Users": (users_df, ["AccidentId", "VehicleId"]),
"Places": (places_df, ["AccidentId"]),
},
"relations": [
("Accidents", "Vehicles"),
("Vehicles", "Users"),
("Accidents", "Places", True),
],
}
Train a classifier with this dataset¶
You may choose the number of features
n_features
to be created by the Khiops AutoML engineSet the number of trees to zero (
n_trees=0
)
khc_accidents = KhiopsClassifier(n_trees=0, n_features=1000)
khc_accidents.fit(X_accidents_train, y_accidents_train)
KhiopsClassifier(n_features=1000, n_trees=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
KhiopsClassifier(n_features=1000, n_trees=0)
Print the accuracy and auc of the model¶
accidents_train_performance = (
khc_accidents.model_report_.train_evaluation_report.get_snb_performance()
)
print(f"Accidents train accuracy: {accidents_train_performance.accuracy}")
print(f"Accidents train auc : {accidents_train_performance.auc}")
Accidents train accuracy: 0.945018
Accidents train auc : 0.845953
Deploy the classifier to obtain predictions on the training data¶
Note that usually one deploys the model on new test data. We deploy on the train dataset to keep the tutorial simple*.
khc_accidents.predict(X_accidents_train)
array(['NonLethal', 'NonLethal', 'NonLethal', ..., 'NonLethal',
'NonLethal', 'NonLethal'], dtype='<U9')