Sklearn Basics 3: Train a Classifier on a Snowflake Multi-Table Dataset

In this notebook, we will learn how to train a classifier with a more complex multi-table data where a secondary table is itself a parent tables of another table (ie. snowflake schema). It is highly recommended to see the Sklearn Basics 1 and Sklearn Basics 2 lessons if you are not familiar with Khiops’ sklearn estimators.

We start by importing the sklearn estimator KhiopsClassifier:

import os
import pandas as pd
from khiops import core as kh
from khiops.sklearn import KhiopsClassifier
from khiops.utils.helpers import train_test_split_dataset
from sklearn import metrics

# If there are any issues you may Khiops status with the following command
# kh.get_runner().print_status()

Training a Multi-Table Classifier

We’ll train a multi-table classifier on a extension of dataset AccidentsSummary that we used in the previous notebook Sklearn Basics 2. This dataset Accidents contains the additional table Users and is organized in the following relational snowflake schema.

Accidents
|
| -- 1:n -- Vehicles
|              |
|              |-- 1:n -- Users
|
| -- 1:1 -- Places

Note that the target variable is Gravity.

To train the KhiopsClassifier for this setup, we must specify a multi-table dataset. Let’s first check the content of the tables:

  • The main table Accidents.

  • The first secondary table Vehicles which has a 1:n relationship with Accidents.

  • The second secondary table Places which has a 1:1 relationship with Accidents.

  • The tertiary table Users which has a 1:n relationship with Vehicles.

accidents_dataset_dir = os.path.join(kh.get_samples_dir(), "Accidents")

accidents_file = os.path.join(accidents_dataset_dir, "Accidents.txt")
accidents_df = pd.read_csv(accidents_file, sep="\t")
print(f"Accident dataframe (first 10 rows):")
display(accidents_df.head(10))
print()

vehicles_file = os.path.join(accidents_dataset_dir, "Vehicles.txt")
vehicles_df = pd.read_csv(vehicles_file, sep="\t")
print(f"Vehicle dataframe (first 10 rows):")
display(vehicles_df.head(10))

# We drop the "Gravity" column as it was used to create the target
users_file = os.path.join(accidents_dataset_dir, "Users.txt")
users_df = pd.read_csv(users_file, sep="\t")
print(f"User dataframe (first 10 rows):")
display(users_df.head(10))
print()

places_file = os.path.join(accidents_dataset_dir, "Places.txt")
places_df = pd.read_csv(places_file, sep="\t", low_memory=False)
print(f"Places dataframe (first 10 rows):")
display(places_df.head(10))
Accident dataframe (first 10 rows):
     AccidentId    Gravity        Date      Hour               Light  0  201800000001  NonLethal  2018-01-24  15:05:00            Daylight
1  201800000002  NonLethal  2018-02-12  10:15:00            Daylight
2  201800000003  NonLethal  2018-03-04  11:35:00            Daylight
3  201800000004  NonLethal  2018-05-05  17:35:00            Daylight
4  201800000005  NonLethal  2018-06-26  16:05:00            Daylight
5  201800000006  NonLethal  2018-09-23  06:30:00      TwilightOrDawn
6  201800000007  NonLethal  2018-09-26  00:40:00  NightStreelightsOn
7  201800000008     Lethal  2018-11-30  17:15:00  NightStreelightsOn
8  201800000009  NonLethal  2018-02-18  15:57:00            Daylight
9  201800000010  NonLethal  2018-03-19  15:30:00            Daylight

   Department  Commune InAgglomeration IntersectionType    Weather  0         590        5              No           Y-type     Normal
1         590       11             Yes           Square   VeryGood
2         590      477             Yes           T-type     Normal
3         590       52             Yes   NoIntersection   VeryGood
4         590      477             Yes   NoIntersection     Normal
5         590       52             Yes   NoIntersection  LightRain
6         590      133             Yes   NoIntersection     Normal
7         590       11             Yes   NoIntersection     Normal
8         590      550              No   NoIntersection     Normal
9         590       51             Yes           X-type     Normal

                      CollisionType             PostalAddress GPSCode  0  2Vehicles-BehindVehicles-Frontal    route des Ansereuilles       M
1                       NoCollision  Place du général de Gaul       M
2                       NoCollision            Rue  nationale       M
3                    2Vehicles-Side       30 rue Jules Guesde       M
4                    2Vehicles-Side        72 rue Victor Hugo       M
5                             Other                       D39       M
6                             Other        4 route de camphin       M
7                             Other         rue saint exupéry       M
8                             Other          rue de l'égalité       M
9  2Vehicles-BehindVehicles-Frontal   face au 59 rue de Lille       M

   Latitude  Longitude
