Source code for everest.entities.aggregates
"""
Aggregate implementations.
This file is part of the everest project.
See LICENSE.txt for licensing, CONTRIBUTORS.txt for contributor information.
Created on Sep 25, 2011.
"""
from everest.entities.base import Aggregate
from everest.exceptions import DuplicateException
from everest.querying.base import EXPRESSION_KINDS
from everest.utils import get_filter_specification_visitor
from everest.utils import get_order_specification_visitor
from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy.orm.exc import NoResultFound
__docformat__ = 'reStructuredText en'
__all__ = ['MemoryAggregate',
'OrmAggregate',
]
[docs]class MemoryAggregate(Aggregate):
"""
In-memory implementation for aggregates.
:note: When "blank" entities without an ID and a slug are added to a
memory aggregate, they can not be retrieved using the
:meth:`get_by_id` or :meth:`get_by_slug` methods since there
is no mechanism to autogenerate IDs or slugs.
"""
def count(self):
return len(self.__get_entities())
def get_by_id(self, id_key):
if self._relationship is None or self._relationship.children is None:
ent = self._session.get_by_id(self.entity_class, id_key)
if not self._filter_spec is None \
and not self._filter_spec.is_satisfied_by(ent):
ent = None
else:
ent = self.__filter_by_attr(self._relationship.children,
'id', id_key)
return ent
def get_by_slug(self, slug):
if self._relationship is None or self._relationship.children is None:
ent = self._session.get_by_slug(self.entity_class, slug)
if not self._filter_spec is None \
and not self._filter_spec.is_satisfied_by(ent):
ent = None
else:
ent = self.__filter_by_attr(self._relationship.children,
'slug', slug)
return ent
def iterator(self):
for ent in self.__get_entities():
yield ent
def add(self, entity):
if not isinstance(entity, self.entity_class):
raise ValueError('Can only add entities of type "%s" to this '
'aggregate.' % self.entity_class)
if self._relationship is None:
self._session.add(self.entity_class, entity)
else:
if not entity.id is None \
and self.__check_existing(self._relationship.children, entity):
raise ValueError('Duplicate ID or slug.')
self._relationship.children.append(entity)
def remove(self, entity):
if self._relationship is None:
self._session.remove(self.entity_class, entity)
else:
self._relationship.children.remove(entity)
def _apply_filter(self):
pass
def _apply_order(self):
pass
def _apply_slice(self):
pass
def __get_entities(self):
if self._relationship is None or self._relationship.children is None:
ents = self._session.get_all(self.entity_class)
else:
ents = self._relationship.children
if not self._filter_spec is None:
visitor = get_filter_specification_visitor(EXPRESSION_KINDS.EVAL)()
self._filter_spec.accept(visitor)
ents = visitor.expression(ents)
if not self._order_spec is None:
visitor = get_order_specification_visitor(EXPRESSION_KINDS.EVAL)()
self._order_spec.accept(visitor)
ents = visitor.expression(ents)
if not self._slice_key is None:
ents = ents[self._slice_key]
return ents
def __check_existing(self, ents, entity):
found = [ent for ent in ents
if ent.id == entity.id or ent.slug == entity.slug]
return len(found) > 0
def __filter_by_attr(self, ents, attr, value):
if self._filter_spec is None:
matching_ents = \
[ent for ent in ents if getattr(ent, attr) == value]
else:
matching_ents = \
[ent for ent in ents
if (getattr(ent, attr) == value
and self._filter_spec.is_satisfied_by(ent))]
if len(matching_ents) == 1:
ent = matching_ents[0]
elif len(matching_ents) == 0:
ent = None
else:
raise DuplicateException('Duplicates found for "%s" value of ' # pragma: no cover
'"%s" attribue.' % (value, attr))
return ent
[docs]class OrmAggregate(Aggregate):
"""
ORM implementation for aggregates.
"""
def __init__(self, entity_class, session_factory, search_mode=False):
Aggregate.__init__(self, entity_class, session_factory)
self._search_mode = search_mode
def count(self):
if not self._relationship is None:
# We need a flush here because we may have newly added entities
# in the aggregate which need to get an ID *before* we build the
# relation filter spec.
self._session.flush()
if self.__defaults_empty:
cnt = 0
else:
cnt = self.__get_filtered_query(None).count()
return cnt
def get_by_id(self, id_key):
query = self.__get_filtered_query(id_key)
try:
ent = query.filter_by(id=id_key).one()
except NoResultFound:
ent = None
except MultipleResultsFound: # pragma: no cover
raise DuplicateException('Duplicates found for ID "%s".' % id_key)
return ent
def get_by_slug(self, slug):
query = self.__get_filtered_query(slug)
try:
ent = query.filter_by(slug=slug).one()
except NoResultFound:
ent = None
except MultipleResultsFound: # pragma: no cover
raise DuplicateException('Duplicates found for slug "%s".' % slug)
return ent
def iterator(self):
if self.__defaults_empty:
raise StopIteration()
else:
if len(self._session.new) > 0:
# We need a flush here because we may have newly added
# entities in the aggregate which need to get an ID *before*
# we build the query expression.
self._session.flush()
query = self._get_data_query()
for obj in iter(query):
yield obj
def add(self, entity):
if self._relationship is None:
self._session.add(entity)
else:
self._relationship.children.append(entity)
def remove(self, entity):
if self._relationship is None:
self._session.delete(entity)
else:
self._relationship.children.remove(entity)
def _apply_filter(self):
pass
def _apply_order(self):
pass
def _apply_slice(self):
pass
def _query_generator(self, query, key): # unused pylint: disable=W0613
return query
def _filter_visitor_factory(self):
visitor_cls = get_filter_specification_visitor(EXPRESSION_KINDS.SQL)
return visitor_cls(self.entity_class)
def _order_visitor_factory(self):
visitor_cls = get_order_specification_visitor(EXPRESSION_KINDS.SQL)
return visitor_cls(self.entity_class)
def _get_base_query(self):
if self._relationship is None:
query = self._session.query(self.entity_class)
else:
# Pre-filter the base query with the relation specification.
rel_spec = self._relationship.specification
visitor = self._filter_visitor_factory()
rel_spec.accept(visitor)
expr = visitor.expression
query = self._session.query(self.entity_class).filter(expr)
return query
def _get_data_query(self):
query = self.__get_ordered_query(self._slice_key)
if not self._slice_key is None:
query = query.slice(self._slice_key.start,
self._slice_key.stop)
return query
def __get_filtered_query(self, key):
query = self._query_generator(self._get_base_query(), key)
if not self._filter_spec is None:
visitor = self._filter_visitor_factory()
self._filter_spec.accept(visitor)
query = query.filter(visitor.expression)
return query
def __get_ordered_query(self, key):
query = self.__get_filtered_query(key)
if not self._order_spec is None:
visitor = self._order_visitor_factory()
self._order_spec.accept(visitor)
for join_expr in visitor.get_joins():
# FIXME: only join when needed here.
query = query.outerjoin(join_expr)
query = query.order_by(visitor.expression)
return query
@property
def __defaults_empty(self):
return self._filter_spec is None and self._search_mode