sqlalchemy_schemadisplay3.py
branchdevel
changeset 77 7dce6c0f06fb
child 294 0e75bd39767d
equal deleted inserted replaced
76:9f7da8dc8df8 77:7dce6c0f06fb
       
     1 # updated SQLA schema display to work with pydot 1.0.2
       
     2 # download from: http://www.sqlalchemy.org/trac/wiki/UsageRecipes/SchemaDisplay
       
     3 
       
     4 
       
     5 from sqlalchemy.orm.properties import PropertyLoader
       
     6 import pydot
       
     7 import types
       
     8 
       
     9 __all__ = ['create_uml_graph', 'create_schema_graph', 'show_uml_graph', 'show_schema_graph']
       
    10 
       
    11 def _mk_label(mapper, show_operations, show_attributes, show_datatypes, bordersize):
       
    12     html = '<<TABLE CELLSPACING="0" CELLPADDING="1" BORDER="0" CELLBORDER="%d" BALIGN="LEFT"><TR><TD><FONT POINT-SIZE="10">%s</FONT></TD></TR>' % (bordersize, mapper.class_.__name__)
       
    13     def format_col(col):
       
    14         colstr = '+%s' % (col.name)
       
    15         if show_datatypes:
       
    16             colstr += ' : %s' % (col.type.__class__.__name__)
       
    17         return colstr
       
    18             
       
    19     if show_attributes:
       
    20         html += '<TR><TD ALIGN="LEFT">%s</TD></TR>' % '<BR ALIGN="LEFT"/>'.join(format_col(col) for col in sorted(mapper.columns, key=lambda col:not col.primary_key))
       
    21     else:
       
    22         [format_col(col) for col in sorted(mapper.columns, key=lambda col:not col.primary_key)]
       
    23     if show_operations:
       
    24         html += '<TR><TD ALIGN="LEFT">%s</TD></TR>' % '<BR ALIGN="LEFT"/>'.join(
       
    25             '%s(%s)' % (name,", ".join(default is _mk_label and ("%s") % arg or ("%s=%s" % (arg,repr(default))) for default,arg in 
       
    26                 zip((func.func_defaults and len(func.func_code.co_varnames)-1-(len(func.func_defaults) or 0) or func.func_code.co_argcount-1)*[_mk_label]+list(func.func_defaults or []), func.func_code.co_varnames[1:])
       
    27             ))
       
    28             for name,func in mapper.class_.__dict__.items() if isinstance(func, types.FunctionType) and func.__module__ == mapper.class_.__module__
       
    29         )
       
    30     html+= '</TABLE>>'
       
    31     return html
       
    32 
       
    33 
       
    34 def create_uml_graph(mappers, show_operations=True, show_attributes=True, show_multiplicity_one=False, show_datatypes=True, linewidth=1.0, font="Bitstream-Vera Sans"):
       
    35     graph = pydot.Dot(prog='neato',mode="major",overlap="0", sep="0.01",dim="3", pack="True", ratio=".75")
       
    36     relations = set()
       
    37     for mapper in mappers:
       
    38         graph.add_node(pydot.Node(mapper.class_.__name__,
       
    39             shape="plaintext", label=_mk_label(mapper, show_operations, show_attributes, show_datatypes, linewidth),
       
    40             fontname=font, fontsize="8.0",
       
    41         ))
       
    42         if mapper.inherits:
       
    43             graph.add_edge(pydot.Edge(mapper.inherits.class_.__name__,mapper.class_.__name__,
       
    44                 arrowhead='none',arrowtail='empty', style="setlinewidth(%s)" % linewidth, arrowsize=str(linewidth)))
       
    45         for loader in mapper.iterate_properties:
       
    46             if isinstance(loader, PropertyLoader) and loader.mapper in mappers:
       
    47                 if hasattr(loader, 'reverse_property'):
       
    48                     relations.add(frozenset([loader, loader.reverse_property]))
       
    49                 else:
       
    50                     relations.add(frozenset([loader]))
       
    51 
       
    52     for relation in relations:
       
    53         #if len(loaders) > 2:
       
    54         #    raise Exception("Warning: too many loaders for join %s" % join)
       
    55         args = {}
       
    56         def multiplicity_indicator(prop):
       
    57             if prop.uselist:
       
    58                 return ' *'
       
    59             if any(col.nullable for col in prop.local_side):
       
    60                 return ' 0..1'
       
    61             if show_multiplicity_one:
       
    62                 return ' 1'
       
    63             return ''
       
    64         
       
    65         if len(relation) == 2:
       
    66             src, dest = relation
       
    67             from_name = src.parent.class_.__name__
       
    68             to_name = dest.parent.class_.__name__
       
    69             
       
    70             def calc_label(src,dest):
       
    71                 return '+' + src.key + multiplicity_indicator(src)
       
    72             args['headlabel'] = calc_label(src,dest)
       
    73             
       
    74             args['taillabel'] = calc_label(dest,src)
       
    75             args['arrowtail'] = 'none'
       
    76             args['arrowhead'] = 'none'
       
    77             args['constraint'] = False
       
    78         else:
       
    79             prop, = relation
       
    80             from_name = prop.parent.class_.__name__
       
    81             to_name = prop.mapper.class_.__name__
       
    82             args['headlabel'] = '+%s%s' % (prop.key, multiplicity_indicator(prop))
       
    83             args['arrowtail'] = 'none'
       
    84             args['arrowhead'] = 'vee'
       
    85         
       
    86         graph.add_edge(pydot.Edge(from_name,to_name,
       
    87             fontname=font, fontsize="7.0", style="setlinewidth(%s)"%linewidth, arrowsize=str(linewidth),
       
    88             **args)
       
    89         )
       
    90 
       
    91     return graph
       
    92 
       
    93 #from sqlalchemy import Table, text
       
    94 
       
    95 def _render_table_html(table, metadata, show_indexes, show_datatypes):
       
    96     def format_col_type(col):
       
    97         try:
       
    98             return col.type.get_col_spec()
       
    99         except NotImplementedError:
       
   100             return str(col.type)
       
   101         except AttributeError:
       
   102             return str(col.type)
       
   103     def format_col_str(col):
       
   104          if show_datatypes:
       
   105              return "- %s : %s" % (col.name, format_col_type(col))
       
   106          else:
       
   107              return "- %s" % col.name
       
   108     html = '<<TABLE BORDER="1" CELLBORDER="0" CELLSPACING="0"><TR><TD ALIGN="CENTER">%s</TD></TR><TR><TD BORDER="1" CELLPADDING="0"></TD></TR>' % table.name 
       
   109 
       
   110     html += ''.join('<TR><TD ALIGN="LEFT" PORT="%s">%s</TD></TR>' % (col.name, format_col_str(col)) for col in table.columns)
       
   111     html += '</TABLE>>'
       
   112     return html
       
   113 
       
   114 def create_schema_graph(tables=None, metadata=None, show_indexes=True, show_datatypes=True, font="Bitstream-Vera Sans",
       
   115     concentrate=True, relation_options={}, rankdir='TB'):
       
   116     relation_kwargs = {
       
   117         'fontsize':"7.0"
       
   118     }
       
   119     relation_kwargs.update(relation_options)
       
   120     
       
   121     if not metadata and len(tables):
       
   122         metadata = tables[0].metadata
       
   123     elif not tables and metadata:
       
   124         if not len(metadata.tables):
       
   125             metadata.reflect()
       
   126         tables = metadata.tables.values()
       
   127     else:
       
   128         raise Exception("You need to specify at least tables or metadata")
       
   129     
       
   130     graph = pydot.Dot(prog="dot",mode="ipsep",overlap="ipsep",sep="0.01",concentrate=str(concentrate), rankdir=rankdir)
       
   131     for table in tables:
       
   132         graph.add_node(pydot.Node(str(table.name),
       
   133             shape="plaintext",
       
   134             label=_render_table_html(table, metadata, show_indexes, show_datatypes),
       
   135             fontname=font, fontsize="7.0"
       
   136         ))
       
   137     
       
   138     for table in tables:
       
   139         for fk in table.foreign_keys:
       
   140             edge = [table.name, fk.column.table.name]
       
   141             is_inheritance = fk.parent.primary_key and fk.column.primary_key
       
   142             if is_inheritance:
       
   143                 edge = edge[::-1]
       
   144             graph_edge = pydot.Edge(
       
   145                 headlabel="+ %s"%fk.column.name, taillabel='+ %s'%fk.parent.name,
       
   146                 arrowhead=is_inheritance and 'none' or 'odot' ,
       
   147                 arrowtail=(fk.parent.primary_key or fk.parent.unique) and 'empty' or 'crow' ,
       
   148                 fontname=font, 
       
   149                 #samehead=fk.column.name, sametail=fk.parent.name,
       
   150                 *edge, **relation_kwargs
       
   151             )
       
   152             graph.add_edge(graph_edge)
       
   153 
       
   154 # not sure what this part is for, doesn't work with pydot 1.0.2
       
   155 #            graph_edge.parent_graph = graph.parent_graph
       
   156 #            if table.name not in [e.get_source() for e in graph.get_edge_list()]:
       
   157 #                graph.edge_src_list.append(table.name)
       
   158 #            if fk.column.table.name not in graph.edge_dst_list:
       
   159 #                graph.edge_dst_list.append(fk.column.table.name)
       
   160 #            graph.sorted_graph_elements.append(graph_edge)
       
   161     return graph
       
   162 
       
   163 def show_uml_graph(*args, **kwargs):
       
   164     from cStringIO import StringIO
       
   165     from PIL import Image
       
   166     iostream = StringIO(create_uml_graph(*args, **kwargs).create_png())
       
   167     Image.open(iostream).show(command=kwargs.get('command','gwenview'))
       
   168 
       
   169 def show_schema_graph(*args, **kwargs):
       
   170     from cStringIO import StringIO
       
   171     from PIL import Image
       
   172     iostream = StringIO(create_schema_graph(*args, **kwargs).create_png())
       
   173     Image.open(iostream).show(command=kwargs.get('command','gwenview'))