0  50.55737    2.55737
1  50.52936    2.52936
2  50.51243    2.51243
3  50.51974    2.51974
4  50.51607    2.51607
5  50.52132    2.52132
6  50.52211    2.52211
7  50.53146    2.53146
8  50.53707    2.53707
9  50.53639    2.53639
Vehicle dataframe (first 10 rows):
     AccidentId VehicleId Direction          Category  PassengerNumber  0  201800000001       A01   Unknown         Car<=3.5T                0
1  201800000001       B01   Unknown         Car<=3.5T                0
2  201800000002       A01   Unknown         Car<=3.5T                0
3  201800000003       A01   Unknown  Motorbike>125cm3                0
4  201800000003       B01   Unknown         Car<=3.5T                0
5  201800000003       C01   Unknown         Car<=3.5T                0
6  201800000004       A01   Unknown         Car<=3.5T                0
7  201800000004       B01   Unknown           Bicycle                0
8  201800000005       A01   Unknown             Moped                0
9  201800000005       B01   Unknown         Car<=3.5T                0

       FixedObstacle MobileObstacle ImpactPoint           Maneuver
0                NaN        Vehicle  RightFront         TurnToLeft
1                NaN        Vehicle   LeftFront  NoDirectionChange
2                NaN     Pedestrian         NaN  NoDirectionChange
3  StationaryVehicle        Vehicle       Front  NoDirectionChange
4                NaN        Vehicle    LeftSide         TurnToLeft
5                NaN            NaN   RightSide             Parked
6                NaN          Other  RightFront          Avoidance
7                NaN        Vehicle    LeftSide                NaN
8                NaN        Vehicle  RightFront           PassLeft
9                NaN        Vehicle   LeftFront               Park
User dataframe (first 10 rows):
     AccidentId VehicleId  Seat    Category Gender TripReason    SafetyDevice  0  201800000001       A01   1.0      Driver   Male    Leisure        SeatBelt
1  201800000001       B01   1.0      Driver   Male        NaN        SeatBelt
2  201800000002       A01   1.0      Driver   Male        NaN        SeatBelt
3  201800000002       A01   NaN  Pedestrian   Male        NaN          Helmet
4  201800000003       A01   1.0      Driver   Male    Leisure          Helmet
5  201800000003       C01   1.0      Driver   Male        NaN  ChildrenDevice
6  201800000004       A01   1.0      Driver   Male    Leisure        SeatBelt
7  201800000004       B01   1.0      Driver   Male    Leisure          Helmet
8  201800000005       A01   1.0      Driver   Male    Leisure          Helmet
9  201800000005       B01   1.0      Driver   Male    Leisure        SeatBelt

  SafetyDeviceUsed            PedestrianLocation PedestrianAction  0              Yes                           NaN              NaN
1              Yes                           NaN              NaN
2              Yes                           NaN              NaN
3              NaN  OnLane<=OnSidewalk0mCrossing         Crossing
4              Yes                           NaN              NaN
5              NaN                           NaN              NaN
6              Yes                           NaN              NaN
7              NaN                           NaN              NaN
8              Yes                           NaN              NaN
9              Yes                           NaN              NaN

  PedestrianCompany  BirthYear
0           Unknown     1960.0
1           Unknown     1928.0
2           Unknown     1947.0
3             Alone     1959.0
4           Unknown     1987.0
5           Unknown     1977.0
6           Unknown     1982.0
7           Unknown     2013.0
8           Unknown     2001.0
9           Unknown     1946.0
Places dataframe (first 10 rows):
     AccidentId       RoadType RoadNumber  RoadSecNumber RoadLetter  0  201800000001  Departamental         41            NaN          C
1  201800000002       Communal         41            NaN          D
2  201800000003  Departamental         39            NaN          D
3  201800000004  Departamental         39            NaN        NaN
4  201800000005       Communal        NaN            NaN        NaN
5  201800000006  Departamental         39            NaN          D
6  201800000007  Departamental         41            NaN          D
7  201800000008       Communal          -            NaN        NaN
8  201800000009  Departamental        141            NaN          D
9  201800000010  Departamental        641            NaN        NaN

  Circulation  LaneNumber SpecialLane   Slope  RoadMarkerId  0      TwoWay         2.0           0    Flat           NaN
1      TwoWay         2.0           0    Flat           NaN
2      TwoWay         2.0           0    Flat           NaN
3      TwoWay         2.0           0    Flat           NaN
4      OneWay         1.0           0    Flat           NaN
5     Unknown         2.0           0  Uphill           NaN
6      TwoWay         2.0           0    Flat          16.0
7      TwoWay         2.0           0    Flat           NaN
8      TwoWay         2.0           0    Flat           NaN
9      TwoWay         2.0        Bike    Flat           1.0

   RoadMarkerDistance      Layout  StripWidth  LaneWidth SurfaceCondition  0                 NaN  RightCurve         NaN        NaN           Normal
