Write Custom Accessors to Avoid Subclassing Pandas DataFrames

This blog post shows how to create custom accessor classes for pandas DataFrames which can be used to provide an encapsulated suite of functionality for your DataFrames. The custom accessors are registered with the pandas API, giving you easy access to namespaced attributes and methods. Writing accessor classes is simpler than subclassing, and the resulting code is more extensible and reusable.

If you prefer to skip the verbiage that follows, find the example Jupyter notebook and accompanying data files on GitHub.

Custom Accessors

A custom accessor is a class that contains a reference to a DataFrame object. The accessor methods operate directly on the DataFrame. Pandas provides a class decorator that you use to register your custom accessor with the pandas programming interface. Once the accessor class has been written, decorated, and loaded, any DataFrame can access the methods in the accessor with the syntax:

my_dataframe.my_namespace.my_method_name

There is no need to subclass DataFrame and the accessor methods are available to any DataFrame that satisfies a custom validation method that you provide.

Use Case

Suppose we are developing a tool to plot and analyze daily temperature at weather stations from around the world. For our initial exploration we’ve downloaded a csv file containing one month of daily weather logged at the Los Angeles Airport weather station from NOAA‘s Climate Data Online. We load the data, do a little formatting, and come up with this.

import pandas as pd
df = pd.read_csv('data/LAX_weather.csv')

# Drop columns for snow depth, wind, and weather type 
df = df.drop(['SNWD','AWND','WDF2','WDF5',
        'WSF2','WSF5','WT01','WT02','WT08'],axis=1)

# Convert strings in 'DATE' column to pd.Timestamp
df['DATE']=pd.to_datetime(df['DATE'])

# Store the station name as user data.
df.attrs['StationName']=df['NAME'][0]

df.head()

       STATION                                      NAME  ... TMAX  TMIN
0  USW00023174  LOS ANGELES INTERNATIONAL AIRPORT, CA US  ...   72    62
1  USW00023174  LOS ANGELES INTERNATIONAL AIRPORT, CA US  ...   73    61
2  USW00023174  LOS ANGELES INTERNATIONAL AIRPORT, CA US  ...   72    59
3  USW00023174  LOS ANGELES INTERNATIONAL AIRPORT, CA US  ...   74    59
4  USW00023174  LOS ANGELES INTERNATIONAL AIRPORT, CA US  ...   80    63

We write a plotting routine to show the daily temperature range and averages.

import matplotlib.pyplot as plt
from matplotlib import dates

fig,ax = plt.subplots()
ax.plot(df['DATE'],df['TAVG'],marker='o',markersize=4,markerfacecolor='w',lw=1,markevery=2,label='Average Daily Temperature')
ax.fill_between(df['DATE'],df['TMIN'],df['TMAX'],alpha=.4,color='yellow')
ax.axhline(1.0, linestyle=':', lw=1)
title_str = 'Daily Temperature Range: {start} to {end}\n{station_name}'.format(
    start=df['DATE'].iloc[0].strftime('%Y-%m-%d'),
    end=df['DATE'].iloc[-1].strftime('%Y-%m-%d'),
    station_name=(df.attrs['StationName']))
ax.set_title(title_str)
ax.xaxis.set_major_formatter(dates.DateFormatter('%m/%d'))
ax.set_ylabel('Temperature (F)')
ax.set_xlabel('Date')
ax.legend()

We like what we have, so we decide to create a new method, “plot_temperature,” that can run the above plotting routine on any DataFrame that has columns ‘DATE’, ‘TMAX’, ‘TMIN’, and ‘TAVG’. We’d also like to standardize the plot title, so we’d like one-liners that can return the station name, date range start date, and date range end date.

Implementation

Here’s how to implement the custom accessor.

Step 1. Decorator and __init__

We begin writing the class by providing an initialization method and also decorating with the correct accessor registration function. We will use “weather” as our accessor namespace.

