Open3D (C++ API)  0.17.0
TensorFlowHelper.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
7
8#pragma once
9#include <tensorflow/core/framework/op_kernel.h>
10#include <tensorflow/core/framework/shape_inference.h>
11#include <tensorflow/core/framework/tensor.h>
12#include <tensorflow/core/lib/core/errors.h>
13
15
16inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
17 ::tensorflow::shape_inference::InferenceContext* c,
18 ::tensorflow::shape_inference::ShapeHandle shape_handle) {
19 using namespace open3d::ml::op_util;
20 if (!c->RankKnown(shape_handle)) {
21 return std::vector<DimValue>();
22 }
23
24 std::vector<DimValue> shape;
25 const int rank = c->Rank(shape_handle);
26 for (int i = 0; i < rank; ++i) {
27 auto d = c->DimKnownRank(shape_handle, i);
28 if (c->ValueKnown(d)) {
29 shape.push_back(c->Value(d));
30 } else {
31 shape.push_back(DimValue());
32 }
33 }
34 return shape;
35}
36
38 class TDimX,
39 class... TArgs>
40std::tuple<bool, std::string> CheckShape(
41 ::tensorflow::shape_inference::InferenceContext* c,
42 ::tensorflow::shape_inference::ShapeHandle shape_handle,
43 TDimX&& dimex,
44 TArgs&&... args) {
45 if (!c->RankKnown(shape_handle)) {
46 // without rank we cannot check
47 return std::make_tuple(true, std::string());
48 }
49 return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(c, shape_handle),
50 std::forward<TDimX>(dimex),
51 std::forward<TArgs>(args)...);
52}
53
54inline std::vector<open3d::ml::op_util::DimValue> GetShapeVector(
55 const tensorflow::Tensor& tensor) {
56 using namespace open3d::ml::op_util;
57
58 std::vector<DimValue> shape;
59 for (int i = 0; i < tensor.dims(); ++i) {
60 shape.push_back(tensor.dim_size(i));
61 }
62 return shape;
63}
64
66 class TDimX,
67 class... TArgs>
68std::tuple<bool, std::string> CheckShape(const tensorflow::Tensor& tensor,
69 TDimX&& dimex,
70 TArgs&&... args) {
71 return open3d::ml::op_util::CheckShape<Opt>(GetShapeVector(tensor),
72 std::forward<TDimX>(dimex),
73 std::forward<TArgs>(args)...);
74}
75
76//
77// Helper function for creating a ShapeHandle from dim expressions.
78// Dim expressions which are not constant will translate to unknown dims in
79// the returned shape handle.
80//
81// Usage:
82// // ctx is of type tensorflow::shape_inference::InferenceContext*
83// {
84// using namespace open3d::ml::op_util;
85// Dim w("w");
86// Dim h("h");
87// CHECK_SHAPE_HANDLE(ctx, handle1, 10, w, h); // checks if the first dim is
88// // 10 and assigns w and h
89// // based on the shape of
90// // handle1
91//
92// CHECK_SHAPE_HANDLE(ctx, handle2, 10, 20, h); // this checks if the the
93// // last dim of handle2 matches the
94// // last dim of handle1. The first
95// // two dims must match 10, 20.
96//
97// ShapeHandle out_shape = MakeShapeHandle(ctx, Dim(), h, w);
98// ctx->set_output(0, out_shape);
99// }
100//
101//
102// See "../ShapeChecking.h" for more info and limitations.
103//
104template <class TDimX, class... TArgs>
105::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(
106 ::tensorflow::shape_inference::InferenceContext* ctx,
107 TDimX&& dimex,
108 TArgs&&... args) {
109 using namespace tensorflow::shape_inference;
110 using namespace open3d::ml::op_util;
111 std::vector<int64_t> shape = CreateDimVector(
112 int64_t(InferenceContext::kUnknownDim), dimex, args...);
113 std::vector<DimensionHandle> dims;
114 for (int64_t d : shape) {
115 dims.push_back(ctx->MakeDim(d));
116 }
117 return ctx->MakeShape(dims);
118}
119
120//
121// Macros for checking the shape of ShapeHandle during shape inference.
122//
123// Usage:
124// // ctx is of type tensorflow::shape_inference::InferenceContext*
125// {
126// using namespace open3d::ml::op_util;
127// Dim w("w");
128// Dim h("h");
129// CHECK_SHAPE_HANDLE(ctx, handle1, 10, w, h); // checks if the first dim is
130// // 10 and assigns w and h
131// // based on the shape of
132// // handle1
133//
134// CHECK_SHAPE_HANDLE(ctx, handle2, 10, 20, h); // this checks if the the
135// // last dim of handle2 matches the
136// // last dim of handle1. The first
137// // two dims must match 10, 20.
138// }
139//
140//
141// See "../ShapeChecking.h" for more info and limitations.
142//
143#define CHECK_SHAPE_HANDLE(ctx, shape_handle, ...) \
144 do { \
145 bool cs_success_; \
146 std::string cs_errstr_; \
147 std::tie(cs_success_, cs_errstr_) = \
148 CheckShape(ctx, shape_handle, __VA_ARGS__); \
149 if (TF_PREDICT_FALSE(!cs_success_)) { \
150 return tensorflow::errors::InvalidArgument( \
151 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
152 } \
153 } while (0)
154
155#define CHECK_SHAPE_HANDLE_COMBINE_FIRST_DIMS(ctx, shape_handle, ...) \
156 do { \
157 bool cs_success_; \
158 std::string cs_errstr_; \
159 std::tie(cs_success_, cs_errstr_) = \
160 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(ctx, shape_handle, \
161 __VA_ARGS__); \
162 if (TF_PREDICT_FALSE(!cs_success_)) { \
163 return tensorflow::errors::InvalidArgument( \
164 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
165 } \
166 } while (0)
167
168#define CHECK_SHAPE_HANDLE_IGNORE_FIRST_DIMS(ctx, shape_handle, ...) \
169 do { \
170 bool cs_success_; \
171 std::string cs_errstr_; \
172 std::tie(cs_success_, cs_errstr_) = \
173 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(ctx, shape_handle, \
174 __VA_ARGS__); \
175 if (TF_PREDICT_FALSE(!cs_success_)) { \
176 return tensorflow::errors::InvalidArgument( \
177 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
178 } \
179 } while (0)
180
181#define CHECK_SHAPE_HANDLE_COMBINE_LAST_DIMS(ctx, shape_handle, ...) \
182 do { \
183 bool cs_success_; \
184 std::string cs_errstr_; \
185 std::tie(cs_success_, cs_errstr_) = \
186 CheckShape<CSOpt::COMBINE_LAST_DIMS>(ctx, shape_handle, \
187 __VA_ARGS__); \
188 if (TF_PREDICT_FALSE(!cs_success_)) { \
189 return tensorflow::errors::InvalidArgument( \
190 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
191 } \
192 } while (0)
193
194#define CHECK_SHAPE_HANDLE_IGNORE_LAST_DIMS(ctx, shape_handle, ...) \
195 do { \
196 bool cs_success_; \
197 std::string cs_errstr_; \
198 std::tie(cs_success_, cs_errstr_) = \
199 CheckShape<CSOpt::IGNORE_LAST_DIMS>(ctx, shape_handle, \
200 __VA_ARGS__); \
201 if (TF_PREDICT_FALSE(!cs_success_)) { \
202 return tensorflow::errors::InvalidArgument( \
203 "invalid shape for '" #shape_handle "', " + cs_errstr_); \
204 } \
205 } while (0)
206
207//
208// Macros for checking the shape of Tensors.
209// Usage:
210// // ctx is of type tensorflow::OpKernelContext*
211// {
212// using namespace open3d::ml::op_util;
213// Dim w("w");
214// Dim h("h");
215// CHECK_SHAPE(ctx, tensor1, 10, w, h); // checks if the first dim is 10
216// // and assigns w and h based on
217// // the shape of tensor1
218//
219// CHECK_SHAPE(ctx, tensor2, 10, 20, h); // this checks if the the last dim
220// // of tensor2 matches the last dim
221// // of tensor1. The first two dims
222// // must match 10, 20.
223// }
224//
225//
226// See "../ShapeChecking.h" for more info and limitations.
227//
228#define CHECK_SHAPE(ctx, tensor, ...) \
229 do { \
230 bool cs_success_; \
231 std::string cs_errstr_; \
232 std::tie(cs_success_, cs_errstr_) = CheckShape(tensor, __VA_ARGS__); \
233 OP_REQUIRES( \
234 ctx, cs_success_, \
235 tensorflow::errors::InvalidArgument( \
236 "invalid shape for '" #tensor "', " + cs_errstr_)); \
237 } while (0)
238
239#define CHECK_SHAPE_COMBINE_FIRST_DIMS(ctx, tensor, ...) \
240 do { \
241 bool cs_success_; \
242 std::string cs_errstr_; \
243 std::tie(cs_success_, cs_errstr_) = \
244 CheckShape<CSOpt::COMBINE_FIRST_DIMS>(tensor, __VA_ARGS__); \
245 OP_REQUIRES( \
246 ctx, cs_success_, \
247 tensorflow::errors::InvalidArgument( \
248 "invalid shape for '" #tensor "', " + cs_errstr_)); \
249 } while (0)
250
251#define CHECK_SHAPE_IGNORE_FIRST_DIMS(ctx, tensor, ...) \
252 do { \
253 bool cs_success_; \
254 std::string cs_errstr_; \
255 std::tie(cs_success_, cs_errstr_) = \
256 CheckShape<CSOpt::IGNORE_FIRST_DIMS>(tensor, __VA_ARGS__); \
257 OP_REQUIRES( \
258 ctx, cs_success_, \
259 tensorflow::errors::InvalidArgument( \
260 "invalid shape for '" #tensor "', " + cs_errstr_)); \
261 } while (0)
262
263#define CHECK_SHAPE_COMBINE_LAST_DIMS(ctx, tensor, ...) \
264 do { \
265 bool cs_success_; \
266 std::string cs_errstr_; \
267 std::tie(cs_success_, cs_errstr_) = \
268 CheckShape<CSOpt::COMBINE_LAST_DIMS>(tensor, __VA_ARGS__); \
269 OP_REQUIRES( \
270 ctx, cs_success_, \
271 tensorflow::errors::InvalidArgument( \
272 "invalid shape for '" #tensor "', " + cs_errstr_)); \
273 } while (0)
274
275#define CHECK_SHAPE_IGNORE_LAST_DIMS(ctx, tensor, ...) \
276 do { \
277 bool cs_success_; \
278 std::string cs_errstr_; \
279 std::tie(cs_success_, cs_errstr_) = \
280 CheckShape<CSOpt::IGNORE_LAST_DIMS>(tensor, __VA_ARGS__); \
281 OP_REQUIRES( \
282 ctx, cs_success_, \
283 tensorflow::errors::InvalidArgument( \
284 "invalid shape for '" #tensor "', " + cs_errstr_)); \
285 } while (0)
std::vector< open3d::ml::op_util::DimValue > GetShapeVector(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle)
Definition: TensorFlowHelper.h:16
std::tuple< bool, std::string > CheckShape(::tensorflow::shape_inference::InferenceContext *c, ::tensorflow::shape_inference::ShapeHandle shape_handle, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:40
::tensorflow::shape_inference::ShapeHandle MakeShapeHandle(::tensorflow::shape_inference::InferenceContext *ctx, TDimX &&dimex, TArgs &&... args)
Definition: TensorFlowHelper.h:105
Class for representing a possibly unknown dimension value.
Definition: ShapeChecking.h:19
Definition: ShapeChecking.h:16
CSOpt
Check shape options.
Definition: ShapeChecking.h:405
void CreateDimVector(std::vector< int64_t > &out, int64_t unknown_dim_value, TDimX dimex)
Definition: ShapeChecking.h:358