Enforcing Abstract Class Variables in Python

November 5, 2021

Overview

An abstract class is a class that defines an interface which derived subclasses must conform to. It is an object-orientated concept whereby a class is declared abstract, should not be instantiated and may contain abstract methods.

A class variable is a variable declared upon class construction and is owned by the class itself. There is no native way to declare a class variable abstract, but we will look at how we can implement this functionality for the following use case.

Abstract Class Requirements

In our use case, suppose we want to create various data models that represent data in a data pipeline. Each model must conform to some pre-defined schema, expose the column names within that schema and expose the data within that model. Therefore, we need to define an abstract class (not exactly) with the following requirements:

  1. the base class must not be instantiated
  2. the base class has an abstract class variable (emphasis on class variable, which defines the schema for that model)
  3. the base class has no abstract methods, but it has a class method (which returns the column names of that schema) and a concrete method (which returns the data within that model)

We can create an abstract class that meets these requirements as follows:

from abc import ABC
from typing import List
 
import pandas as pd
import pandera as pa
 
 
class BaseModel(ABC):
    schema: pa.DataFrameSchema = NotImplemented
 
    def __new__(cls, *args, **kwargs):
        if cls is BaseModel:
            raise TypeError(f"Only children of '{cls.__name__}' may be instantiated")
        return super().__new__(cls)
 
    def __init__(self, df: pd.DataFrame):
        self._df = self.schema.validate(df)
 
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        if cls.schema is NotImplemented:
            raise NotImplementedError("Please implement the `schema` class variable")
 
    @classmethod
    def get_columns(cls) -> List[str]:
        return list(cls.schema.columns.keys())
 
    def get_df(self):
        return self._df

In terms of requirement 1, if we try to instantiate BaseModel, a TypeError will be thrown. This is enforced by overriding the __new__ dunder method, which controls how instances of a class are created (we cannot rely on extending python's abc.ABC class alone given that there are no abstract methods in this class). Typically, the __new__ dunder method will return the class instance object reference. Once __new__ has completed execution, the __init__ dunder method will be called.

If you override both __new__ and __init__, you cannot pass arguments or key-word arguments to __new__. In our example above, we return super().__new__(cls) instead of super().__new__(cls, *args, **kwargs).

In terms of requirement 2, if we try to define a subclass that does not override the abstract schema class variable, a NotImplementedError will be thrown. This is enforced by overriding the __init_subclass__ dunder method, which is a hook that controls post-class initialisation (we cannot rely on using @property with @abstractmethod because we have a class method that requires a class variable, not an instance variable). The __init_subclass__ dunder method is a hook that can be used for anything and is useful to configure default values for subclasses.

In terms of requirement 3, we define a class method that can be used on the class object itself as well as a concrete method that can be used on an instance of the class.

Creating a Class Instance

We define a DataModel class that extends BaseModel:

class DataModel(BaseModel):
    schema = pa.DataFrameSchema(
        {
            "col_1": pa.Column(int, unique=False, nullable=False),
            "col_2": pa.Column(int, unique=False, nullable=False),
            "col_3": pa.Column(int, unique=False, nullable=False),
        },
        strict=True,
    )

We can now create a dataframe and instantiate an instance of DataModel:

df = pd.DataFrame({"col_1": [1, 2, 3], "col_2": [4, 5, 6], "col_3": [7, 8, 9]})
data_model = DataModel(df)