1                 NaN   LeftCurve         NaN        NaN           Normal
2                 NaN    Straight         NaN        NaN           Normal
3                 NaN    Straight         NaN        NaN           Normal
4                 NaN    Straight         NaN        NaN           Normal
5                 NaN   LeftCurve         NaN        NaN              Wet
6               500.0    Straight         NaN        NaN           Normal
7                 NaN    Straight         NaN        NaN           Normal
8                 NaN    Straight         NaN        NaN           Normal
9               670.0    Straight         NaN        NaN           Normal

  Infrastructure Localization  SchoolNear
0        Unknown         Lane         0.0
1        Unknown         Lane         0.0
2        Unknown         Lane         0.0
3        Unknown         Lane         0.0
4        Unknown         Lane         0.0
5        Unknown     Shoulder         0.0
6        Unknown     Shoulder         0.0
7        Unknown         Lane         0.0
8        Unknown     Shoulder         0.0
9        Unknown         Lane         0.0

Create the multi-table dataset specification

Note the main table Accidents and the secondary table Places have one key AccidentId. Tables Vehicles (the other secondary table) and Users (the tertiary table) have two keys: AccidentId and VehicleId.

To describe relations between tables, we add the relations field must to the dataset spec. This field contains a list of tuples describing the relations between tables. The first two values (str) of each tuple correspond to names of both the parent and the child table involved in the relation. A third value (bool) can be optionally set as True to indicate that the relation is 1:1. For example, if the tuple (table1, table2, True) is contained in this field, it means that:

  • table1 and table2 are in a 1:1 relationship

  • The key of table1 is contained in that of table2 (ie. keys are hierarchical)

If the relations field is not present then Khiops Python assumes that the tables are in a star schema with main_table as the central table.

X_accidents = {
    "main_table": "Accidents",
    "tables": {
        "Accidents": (accidents_df.drop("Gravity", axis=1), "AccidentId"),
        "Vehicles": (vehicles_df, ["AccidentId", "VehicleId"]),
        "Users": (users_df, ["AccidentId", "VehicleId"]),
        "Places": (places_df, "AccidentId"),
    },
    "relations": [
        ("Accidents", "Vehicles"),
        ("Vehicles", "Users"),
        ("Accidents", "Places", True),
    ],
}
y_accidents = accidents_df["Gravity"]

Split the dataset into train and test

We use the helper function train_test_split_dataset with the X dataset spec to obtain one spec for train and another for test.

(
    X_accidents_train,
    X_accidents_test,
    y_accidents_train,
    y_accidents_test,
) = train_test_split_dataset(X_accidents, y_accidents, test_size=0.3)

Train a classifier with this dataset

  • You may choose the number of features n_features to be created by the Khiops AutoML engine

  • Set the number of trees to zero (n_trees=0)

khc_accidents = KhiopsClassifier(n_trees=0, n_features=1000)
khc_accidents.fit(X_accidents_train, y_accidents_train)
KhiopsClassifier(n_features=1000, n_trees=0)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Deploy the classifier to obtain predictions and probabilities on the test data

y_accidents_test_predicted = khc_accidents.predict(X_accidents_test)
probas_accidents_test = khc_accidents.predict_proba(X_accidents_test)

print("Accidents test predictions (first 10 values):")
display(y_accidents_test_predicted[:10])
print("Accidentns test prediction probabilities (first 10 values):")
display(probas_accidents_test[:10])
Accidents test predictions (first 10 values):
array(['NonLethal', 'NonLethal', 'NonLethal', 'NonLethal', 'NonLethal',
       'NonLethal', 'NonLethal', 'NonLethal', 'NonLethal', 'NonLethal'],
      dtype='<U9')
Accidentns test prediction probabilities (first 10 values):
array([[0.0526682 , 0.9473318 ],
       [0.00265006, 0.99734994],
       [0.11575197, 0.88424803],
       [0.003901  , 0.996099  ],
       [0.12084546, 0.87915454],
       [0.03023732, 0.96976268],
       [0.02279352, 0.97720648],
       [0.01870448, 0.98129552],
       [0.01280118, 0.98719882],
       [0.00382921, 0.99617079]])

Estimate the accuracy and AUC metrics on the test data

accidents_test_accuracy = metrics.accuracy_score(
    y_accidents_test, y_accidents_test_predicted
)
accidents_test_auc = metrics.roc_auc_score(
    y_accidents_test, probas_accidents_test[:, 1]
)

print(f"Accidents test accuracy: {accidents_test_accuracy}")
print(f"Accidents test auc     : {accidents_test_auc}")
Accidents test accuracy: 0.9435246610902798
Accidents test auc     : 0.824969336370397