Failsafing Multitenancy in SQLAlchemy

July 6, 2016
python sqlalchemy development

At Cratejoy, we run a hosted, multi-tenant application, and for multiple reasons, certain data is centralized between our separate merchants. Merchant specific data should never be crossed, for obvious reasons, however our current setup led to simple engineering oversights to lead to exactly this. We were able to trace data leaks back to our use of joinload’s of SQLAlchemy backrefs on centralized models. Ie, we have a centralized model with a child relation to something that is store specific. When we joinload the child model, it includes the children across all merchants, not the specific one we want at the time. While we could manually filter this in most cases, we take these data leakages very seriously, and do not want an oversight by an engineer to so easily lead to data leakage. We took some time to develop a failsafe system that should prevent data leakage between stores, even in the event of missing filters.

For this walk through, we’ll use the following two models:


class Customer(db.Model):
  id = db.Column(db.Integer, primary_key=True)
  name = db.Column(db.String(50))

  addresses = db.relationship('Address', 
    backref='customer', collection_class=FilteredList)

class Address(db.Model):
  id = db.Column(db.Integer, primary_key=True)
  email = db.Column(db.String(50))

  customer_id = db.Column(db.Integer, db.ForeignKey(''))
  merchant_id = db.Column(db.Integer, db.ForeignKey(''))

Here, we have a Customer who can have multiple addresses, on multiple stores. However accessing customer.addresses will give us addresses from both stores, which is not what we want. Luckily for us, the majority of client specific tables already have a merchant_id attribute, and we were able to reduce this problem to a single statement: “If something is loaded through a SQLAlchemy backref during a request, and it has a merchant_id, it’s merchant_id should always be the same as the merchant that request is for”. From this problem statement, we were able to build a custom collection class to use for our SQLAlchemy backrefs, which explicitly checks this and transparently removes models that were incorrectly loaded.

A custom collection_class is a class which must implement the list interface, and will be used to hold the collection of returned child models. With this in mind, we can create the following collection_class:


from sqlalchemy.orm.collections import InstrumentedList

class FilteredList(InstrumentedList):
  List-like collection allowing us to prevent cross account 
  data leakage from centralized resources. Works by 
  preventing non-matching objects from being added to 
  backref containers.

  attr_name = u'merchant_id'
  attr_val = 1
  Will prevent anything with a `merchant_id` attribute 
  who's value does not equal `1` from being added to a 
  FilteredList instance

  Use by setting the `collection_class` attribute on 
  SQLAlchemy backrefs.


  # Attributes
  attr_name = None
  attr_val = None

  def enable_filter(cls, filter_name, filter_val):
    'Activate' the filtering using the provided attribute
    name and value.
    cls.attr_name = filter_name
    cls.attr_val = filter_val

  def reset_filter(cls):
    'Deactivate' the filtering.
    cls.attr_name, cls.attr_val = None, None

  def append(self, item):
    Append iff:
      - the filters are not set, or
      - the object matches the filters

    if FilteredList.attr_name and FilteredList.attr_val:

      if hasattr(item, FilteredList.attr_name):
        object_val = getattr(item, FilteredList.attr_name)
        if object_val is not None and 
           object_val != FilteredList.attr_val:
          # Prevent actually appending, effectively dropping 
          # the non-matching object

    super(FilteredList, self).append(item)

As you can see, this is checking the merchant_id against a static value on the class on each add. Only if the merchant_id matches the static ID do we allow it to actually add to the list. In our case we’re scoping the filtering to a given request, so the ID to match must be explicitly set at the beginning of each request. We do this in a pre-request hook (shown below) however your use may vary.


def register_hooks(app):
  def _before_request():

  def _after_request(response):

Note that this works for us since each request is synchronous and blocking, but this would have to be modified slightly to work in an asynchronous environment where a single thread could be handling multiple requests simultaneously.

Now you’re probably thinking something like:

But wait Nick, this seems very inefficient. You’re loading things that you know you don’t want, and then filtering them afterwards. Why not just filter them specifically?!

You would be right. This is very inefficient, and if we load a lot of unrelated data it will be very slow. Code being written should definitely be explicitly filtering, and oversights like this are always looked for in code review. This method is meant as our fallback for two cases:

  1. Both the engineer writing the code, and the reviewers, overlook bad data being loaded.
  2. Situations where explicit filtering isn’t possible/practical. This is the case when some libs are used, where rewriting and refactoring the usage would have taken much longer, and be more error prone.

This approach actually works really well. By modifying the FilteredList implementation above slightly, we can add logging to identify when things are being loaded that shouldn’t be, and go back and explicitly fix it.

One final thing: a transient bug was later discovered with this implementation. In the event that objects were being excluded by the filters, sometimes (but only sometimes) data that we actually wanted was also being dropped. More information, as well as a fix, can be found here.