sqlalchemy_schemadisplay3.py
changeset 313 a88add2b3eea
parent 312 42fd5075a5d1
child 314 353541614dfa
equal deleted inserted replaced
312:42fd5075a5d1 313:a88add2b3eea
     1 # Copyright (c) 2012 netzguerilla.net <iro@netzguerilla.net>
       
     2 # 
       
     3 # This file is part of Iro.
       
     4 # 
       
     5 # Permission is hereby granted, free of charge, to any person obtaining a copy of
       
     6 # this software and associated documentation files (the "Software"), to deal in
       
     7 # the Software without restriction, including without limitation the rights to use,
       
     8 # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
       
     9 # #Software, and to permit persons to whom the Software is furnished to do so,
       
    10 # subject to the following conditions:
       
    11 # 
       
    12 # The above copyright notice and this permission notice shall be included in
       
    13 # all copies or substantial portions of the Software.
       
    14 # 
       
    15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
       
    16 # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
       
    17 # PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
       
    18 # HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
       
    19 # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
       
    20 # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
       
    21 
       
    22 # updated SQLA schema display to work with pydot 1.0.2
       
    23 # download from: http://www.sqlalchemy.org/trac/wiki/UsageRecipes/SchemaDisplay
       
    24 
       
    25 
       
    26 from sqlalchemy.orm.properties import PropertyLoader
       
    27 import pydot
       
    28 import types
       
    29 
       
    30 __all__ = ['create_uml_graph', 'create_schema_graph', 'show_uml_graph', 'show_schema_graph']
       
    31 
       
    32 def _mk_label(mapper, show_operations, show_attributes, show_datatypes, bordersize):
       
    33     html = '<<TABLE CELLSPACING="0" CELLPADDING="1" BORDER="0" CELLBORDER="%d" BALIGN="LEFT"><TR><TD><FONT POINT-SIZE="10">%s</FONT></TD></TR>' % (bordersize, mapper.class_.__name__)
       
    34     def format_col(col):
       
    35         colstr = '+%s' % (col.name)
       
    36         if show_datatypes:
       
    37             colstr += ' : %s' % (col.type.__class__.__name__)
       
    38         return colstr
       
    39             
       
    40     if show_attributes:
       
    41         html += '<TR><TD ALIGN="LEFT">%s</TD></TR>' % '<BR ALIGN="LEFT"/>'.join(format_col(col) for col in sorted(mapper.columns, key=lambda col:not col.primary_key))
       
    42     else:
       
    43         [format_col(col) for col in sorted(mapper.columns, key=lambda col:not col.primary_key)]
       
    44     if show_operations:
       
    45         html += '<TR><TD ALIGN="LEFT">%s</TD></TR>' % '<BR ALIGN="LEFT"/>'.join(
       
    46             '%s(%s)' % (name,", ".join(default is _mk_label and ("%s") % arg or ("%s=%s" % (arg,repr(default))) for default,arg in 
       
    47                 zip((func.func_defaults and len(func.func_code.co_varnames)-1-(len(func.func_defaults) or 0) or func.func_code.co_argcount-1)*[_mk_label]+list(func.func_defaults or []), func.func_code.co_varnames[1:])
       
    48             ))
       
    49             for name,func in mapper.class_.__dict__.items() if isinstance(func, types.FunctionType) and func.__module__ == mapper.class_.__module__
       
    50         )
       
    51     html+= '</TABLE>>'
       
    52     return html
       
    53 
       
    54 
       
    55 def create_uml_graph(mappers, show_operations=True, show_attributes=True, show_multiplicity_one=False, show_datatypes=True, linewidth=1.0, font="Bitstream-Vera Sans"):
       
    56     graph = pydot.Dot(prog='neato',mode="major",overlap="0", sep="0.01",dim="3", pack="True", ratio=".75")
       
    57     relations = set()
       
    58     for mapper in mappers:
       
    59         graph.add_node(pydot.Node(mapper.class_.__name__,
       
    60             shape="plaintext", label=_mk_label(mapper, show_operations, show_attributes, show_datatypes, linewidth),
       
    61             fontname=font, fontsize="8.0",
       
    62         ))
       
    63         if mapper.inherits:
       
    64             graph.add_edge(pydot.Edge(mapper.inherits.class_.__name__,mapper.class_.__name__,
       
    65                 arrowhead='none',arrowtail='empty', style="setlinewidth(%s)" % linewidth, arrowsize=str(linewidth)))
       
    66         for loader in mapper.iterate_properties:
       
    67             if isinstance(loader, PropertyLoader) and loader.mapper in mappers:
       
    68                 if hasattr(loader, 'reverse_property'):
       
    69                     relations.add(frozenset([loader, loader.reverse_property]))
       
    70                 else:
       
    71                     relations.add(frozenset([loader]))
       
    72 
       
    73     for relation in relations:
       
    74         #if len(loaders) > 2:
       
    75         #    raise Exception("Warning: too many loaders for join %s" % join)
       
    76         args = {}
       
    77         def multiplicity_indicator(prop):
       
    78             if prop.uselist:
       
    79                 return ' *'
       
    80             if any(col.nullable for col in prop.local_side):
       
    81                 return ' 0..1'
       
    82             if show_multiplicity_one:
       
    83                 return ' 1'
       
    84             return ''
       
    85         
       
    86         if len(relation) == 2:
       
    87             src, dest = relation
       
    88             from_name = src.parent.class_.__name__
       
    89             to_name = dest.parent.class_.__name__
       
    90             
       
    91             def calc_label(src,dest):
       
    92                 return '+' + src.key + multiplicity_indicator(src)
       
    93             args['headlabel'] = calc_label(src,dest)
       
    94             
       
    95             args['taillabel'] = calc_label(dest,src)
       
    96             args['arrowtail'] = 'none'
       
    97             args['arrowhead'] = 'none'
       
    98             args['constraint'] = False
       
    99         else:
       
   100             prop, = relation
       
   101             from_name = prop.parent.class_.__name__
       
   102             to_name = prop.mapper.class_.__name__
       
   103             args['headlabel'] = '+%s%s' % (prop.key, multiplicity_indicator(prop))
       
   104             args['arrowtail'] = 'none'
       
   105             args['arrowhead'] = 'vee'
       
   106         
       
   107         graph.add_edge(pydot.Edge(from_name,to_name,
       
   108             fontname=font, fontsize="7.0", style="setlinewidth(%s)"%linewidth, arrowsize=str(linewidth),
       
   109             **args)
       
   110         )
       
   111 
       
   112     return graph
       
   113 
       
   114 #from sqlalchemy import Table, text
       
   115 
       
   116 def _render_table_html(table, metadata, show_indexes, show_datatypes):
       
   117     def format_col_type(col):
       
   118         try:
       
   119             return col.type.get_col_spec()
       
   120         except NotImplementedError:
       
   121             return str(col.type)
       
   122         except AttributeError:
       
   123             return str(col.type)
       
   124     def format_col_str(col):
       
   125          if show_datatypes:
       
   126              return "- %s : %s" % (col.name, format_col_type(col))
       
   127          else:
       
   128              return "- %s" % col.name
       
   129     html = '<<TABLE BORDER="1" CELLBORDER="0" CELLSPACING="0"><TR><TD ALIGN="CENTER">%s</TD></TR><TR><TD BORDER="1" CELLPADDING="0"></TD></TR>' % table.name 
       
   130 
       
   131     html += ''.join('<TR><TD ALIGN="LEFT" PORT="%s">%s</TD></TR>' % (col.name, format_col_str(col)) for col in table.columns)
       
   132     html += '</TABLE>>'
       
   133     return html
       
   134 
       
   135 def create_schema_graph(tables=None, metadata=None, show_indexes=True, show_datatypes=True, font="Bitstream-Vera Sans",
       
   136     concentrate=True, relation_options={}, rankdir='TB'):
       
   137     relation_kwargs = {
       
   138         'fontsize':"7.0"
       
   139     }
       
   140     relation_kwargs.update(relation_options)
       
   141     
       
   142     if not metadata and len(tables):
       
   143         metadata = tables[0].metadata
       
   144     elif not tables and metadata:
       
   145         if not len(metadata.tables):
       
   146             metadata.reflect()
       
   147         tables = metadata.tables.values()
       
   148     else:
       
   149         raise Exception("You need to specify at least tables or metadata")
       
   150     
       
   151     graph = pydot.Dot(prog="dot",mode="ipsep",overlap="ipsep",sep="0.01",concentrate=str(concentrate), rankdir=rankdir)
       
   152     for table in tables:
       
   153         graph.add_node(pydot.Node(str(table.name),
       
   154             shape="plaintext",
       
   155             label=_render_table_html(table, metadata, show_indexes, show_datatypes),
       
   156             fontname=font, fontsize="7.0"
       
   157         ))
       
   158     
       
   159     for table in tables:
       
   160         for fk in table.foreign_keys:
       
   161             edge = [table.name, fk.column.table.name]
       
   162             is_inheritance = fk.parent.primary_key and fk.column.primary_key
       
   163             if is_inheritance:
       
   164                 edge = edge[::-1]
       
   165             graph_edge = pydot.Edge(
       
   166                 headlabel="+ %s"%fk.column.name, taillabel='+ %s'%fk.parent.name,
       
   167                 arrowhead=is_inheritance and 'none' or 'odot' ,
       
   168                 arrowtail=(fk.parent.primary_key or fk.parent.unique) and 'empty' or 'crow' ,
       
   169                 fontname=font, 
       
   170                 #samehead=fk.column.name, sametail=fk.parent.name,
       
   171                 *edge, **relation_kwargs
       
   172             )
       
   173             graph.add_edge(graph_edge)
       
   174 
       
   175 # not sure what this part is for, doesn't work with pydot 1.0.2
       
   176 #            graph_edge.parent_graph = graph.parent_graph
       
   177 #            if table.name not in [e.get_source() for e in graph.get_edge_list()]:
       
   178 #                graph.edge_src_list.append(table.name)
       
   179 #            if fk.column.table.name not in graph.edge_dst_list:
       
   180 #                graph.edge_dst_list.append(fk.column.table.name)
       
   181 #            graph.sorted_graph_elements.append(graph_edge)
       
   182     return graph
       
   183 
       
   184 def show_uml_graph(*args, **kwargs):
       
   185     from cStringIO import StringIO
       
   186     from PIL import Image
       
   187     iostream = StringIO(create_uml_graph(*args, **kwargs).create_png())
       
   188     Image.open(iostream).show(command=kwargs.get('command','gwenview'))
       
   189 
       
   190 def show_schema_graph(*args, **kwargs):
       
   191     from cStringIO import StringIO
       
   192     from PIL import Image
       
   193     iostream = StringIO(create_schema_graph(*args, **kwargs).create_png())
       
   194     Image.open(iostream).show(command=kwargs.get('command','gwenview'))