Z3Solver.cpp revision 360784
1//== Z3Solver.cpp -----------------------------------------------*- C++ -*--==//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/ADT/Twine.h"
10#include "llvm/Config/config.h"
11#include "llvm/Support/SMTAPI.h"
12#include <set>
13
14using namespace llvm;
15
16#if LLVM_WITH_Z3
17
18#include <z3.h>
19
20namespace {
21
22/// Configuration class for Z3
23class Z3Config {
24  friend class Z3Context;
25
26  Z3_config Config;
27
28public:
29  Z3Config() : Config(Z3_mk_config()) {
30    // Enable model finding
31    Z3_set_param_value(Config, "model", "true");
32    // Disable proof generation
33    Z3_set_param_value(Config, "proof", "false");
34    // Set timeout to 15000ms = 15s
35    Z3_set_param_value(Config, "timeout", "15000");
36  }
37
38  ~Z3Config() { Z3_del_config(Config); }
39}; // end class Z3Config
40
41// Function used to report errors
42void Z3ErrorHandler(Z3_context Context, Z3_error_code Error) {
43  llvm::report_fatal_error("Z3 error: " +
44                           llvm::Twine(Z3_get_error_msg(Context, Error)));
45}
46
47/// Wrapper for Z3 context
48class Z3Context {
49public:
50  Z3_context Context;
51
52  Z3Context() {
53    Context = Z3_mk_context_rc(Z3Config().Config);
54    // The error function is set here because the context is the first object
55    // created by the backend
56    Z3_set_error_handler(Context, Z3ErrorHandler);
57  }
58
59  virtual ~Z3Context() {
60    Z3_del_context(Context);
61    Context = nullptr;
62  }
63}; // end class Z3Context
64
65/// Wrapper for Z3 Sort
66class Z3Sort : public SMTSort {
67  friend class Z3Solver;
68
69  Z3Context &Context;
70
71  Z3_sort Sort;
72
73public:
74  /// Default constructor, mainly used by make_shared
75  Z3Sort(Z3Context &C, Z3_sort ZS) : Context(C), Sort(ZS) {
76    Z3_inc_ref(Context.Context, reinterpret_cast<Z3_ast>(Sort));
77  }
78
79  /// Override implicit copy constructor for correct reference counting.
80  Z3Sort(const Z3Sort &Other) : Context(Other.Context), Sort(Other.Sort) {
81    Z3_inc_ref(Context.Context, reinterpret_cast<Z3_ast>(Sort));
82  }
83
84  /// Override implicit copy assignment constructor for correct reference
85  /// counting.
86  Z3Sort &operator=(const Z3Sort &Other) {
87    Z3_inc_ref(Context.Context, reinterpret_cast<Z3_ast>(Other.Sort));
88    Z3_dec_ref(Context.Context, reinterpret_cast<Z3_ast>(Sort));
89    Sort = Other.Sort;
90    return *this;
91  }
92
93  Z3Sort(Z3Sort &&Other) = delete;
94  Z3Sort &operator=(Z3Sort &&Other) = delete;
95
96  ~Z3Sort() {
97    if (Sort)
98      Z3_dec_ref(Context.Context, reinterpret_cast<Z3_ast>(Sort));
99  }
100
101  void Profile(llvm::FoldingSetNodeID &ID) const override {
102    ID.AddInteger(
103        Z3_get_ast_id(Context.Context, reinterpret_cast<Z3_ast>(Sort)));
104  }
105
106  bool isBitvectorSortImpl() const override {
107    return (Z3_get_sort_kind(Context.Context, Sort) == Z3_BV_SORT);
108  }
109
110  bool isFloatSortImpl() const override {
111    return (Z3_get_sort_kind(Context.Context, Sort) == Z3_FLOATING_POINT_SORT);
112  }
113
114  bool isBooleanSortImpl() const override {
115    return (Z3_get_sort_kind(Context.Context, Sort) == Z3_BOOL_SORT);
116  }
117
118  unsigned getBitvectorSortSizeImpl() const override {
119    return Z3_get_bv_sort_size(Context.Context, Sort);
120  }
121
122  unsigned getFloatSortSizeImpl() const override {
123    return Z3_fpa_get_ebits(Context.Context, Sort) +
124           Z3_fpa_get_sbits(Context.Context, Sort);
125  }
126
127  bool equal_to(SMTSort const &Other) const override {
128    return Z3_is_eq_sort(Context.Context, Sort,
129                         static_cast<const Z3Sort &>(Other).Sort);
130  }
131
132  void print(raw_ostream &OS) const override {
133    OS << Z3_sort_to_string(Context.Context, Sort);
134  }
135}; // end class Z3Sort
136
137static const Z3Sort &toZ3Sort(const SMTSort &S) {
138  return static_cast<const Z3Sort &>(S);
139}
140
141class Z3Expr : public SMTExpr {
142  friend class Z3Solver;
143
144  Z3Context &Context;
145
146  Z3_ast AST;
147
148public:
149  Z3Expr(Z3Context &C, Z3_ast ZA) : SMTExpr(), Context(C), AST(ZA) {
150    Z3_inc_ref(Context.Context, AST);
151  }
152
153  /// Override implicit copy constructor for correct reference counting.
154  Z3Expr(const Z3Expr &Copy) : SMTExpr(), Context(Copy.Context), AST(Copy.AST) {
155    Z3_inc_ref(Context.Context, AST);
156  }
157
158  /// Override implicit copy assignment constructor for correct reference
159  /// counting.
160  Z3Expr &operator=(const Z3Expr &Other) {
161    Z3_inc_ref(Context.Context, Other.AST);
162    Z3_dec_ref(Context.Context, AST);
163    AST = Other.AST;
164    return *this;
165  }
166
167  Z3Expr(Z3Expr &&Other) = delete;
168  Z3Expr &operator=(Z3Expr &&Other) = delete;
169
170  ~Z3Expr() {
171    if (AST)
172      Z3_dec_ref(Context.Context, AST);
173  }
174
175  void Profile(llvm::FoldingSetNodeID &ID) const override {
176    ID.AddInteger(Z3_get_ast_id(Context.Context, AST));
177  }
178
179  /// Comparison of AST equality, not model equivalence.
180  bool equal_to(SMTExpr const &Other) const override {
181    assert(Z3_is_eq_sort(Context.Context, Z3_get_sort(Context.Context, AST),
182                         Z3_get_sort(Context.Context,
183                                     static_cast<const Z3Expr &>(Other).AST)) &&
184           "AST's must have the same sort");
185    return Z3_is_eq_ast(Context.Context, AST,
186                        static_cast<const Z3Expr &>(Other).AST);
187  }
188
189  void print(raw_ostream &OS) const override {
190    OS << Z3_ast_to_string(Context.Context, AST);
191  }
192}; // end class Z3Expr
193
194static const Z3Expr &toZ3Expr(const SMTExpr &E) {
195  return static_cast<const Z3Expr &>(E);
196}
197
198class Z3Model {
199  friend class Z3Solver;
200
201  Z3Context &Context;
202
203  Z3_model Model;
204
205public:
206  Z3Model(Z3Context &C, Z3_model ZM) : Context(C), Model(ZM) {
207    Z3_model_inc_ref(Context.Context, Model);
208  }
209
210  Z3Model(const Z3Model &Other) = delete;
211  Z3Model(Z3Model &&Other) = delete;
212  Z3Model &operator=(Z3Model &Other) = delete;
213  Z3Model &operator=(Z3Model &&Other) = delete;
214
215  ~Z3Model() {
216    if (Model)
217      Z3_model_dec_ref(Context.Context, Model);
218  }
219
220  void print(raw_ostream &OS) const {
221    OS << Z3_model_to_string(Context.Context, Model);
222  }
223
224  LLVM_DUMP_METHOD void dump() const { print(llvm::errs()); }
225}; // end class Z3Model
226
227/// Get the corresponding IEEE floating-point type for a given bitwidth.
228static const llvm::fltSemantics &getFloatSemantics(unsigned BitWidth) {
229  switch (BitWidth) {
230  default:
231    llvm_unreachable("Unsupported floating-point semantics!");
232    break;
233  case 16:
234    return llvm::APFloat::IEEEhalf();
235  case 32:
236    return llvm::APFloat::IEEEsingle();
237  case 64:
238    return llvm::APFloat::IEEEdouble();
239  case 128:
240    return llvm::APFloat::IEEEquad();
241  }
242}
243
244// Determine whether two float semantics are equivalent
245static bool areEquivalent(const llvm::fltSemantics &LHS,
246                          const llvm::fltSemantics &RHS) {
247  return (llvm::APFloat::semanticsPrecision(LHS) ==
248          llvm::APFloat::semanticsPrecision(RHS)) &&
249         (llvm::APFloat::semanticsMinExponent(LHS) ==
250          llvm::APFloat::semanticsMinExponent(RHS)) &&
251         (llvm::APFloat::semanticsMaxExponent(LHS) ==
252          llvm::APFloat::semanticsMaxExponent(RHS)) &&
253         (llvm::APFloat::semanticsSizeInBits(LHS) ==
254          llvm::APFloat::semanticsSizeInBits(RHS));
255}
256
257class Z3Solver : public SMTSolver {
258  friend class Z3ConstraintManager;
259
260  Z3Context Context;
261
262  Z3_solver Solver;
263
264  // Cache Sorts
265  std::set<Z3Sort> CachedSorts;
266
267  // Cache Exprs
268  std::set<Z3Expr> CachedExprs;
269
270public:
271  Z3Solver() : Solver(Z3_mk_simple_solver(Context.Context)) {
272    Z3_solver_inc_ref(Context.Context, Solver);
273  }
274
275  Z3Solver(const Z3Solver &Other) = delete;
276  Z3Solver(Z3Solver &&Other) = delete;
277  Z3Solver &operator=(Z3Solver &Other) = delete;
278  Z3Solver &operator=(Z3Solver &&Other) = delete;
279
280  ~Z3Solver() {
281    if (Solver)
282      Z3_solver_dec_ref(Context.Context, Solver);
283  }
284
285  void addConstraint(const SMTExprRef &Exp) const override {
286    Z3_solver_assert(Context.Context, Solver, toZ3Expr(*Exp).AST);
287  }
288
289  // Given an SMTSort, adds/retrives it from the cache and returns
290  // an SMTSortRef to the SMTSort in the cache
291  SMTSortRef newSortRef(const SMTSort &Sort) {
292    auto It = CachedSorts.insert(toZ3Sort(Sort));
293    return &(*It.first);
294  }
295
296  // Given an SMTExpr, adds/retrives it from the cache and returns
297  // an SMTExprRef to the SMTExpr in the cache
298  SMTExprRef newExprRef(const SMTExpr &Exp) {
299    auto It = CachedExprs.insert(toZ3Expr(Exp));
300    return &(*It.first);
301  }
302
303  SMTSortRef getBoolSort() override {
304    return newSortRef(Z3Sort(Context, Z3_mk_bool_sort(Context.Context)));
305  }
306
307  SMTSortRef getBitvectorSort(unsigned BitWidth) override {
308    return newSortRef(
309        Z3Sort(Context, Z3_mk_bv_sort(Context.Context, BitWidth)));
310  }
311
312  SMTSortRef getSort(const SMTExprRef &Exp) override {
313    return newSortRef(
314        Z3Sort(Context, Z3_get_sort(Context.Context, toZ3Expr(*Exp).AST)));
315  }
316
317  SMTSortRef getFloat16Sort() override {
318    return newSortRef(Z3Sort(Context, Z3_mk_fpa_sort_16(Context.Context)));
319  }
320
321  SMTSortRef getFloat32Sort() override {
322    return newSortRef(Z3Sort(Context, Z3_mk_fpa_sort_32(Context.Context)));
323  }
324
325  SMTSortRef getFloat64Sort() override {
326    return newSortRef(Z3Sort(Context, Z3_mk_fpa_sort_64(Context.Context)));
327  }
328
329  SMTSortRef getFloat128Sort() override {
330    return newSortRef(Z3Sort(Context, Z3_mk_fpa_sort_128(Context.Context)));
331  }
332
333  SMTExprRef mkBVNeg(const SMTExprRef &Exp) override {
334    return newExprRef(
335        Z3Expr(Context, Z3_mk_bvneg(Context.Context, toZ3Expr(*Exp).AST)));
336  }
337
338  SMTExprRef mkBVNot(const SMTExprRef &Exp) override {
339    return newExprRef(
340        Z3Expr(Context, Z3_mk_bvnot(Context.Context, toZ3Expr(*Exp).AST)));
341  }
342
343  SMTExprRef mkNot(const SMTExprRef &Exp) override {
344    return newExprRef(
345        Z3Expr(Context, Z3_mk_not(Context.Context, toZ3Expr(*Exp).AST)));
346  }
347
348  SMTExprRef mkBVAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
349    return newExprRef(
350        Z3Expr(Context, Z3_mk_bvadd(Context.Context, toZ3Expr(*LHS).AST,
351                                    toZ3Expr(*RHS).AST)));
352  }
353
354  SMTExprRef mkBVSub(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
355    return newExprRef(
356        Z3Expr(Context, Z3_mk_bvsub(Context.Context, toZ3Expr(*LHS).AST,
357                                    toZ3Expr(*RHS).AST)));
358  }
359
360  SMTExprRef mkBVMul(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
361    return newExprRef(
362        Z3Expr(Context, Z3_mk_bvmul(Context.Context, toZ3Expr(*LHS).AST,
363                                    toZ3Expr(*RHS).AST)));
364  }
365
366  SMTExprRef mkBVSRem(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
367    return newExprRef(
368        Z3Expr(Context, Z3_mk_bvsrem(Context.Context, toZ3Expr(*LHS).AST,
369                                     toZ3Expr(*RHS).AST)));
370  }
371
372  SMTExprRef mkBVURem(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
373    return newExprRef(
374        Z3Expr(Context, Z3_mk_bvurem(Context.Context, toZ3Expr(*LHS).AST,
375                                     toZ3Expr(*RHS).AST)));
376  }
377
378  SMTExprRef mkBVSDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
379    return newExprRef(
380        Z3Expr(Context, Z3_mk_bvsdiv(Context.Context, toZ3Expr(*LHS).AST,
381                                     toZ3Expr(*RHS).AST)));
382  }
383
384  SMTExprRef mkBVUDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
385    return newExprRef(
386        Z3Expr(Context, Z3_mk_bvudiv(Context.Context, toZ3Expr(*LHS).AST,
387                                     toZ3Expr(*RHS).AST)));
388  }
389
390  SMTExprRef mkBVShl(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
391    return newExprRef(
392        Z3Expr(Context, Z3_mk_bvshl(Context.Context, toZ3Expr(*LHS).AST,
393                                    toZ3Expr(*RHS).AST)));
394  }
395
396  SMTExprRef mkBVAshr(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
397    return newExprRef(
398        Z3Expr(Context, Z3_mk_bvashr(Context.Context, toZ3Expr(*LHS).AST,
399                                     toZ3Expr(*RHS).AST)));
400  }
401
402  SMTExprRef mkBVLshr(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
403    return newExprRef(
404        Z3Expr(Context, Z3_mk_bvlshr(Context.Context, toZ3Expr(*LHS).AST,
405                                     toZ3Expr(*RHS).AST)));
406  }
407
408  SMTExprRef mkBVXor(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
409    return newExprRef(
410        Z3Expr(Context, Z3_mk_bvxor(Context.Context, toZ3Expr(*LHS).AST,
411                                    toZ3Expr(*RHS).AST)));
412  }
413
414  SMTExprRef mkBVOr(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
415    return newExprRef(
416        Z3Expr(Context, Z3_mk_bvor(Context.Context, toZ3Expr(*LHS).AST,
417                                   toZ3Expr(*RHS).AST)));
418  }
419
420  SMTExprRef mkBVAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
421    return newExprRef(
422        Z3Expr(Context, Z3_mk_bvand(Context.Context, toZ3Expr(*LHS).AST,
423                                    toZ3Expr(*RHS).AST)));
424  }
425
426  SMTExprRef mkBVUlt(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
427    return newExprRef(
428        Z3Expr(Context, Z3_mk_bvult(Context.Context, toZ3Expr(*LHS).AST,
429                                    toZ3Expr(*RHS).AST)));
430  }
431
432  SMTExprRef mkBVSlt(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
433    return newExprRef(
434        Z3Expr(Context, Z3_mk_bvslt(Context.Context, toZ3Expr(*LHS).AST,
435                                    toZ3Expr(*RHS).AST)));
436  }
437
438  SMTExprRef mkBVUgt(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
439    return newExprRef(
440        Z3Expr(Context, Z3_mk_bvugt(Context.Context, toZ3Expr(*LHS).AST,
441                                    toZ3Expr(*RHS).AST)));
442  }
443
444  SMTExprRef mkBVSgt(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
445    return newExprRef(
446        Z3Expr(Context, Z3_mk_bvsgt(Context.Context, toZ3Expr(*LHS).AST,
447                                    toZ3Expr(*RHS).AST)));
448  }
449
450  SMTExprRef mkBVUle(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
451    return newExprRef(
452        Z3Expr(Context, Z3_mk_bvule(Context.Context, toZ3Expr(*LHS).AST,
453                                    toZ3Expr(*RHS).AST)));
454  }
455
456  SMTExprRef mkBVSle(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
457    return newExprRef(
458        Z3Expr(Context, Z3_mk_bvsle(Context.Context, toZ3Expr(*LHS).AST,
459                                    toZ3Expr(*RHS).AST)));
460  }
461
462  SMTExprRef mkBVUge(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
463    return newExprRef(
464        Z3Expr(Context, Z3_mk_bvuge(Context.Context, toZ3Expr(*LHS).AST,
465                                    toZ3Expr(*RHS).AST)));
466  }
467
468  SMTExprRef mkBVSge(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
469    return newExprRef(
470        Z3Expr(Context, Z3_mk_bvsge(Context.Context, toZ3Expr(*LHS).AST,
471                                    toZ3Expr(*RHS).AST)));
472  }
473
474  SMTExprRef mkAnd(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
475    Z3_ast Args[2] = {toZ3Expr(*LHS).AST, toZ3Expr(*RHS).AST};
476    return newExprRef(Z3Expr(Context, Z3_mk_and(Context.Context, 2, Args)));
477  }
478
479  SMTExprRef mkOr(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
480    Z3_ast Args[2] = {toZ3Expr(*LHS).AST, toZ3Expr(*RHS).AST};
481    return newExprRef(Z3Expr(Context, Z3_mk_or(Context.Context, 2, Args)));
482  }
483
484  SMTExprRef mkEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
485    return newExprRef(
486        Z3Expr(Context, Z3_mk_eq(Context.Context, toZ3Expr(*LHS).AST,
487                                 toZ3Expr(*RHS).AST)));
488  }
489
490  SMTExprRef mkFPNeg(const SMTExprRef &Exp) override {
491    return newExprRef(
492        Z3Expr(Context, Z3_mk_fpa_neg(Context.Context, toZ3Expr(*Exp).AST)));
493  }
494
495  SMTExprRef mkFPIsInfinite(const SMTExprRef &Exp) override {
496    return newExprRef(Z3Expr(
497        Context, Z3_mk_fpa_is_infinite(Context.Context, toZ3Expr(*Exp).AST)));
498  }
499
500  SMTExprRef mkFPIsNaN(const SMTExprRef &Exp) override {
501    return newExprRef(
502        Z3Expr(Context, Z3_mk_fpa_is_nan(Context.Context, toZ3Expr(*Exp).AST)));
503  }
504
505  SMTExprRef mkFPIsNormal(const SMTExprRef &Exp) override {
506    return newExprRef(Z3Expr(
507        Context, Z3_mk_fpa_is_normal(Context.Context, toZ3Expr(*Exp).AST)));
508  }
509
510  SMTExprRef mkFPIsZero(const SMTExprRef &Exp) override {
511    return newExprRef(Z3Expr(
512        Context, Z3_mk_fpa_is_zero(Context.Context, toZ3Expr(*Exp).AST)));
513  }
514
515  SMTExprRef mkFPMul(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
516    SMTExprRef RoundingMode = getFloatRoundingMode();
517    return newExprRef(
518        Z3Expr(Context,
519               Z3_mk_fpa_mul(Context.Context, toZ3Expr(*LHS).AST,
520                             toZ3Expr(*RHS).AST, toZ3Expr(*RoundingMode).AST)));
521  }
522
523  SMTExprRef mkFPDiv(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
524    SMTExprRef RoundingMode = getFloatRoundingMode();
525    return newExprRef(
526        Z3Expr(Context,
527               Z3_mk_fpa_div(Context.Context, toZ3Expr(*LHS).AST,
528                             toZ3Expr(*RHS).AST, toZ3Expr(*RoundingMode).AST)));
529  }
530
531  SMTExprRef mkFPRem(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
532    return newExprRef(
533        Z3Expr(Context, Z3_mk_fpa_rem(Context.Context, toZ3Expr(*LHS).AST,
534                                      toZ3Expr(*RHS).AST)));
535  }
536
537  SMTExprRef mkFPAdd(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
538    SMTExprRef RoundingMode = getFloatRoundingMode();
539    return newExprRef(
540        Z3Expr(Context,
541               Z3_mk_fpa_add(Context.Context, toZ3Expr(*LHS).AST,
542                             toZ3Expr(*RHS).AST, toZ3Expr(*RoundingMode).AST)));
543  }
544
545  SMTExprRef mkFPSub(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
546    SMTExprRef RoundingMode = getFloatRoundingMode();
547    return newExprRef(
548        Z3Expr(Context,
549               Z3_mk_fpa_sub(Context.Context, toZ3Expr(*LHS).AST,
550                             toZ3Expr(*RHS).AST, toZ3Expr(*RoundingMode).AST)));
551  }
552
553  SMTExprRef mkFPLt(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
554    return newExprRef(
555        Z3Expr(Context, Z3_mk_fpa_lt(Context.Context, toZ3Expr(*LHS).AST,
556                                     toZ3Expr(*RHS).AST)));
557  }
558
559  SMTExprRef mkFPGt(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
560    return newExprRef(
561        Z3Expr(Context, Z3_mk_fpa_gt(Context.Context, toZ3Expr(*LHS).AST,
562                                     toZ3Expr(*RHS).AST)));
563  }
564
565  SMTExprRef mkFPLe(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
566    return newExprRef(
567        Z3Expr(Context, Z3_mk_fpa_leq(Context.Context, toZ3Expr(*LHS).AST,
568                                      toZ3Expr(*RHS).AST)));
569  }
570
571  SMTExprRef mkFPGe(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
572    return newExprRef(
573        Z3Expr(Context, Z3_mk_fpa_geq(Context.Context, toZ3Expr(*LHS).AST,
574                                      toZ3Expr(*RHS).AST)));
575  }
576
577  SMTExprRef mkFPEqual(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
578    return newExprRef(
579        Z3Expr(Context, Z3_mk_fpa_eq(Context.Context, toZ3Expr(*LHS).AST,
580                                     toZ3Expr(*RHS).AST)));
581  }
582
583  SMTExprRef mkIte(const SMTExprRef &Cond, const SMTExprRef &T,
584                   const SMTExprRef &F) override {
585    return newExprRef(
586        Z3Expr(Context, Z3_mk_ite(Context.Context, toZ3Expr(*Cond).AST,
587                                  toZ3Expr(*T).AST, toZ3Expr(*F).AST)));
588  }
589
590  SMTExprRef mkBVSignExt(unsigned i, const SMTExprRef &Exp) override {
591    return newExprRef(Z3Expr(
592        Context, Z3_mk_sign_ext(Context.Context, i, toZ3Expr(*Exp).AST)));
593  }
594
595  SMTExprRef mkBVZeroExt(unsigned i, const SMTExprRef &Exp) override {
596    return newExprRef(Z3Expr(
597        Context, Z3_mk_zero_ext(Context.Context, i, toZ3Expr(*Exp).AST)));
598  }
599
600  SMTExprRef mkBVExtract(unsigned High, unsigned Low,
601                         const SMTExprRef &Exp) override {
602    return newExprRef(Z3Expr(Context, Z3_mk_extract(Context.Context, High, Low,
603                                                    toZ3Expr(*Exp).AST)));
604  }
605
606  /// Creates a predicate that checks for overflow in a bitvector addition
607  /// operation
608  SMTExprRef mkBVAddNoOverflow(const SMTExprRef &LHS, const SMTExprRef &RHS,
609                               bool isSigned) override {
610    return newExprRef(Z3Expr(
611        Context, Z3_mk_bvadd_no_overflow(Context.Context, toZ3Expr(*LHS).AST,
612                                         toZ3Expr(*RHS).AST, isSigned)));
613  }
614
615  /// Creates a predicate that checks for underflow in a signed bitvector
616  /// addition operation
617  SMTExprRef mkBVAddNoUnderflow(const SMTExprRef &LHS,
618                                const SMTExprRef &RHS) override {
619    return newExprRef(Z3Expr(
620        Context, Z3_mk_bvadd_no_underflow(Context.Context, toZ3Expr(*LHS).AST,
621                                          toZ3Expr(*RHS).AST)));
622  }
623
624  /// Creates a predicate that checks for overflow in a signed bitvector
625  /// subtraction operation
626  SMTExprRef mkBVSubNoOverflow(const SMTExprRef &LHS,
627                               const SMTExprRef &RHS) override {
628    return newExprRef(Z3Expr(
629        Context, Z3_mk_bvsub_no_overflow(Context.Context, toZ3Expr(*LHS).AST,
630                                         toZ3Expr(*RHS).AST)));
631  }
632
633  /// Creates a predicate that checks for underflow in a bitvector subtraction
634  /// operation
635  SMTExprRef mkBVSubNoUnderflow(const SMTExprRef &LHS, const SMTExprRef &RHS,
636                                bool isSigned) override {
637    return newExprRef(Z3Expr(
638        Context, Z3_mk_bvsub_no_underflow(Context.Context, toZ3Expr(*LHS).AST,
639                                          toZ3Expr(*RHS).AST, isSigned)));
640  }
641
642  /// Creates a predicate that checks for overflow in a signed bitvector
643  /// division/modulus operation
644  SMTExprRef mkBVSDivNoOverflow(const SMTExprRef &LHS,
645                                const SMTExprRef &RHS) override {
646    return newExprRef(Z3Expr(
647        Context, Z3_mk_bvsdiv_no_overflow(Context.Context, toZ3Expr(*LHS).AST,
648                                          toZ3Expr(*RHS).AST)));
649  }
650
651  /// Creates a predicate that checks for overflow in a bitvector negation
652  /// operation
653  SMTExprRef mkBVNegNoOverflow(const SMTExprRef &Exp) override {
654    return newExprRef(Z3Expr(
655        Context, Z3_mk_bvneg_no_overflow(Context.Context, toZ3Expr(*Exp).AST)));
656  }
657
658  /// Creates a predicate that checks for overflow in a bitvector multiplication
659  /// operation
660  SMTExprRef mkBVMulNoOverflow(const SMTExprRef &LHS, const SMTExprRef &RHS,
661                               bool isSigned) override {
662    return newExprRef(Z3Expr(
663        Context, Z3_mk_bvmul_no_overflow(Context.Context, toZ3Expr(*LHS).AST,
664                                         toZ3Expr(*RHS).AST, isSigned)));
665  }
666
667  /// Creates a predicate that checks for underflow in a signed bitvector
668  /// multiplication operation
669  SMTExprRef mkBVMulNoUnderflow(const SMTExprRef &LHS,
670                                const SMTExprRef &RHS) override {
671    return newExprRef(Z3Expr(
672        Context, Z3_mk_bvmul_no_underflow(Context.Context, toZ3Expr(*LHS).AST,
673                                          toZ3Expr(*RHS).AST)));
674  }
675
676  SMTExprRef mkBVConcat(const SMTExprRef &LHS, const SMTExprRef &RHS) override {
677    return newExprRef(
678        Z3Expr(Context, Z3_mk_concat(Context.Context, toZ3Expr(*LHS).AST,
679                                     toZ3Expr(*RHS).AST)));
680  }
681
682  SMTExprRef mkFPtoFP(const SMTExprRef &From, const SMTSortRef &To) override {
683    SMTExprRef RoundingMode = getFloatRoundingMode();
684    return newExprRef(Z3Expr(
685        Context,
686        Z3_mk_fpa_to_fp_float(Context.Context, toZ3Expr(*RoundingMode).AST,
687                              toZ3Expr(*From).AST, toZ3Sort(*To).Sort)));
688  }
689
690  SMTExprRef mkSBVtoFP(const SMTExprRef &From, const SMTSortRef &To) override {
691    SMTExprRef RoundingMode = getFloatRoundingMode();
692    return newExprRef(Z3Expr(
693        Context,
694        Z3_mk_fpa_to_fp_signed(Context.Context, toZ3Expr(*RoundingMode).AST,
695                               toZ3Expr(*From).AST, toZ3Sort(*To).Sort)));
696  }
697
698  SMTExprRef mkUBVtoFP(const SMTExprRef &From, const SMTSortRef &To) override {
699    SMTExprRef RoundingMode = getFloatRoundingMode();
700    return newExprRef(Z3Expr(
701        Context,
702        Z3_mk_fpa_to_fp_unsigned(Context.Context, toZ3Expr(*RoundingMode).AST,
703                                 toZ3Expr(*From).AST, toZ3Sort(*To).Sort)));
704  }
705
706  SMTExprRef mkFPtoSBV(const SMTExprRef &From, unsigned ToWidth) override {
707    SMTExprRef RoundingMode = getFloatRoundingMode();
708    return newExprRef(Z3Expr(
709        Context, Z3_mk_fpa_to_sbv(Context.Context, toZ3Expr(*RoundingMode).AST,
710                                  toZ3Expr(*From).AST, ToWidth)));
711  }
712
713  SMTExprRef mkFPtoUBV(const SMTExprRef &From, unsigned ToWidth) override {
714    SMTExprRef RoundingMode = getFloatRoundingMode();
715    return newExprRef(Z3Expr(
716        Context, Z3_mk_fpa_to_ubv(Context.Context, toZ3Expr(*RoundingMode).AST,
717                                  toZ3Expr(*From).AST, ToWidth)));
718  }
719
720  SMTExprRef mkBoolean(const bool b) override {
721    return newExprRef(Z3Expr(Context, b ? Z3_mk_true(Context.Context)
722                                        : Z3_mk_false(Context.Context)));
723  }
724
725  SMTExprRef mkBitvector(const llvm::APSInt Int, unsigned BitWidth) override {
726    const SMTSortRef Sort = getBitvectorSort(BitWidth);
727    return newExprRef(
728        Z3Expr(Context, Z3_mk_numeral(Context.Context, Int.toString(10).c_str(),
729                                      toZ3Sort(*Sort).Sort)));
730  }
731
732  SMTExprRef mkFloat(const llvm::APFloat Float) override {
733    SMTSortRef Sort =
734        getFloatSort(llvm::APFloat::semanticsSizeInBits(Float.getSemantics()));
735
736    llvm::APSInt Int = llvm::APSInt(Float.bitcastToAPInt(), false);
737    SMTExprRef Z3Int = mkBitvector(Int, Int.getBitWidth());
738    return newExprRef(Z3Expr(
739        Context, Z3_mk_fpa_to_fp_bv(Context.Context, toZ3Expr(*Z3Int).AST,
740                                    toZ3Sort(*Sort).Sort)));
741  }
742
743  SMTExprRef mkSymbol(const char *Name, SMTSortRef Sort) override {
744    return newExprRef(
745        Z3Expr(Context, Z3_mk_const(Context.Context,
746                                    Z3_mk_string_symbol(Context.Context, Name),
747                                    toZ3Sort(*Sort).Sort)));
748  }
749
750  llvm::APSInt getBitvector(const SMTExprRef &Exp, unsigned BitWidth,
751                            bool isUnsigned) override {
752    return llvm::APSInt(
753        llvm::APInt(BitWidth,
754                    Z3_get_numeral_string(Context.Context, toZ3Expr(*Exp).AST),
755                    10),
756        isUnsigned);
757  }
758
759  bool getBoolean(const SMTExprRef &Exp) override {
760    return Z3_get_bool_value(Context.Context, toZ3Expr(*Exp).AST) == Z3_L_TRUE;
761  }
762
763  SMTExprRef getFloatRoundingMode() override {
764    // TODO: Don't assume nearest ties to even rounding mode
765    return newExprRef(Z3Expr(Context, Z3_mk_fpa_rne(Context.Context)));
766  }
767
768  bool toAPFloat(const SMTSortRef &Sort, const SMTExprRef &AST,
769                 llvm::APFloat &Float, bool useSemantics) {
770    assert(Sort->isFloatSort() && "Unsupported sort to floating-point!");
771
772    llvm::APSInt Int(Sort->getFloatSortSize(), true);
773    const llvm::fltSemantics &Semantics =
774        getFloatSemantics(Sort->getFloatSortSize());
775    SMTSortRef BVSort = getBitvectorSort(Sort->getFloatSortSize());
776    if (!toAPSInt(BVSort, AST, Int, true)) {
777      return false;
778    }
779
780    if (useSemantics && !areEquivalent(Float.getSemantics(), Semantics)) {
781      assert(false && "Floating-point types don't match!");
782      return false;
783    }
784
785    Float = llvm::APFloat(Semantics, Int);
786    return true;
787  }
788
789  bool toAPSInt(const SMTSortRef &Sort, const SMTExprRef &AST,
790                llvm::APSInt &Int, bool useSemantics) {
791    if (Sort->isBitvectorSort()) {
792      if (useSemantics && Int.getBitWidth() != Sort->getBitvectorSortSize()) {
793        assert(false && "Bitvector types don't match!");
794        return false;
795      }
796
797      // FIXME: This function is also used to retrieve floating-point values,
798      // which can be 16, 32, 64 or 128 bits long. Bitvectors can be anything
799      // between 1 and 64 bits long, which is the reason we have this weird
800      // guard. In the future, we need proper calls in the backend to retrieve
801      // floating-points and its special values (NaN, +/-infinity, +/-zero),
802      // then we can drop this weird condition.
803      if (Sort->getBitvectorSortSize() <= 64 ||
804          Sort->getBitvectorSortSize() == 128) {
805        Int = getBitvector(AST, Int.getBitWidth(), Int.isUnsigned());
806        return true;
807      }
808
809      assert(false && "Bitwidth not supported!");
810      return false;
811    }
812
813    if (Sort->isBooleanSort()) {
814      if (useSemantics && Int.getBitWidth() < 1) {
815        assert(false && "Boolean type doesn't match!");
816        return false;
817      }
818
819      Int = llvm::APSInt(llvm::APInt(Int.getBitWidth(), getBoolean(AST)),
820                         Int.isUnsigned());
821      return true;
822    }
823
824    llvm_unreachable("Unsupported sort to integer!");
825  }
826
827  bool getInterpretation(const SMTExprRef &Exp, llvm::APSInt &Int) override {
828    Z3Model Model(Context, Z3_solver_get_model(Context.Context, Solver));
829    Z3_func_decl Func = Z3_get_app_decl(
830        Context.Context, Z3_to_app(Context.Context, toZ3Expr(*Exp).AST));
831    if (Z3_model_has_interp(Context.Context, Model.Model, Func) != Z3_L_TRUE)
832      return false;
833
834    SMTExprRef Assign = newExprRef(
835        Z3Expr(Context,
836               Z3_model_get_const_interp(Context.Context, Model.Model, Func)));
837    SMTSortRef Sort = getSort(Assign);
838    return toAPSInt(Sort, Assign, Int, true);
839  }
840
841  bool getInterpretation(const SMTExprRef &Exp, llvm::APFloat &Float) override {
842    Z3Model Model(Context, Z3_solver_get_model(Context.Context, Solver));
843    Z3_func_decl Func = Z3_get_app_decl(
844        Context.Context, Z3_to_app(Context.Context, toZ3Expr(*Exp).AST));
845    if (Z3_model_has_interp(Context.Context, Model.Model, Func) != Z3_L_TRUE)
846      return false;
847
848    SMTExprRef Assign = newExprRef(
849        Z3Expr(Context,
850               Z3_model_get_const_interp(Context.Context, Model.Model, Func)));
851    SMTSortRef Sort = getSort(Assign);
852    return toAPFloat(Sort, Assign, Float, true);
853  }
854
855  Optional<bool> check() const override {
856    Z3_lbool res = Z3_solver_check(Context.Context, Solver);
857    if (res == Z3_L_TRUE)
858      return true;
859
860    if (res == Z3_L_FALSE)
861      return false;
862
863    return Optional<bool>();
864  }
865
866  void push() override { return Z3_solver_push(Context.Context, Solver); }
867
868  void pop(unsigned NumStates = 1) override {
869    assert(Z3_solver_get_num_scopes(Context.Context, Solver) >= NumStates);
870    return Z3_solver_pop(Context.Context, Solver, NumStates);
871  }
872
873  bool isFPSupported() override { return true; }
874
875  /// Reset the solver and remove all constraints.
876  void reset() override { Z3_solver_reset(Context.Context, Solver); }
877
878  void print(raw_ostream &OS) const override {
879    OS << Z3_solver_to_string(Context.Context, Solver);
880  }
881}; // end class Z3Solver
882
883} // end anonymous namespace
884
885#endif
886
887llvm::SMTSolverRef llvm::CreateZ3Solver() {
888#if LLVM_WITH_Z3
889  return std::make_unique<Z3Solver>();
890#else
891  llvm::report_fatal_error("LLVM was not compiled with Z3 support, rebuild "
892                           "with -DLLVM_ENABLE_Z3_SOLVER=ON",
893                           false);
894  return nullptr;
895#endif
896}
897
898LLVM_DUMP_METHOD void SMTSort::dump() const { print(llvm::errs()); }
899LLVM_DUMP_METHOD void SMTExpr::dump() const { print(llvm::errs()); }
900LLVM_DUMP_METHOD void SMTSolver::dump() const { print(llvm::errs()); }
901