Note
Click here to download the full example code
Automatically use a function at insert or selectΒΆ
Sometimes the application wants to apply a function in an insert or in a select.
For example, the application might need the geometry with lat/lon coordinates while they
are projected in the DB. To avoid having to always tweak the query with a
ST_Transform()
, it is possible to define a TypeDecorator
11 from sqlalchemy import Column
12 from sqlalchemy import Integer
13 from sqlalchemy import MetaData
14 from sqlalchemy import func
15 from sqlalchemy import text
16 from sqlalchemy.ext.declarative import declarative_base
17 from sqlalchemy.types import TypeDecorator
18
19 from geoalchemy2 import Geometry
20 from geoalchemy2 import shape
21
22 # Tests imports
23 from tests import test_only_with_dialects
24
25 metadata = MetaData()
26
27 Base = declarative_base(metadata=metadata)
28
29
30 class TransformedGeometry(TypeDecorator):
31 """This class is used to insert a ST_Transform() in each insert or select."""
32 impl = Geometry
33
34 def __init__(self, db_srid, app_srid, **kwargs):
35 kwargs["srid"] = db_srid
36 self.impl = self.__class__.impl(**kwargs)
37 self.app_srid = app_srid
38 self.db_srid = db_srid
39
40 def column_expression(self, col):
41 """The column_expression() method is overridden to ensure that the
42 SRID of the resulting WKBElement is correct"""
43 return getattr(func, self.impl.as_binary)(
44 func.ST_Transform(col, self.app_srid),
45 type_=self.__class__.impl(srid=self.app_srid)
46 # srid could also be -1 so that the SRID is deduced from the
47 # WKB data
48 )
49
50 def bind_expression(self, bindvalue):
51 return func.ST_Transform(
52 self.impl.bind_expression(bindvalue), self.db_srid)
53
54
55 class ThreeDGeometry(TypeDecorator):
56 """This class is used to insert a ST_Force3D() in each insert."""
57 impl = Geometry
58
59 def bind_expression(self, bindvalue):
60 return func.ST_Force3D(self.impl.bind_expression(bindvalue))
61
62
63 class Point(Base):
64 __tablename__ = "point"
65 id = Column(Integer, primary_key=True)
66 raw_geom = Column(Geometry(srid=4326, geometry_type="POINT"))
67 geom = Column(
68 TransformedGeometry(
69 db_srid=2154, app_srid=4326, geometry_type="POINT"))
70 three_d_geom = Column(
71 ThreeDGeometry(srid=4326, geometry_type="POINTZ", dimension=3))
72
73
74 def check_wkb(wkb, x, y):
75 pt = shape.to_shape(wkb)
76 assert round(pt.x, 5) == x
77 assert round(pt.y, 5) == y
78
79
80 @test_only_with_dialects("postgresql")
81 class TestTypeDecorator():
82
83 def _create_one_point(self, session, conn):
84 metadata.drop_all(conn, checkfirst=True)
85 metadata.create_all(conn)
86
87 # Create new point instance
88 p = Point()
89 p.raw_geom = "SRID=4326;POINT(5 45)"
90 p.geom = "SRID=4326;POINT(5 45)"
91 p.three_d_geom = "SRID=4326;POINT(5 45)" # Insert 2D geometry into 3D column
92
93 # Insert point
94 session.add(p)
95 session.flush()
96 session.expire(p)
97
98 return p.id
99
100 def test_transform(self, session, conn):
101 self._create_one_point(session, conn)
102
103 # Query the point and check the result
104 pt = session.query(Point).one()
105 assert pt.id == 1
106 assert pt.raw_geom.srid == 4326
107 check_wkb(pt.raw_geom, 5, 45)
108
109 assert pt.geom.srid == 4326
110 check_wkb(pt.geom, 5, 45)
111
112 # Check that the data is correct in DB using raw query
113 q = text("SELECT id, ST_AsEWKT(geom) AS geom FROM point;")
114 res_q = session.execute(q).fetchone()
115 assert res_q.id == 1
116 assert res_q.geom == "SRID=2154;POINT(857581.899319668 6435414.7478354)"
117
118 # Compare geom, raw_geom with auto transform and explicit transform
119 pt_trans = session.query(
120 Point,
121 Point.raw_geom,
122 func.ST_Transform(Point.raw_geom, 2154).label("trans")
123 ).one()
124
125 assert pt_trans[0].id == 1
126
127 assert pt_trans[0].geom.srid == 4326
128 check_wkb(pt_trans[0].geom, 5, 45)
129
130 assert pt_trans[0].raw_geom.srid == 4326
131 check_wkb(pt_trans[0].raw_geom, 5, 45)
132
133 assert pt_trans[1].srid == 4326
134 check_wkb(pt_trans[1], 5, 45)
135
136 assert pt_trans[2].srid == 2154
137 check_wkb(pt_trans[2], 857581.89932, 6435414.74784)
138
139 def test_force_3d(self, session, conn):
140 self._create_one_point(session, conn)
141
142 # Query the point and check the result
143 pt = session.query(Point).one()
144
145 assert pt.id == 1
146 assert pt.three_d_geom.srid == 4326
147 assert pt.three_d_geom.desc.lower() == (
148 '01010000a0e6100000000000000000144000000000008046400000000000000000')