@pd.api.extensions.register_dataframe_accessor("weather")
class WeatherAccessor:
    def __init__(self, pandas_obj):
        self._obj=pandas_obj

Step 2. Validation

We write a validation method to ensure that the DataFrame has the columns required for temperature plotting and that all data come from the same weather station. We call the validation method from our initialization method.

@pd.api.extensions.register_dataframe_accessor("weather")
class WeatherAccessor:
    def __init__(self, pandas_obj):
        self._validate(pandas_obj)
        self._obj=pandas_obj
        
    @staticmethod
    def _validate(obj):
        # verify are columns for date, tmin, tmax, and tavg
        if not all(col in obj.columns for col in ['DATE','TMAX','TMIN','TAVG']):
            raise AttributeError("Columns must include 'DATE','TMAX','TMIN', and 'TAVG'")
        if not all(nm == obj['NAME'][0] for nm in obj['NAME']):
            raise AttributeError("All values in NAME column must be the same")

Step 3. Finish the Class

Finally we add properties to return the weather station name and the time series start and end date as well as a method to create the desired plot. The full class definition is as follows.

@pd.api.extensions.register_dataframe_accessor("weather")
class WeatherAccessor:
    def __init__(self, pandas_obj):
        self._validate(pandas_obj)
        self._obj=pandas_obj
        
    @staticmethod
    def _validate(obj):
        # verify are columns for date, tmin, tmax, and tavg
        if not all(col in obj.columns for col in ['DATE','TMAX','TMIN','TAVG']):
            raise AttributeError("Columns must include 'DATE','TMAX','TMIN', and 'TAVG'")
        if not all(nm == obj['NAME'][0] for nm in obj['NAME']):
            raise AttributeError("All values in NAME column must be the same")
  
    @property
    def start_date(self):
        # return the time series start date
        return pd.to_datetime(self._obj.sort_values('DATE',axis=0)['DATE'].iloc[0]).strftime('%Y-%m-%d')
    
    @property
    def end_date(self):
        # return the time series end date
        return pd.to_datetime(self._obj.sort_values('DATE',axis=0)['DATE'].iloc[-1]).strftime('%Y-%m-%d')
  
    @property
    def station_name(self):
        # return the station name
        return self._obj['NAME'][0]
    
    def plot_temperature(self):
        fig,ax = plt.subplots()
        ax.plot(pd.to_datetime(self._obj['DATE']),self._obj['TAVG'],marker='o',markersize=4,markerfacecolor='w',lw=1,markevery=2,label='Average Daily Temperature')
        ax.fill_between(pd.to_datetime(self._obj['DATE']),self._obj['TMIN'],self._obj['TMAX'],alpha=.4,color='yellow')
        ax.axhline(1.0, linestyle=':', lw=1)
        title_str = 'Daily Temperature Range: {start} to {end}\n{station_name}'.format(
            start=self.start_date,
            end=self.end_date,
            station_name=self.station_name)
        ax.set_title(title_str)
        ax.xaxis.set_major_formatter(dates.DateFormatter('%m/%d'))
        ax.set_ylabel('Temperature (F)')
        ax.set_xlabel('Date')
        ax.legend()
    

Usage

And that’s it! Now, using the ‘weather’ namespace, we have access to all the custom accessor-defined attributes and have a one-liner for creating the plot.

df.weather.start_date

'2020-05-01'

df.weather.end_date

'2020-05-31'

df.weather.station_name

'LOS ANGELES INTERNATIONAL AIRPORT, CA US'

df.weather.plot_temperature()

Reuse

To put our work to the test we grab another dataset, this time from New York’s LaGuardia Airport. We load the data into a pandas DataFrame and try it out.

lga_df = pd.read_csv('data/LGA_weather.csv')
lga_df.weather.plot_temperature()

As easy as it gets! We’ve created code that is clear, easy to maintain, and will work on any pandas DataFrame that satisfies the validation criteria.

If you’d like to play around with this example, the dataset files and a Jupyter notebook containing the code are available on GitHub.