1223013Sdim//===- SetTheory.cpp - Generate ordered sets from DAG expressions ---------===//
2223013Sdim//
3223013Sdim//                     The LLVM Compiler Infrastructure
4223013Sdim//
5223013Sdim// This file is distributed under the University of Illinois Open Source
6223013Sdim// License. See LICENSE.TXT for details.
7223013Sdim//
8223013Sdim//===----------------------------------------------------------------------===//
9223013Sdim//
10223013Sdim// This file implements the SetTheory class that computes ordered sets of
11223013Sdim// Records from DAG expressions.
12223013Sdim//
13223013Sdim//===----------------------------------------------------------------------===//
14223013Sdim
15223013Sdim#include "SetTheory.h"
16249423Sdim#include "llvm/Support/Format.h"
17226633Sdim#include "llvm/TableGen/Error.h"
18226633Sdim#include "llvm/TableGen/Record.h"
19223013Sdim
20223013Sdimusing namespace llvm;
21223013Sdim
22223013Sdim// Define the standard operators.
23223013Sdimnamespace {
24223013Sdim
25223013Sdimtypedef SetTheory::RecSet RecSet;
26223013Sdimtypedef SetTheory::RecVec RecVec;
27223013Sdim
28223013Sdim// (add a, b, ...) Evaluate and union all arguments.
29223013Sdimstruct AddOp : public SetTheory::Operator {
30263508Sdim  virtual void apply(SetTheory &ST, DagInit *Expr, RecSet &Elts,
31263508Sdim                     ArrayRef<SMLoc> Loc) {
32243830Sdim    ST.evaluate(Expr->arg_begin(), Expr->arg_end(), Elts, Loc);
33223013Sdim  }
34223013Sdim};
35223013Sdim
36223013Sdim// (sub Add, Sub, ...) Set difference.
37223013Sdimstruct SubOp : public SetTheory::Operator {
38263508Sdim  virtual void apply(SetTheory &ST, DagInit *Expr, RecSet &Elts,
39263508Sdim                     ArrayRef<SMLoc> Loc) {
40223013Sdim    if (Expr->arg_size() < 2)
41243830Sdim      PrintFatalError(Loc, "Set difference needs at least two arguments: " +
42243830Sdim        Expr->getAsString());
43223013Sdim    RecSet Add, Sub;
44243830Sdim    ST.evaluate(*Expr->arg_begin(), Add, Loc);
45243830Sdim    ST.evaluate(Expr->arg_begin() + 1, Expr->arg_end(), Sub, Loc);
46223013Sdim    for (RecSet::iterator I = Add.begin(), E = Add.end(); I != E; ++I)
47223013Sdim      if (!Sub.count(*I))
48223013Sdim        Elts.insert(*I);
49223013Sdim  }
50223013Sdim};
51223013Sdim
52223013Sdim// (and S1, S2) Set intersection.
53223013Sdimstruct AndOp : public SetTheory::Operator {
54263508Sdim  virtual void apply(SetTheory &ST, DagInit *Expr, RecSet &Elts,
55263508Sdim                     ArrayRef<SMLoc> Loc) {
56223013Sdim    if (Expr->arg_size() != 2)
57243830Sdim      PrintFatalError(Loc, "Set intersection requires two arguments: " +
58243830Sdim        Expr->getAsString());
59223013Sdim    RecSet S1, S2;
60243830Sdim    ST.evaluate(Expr->arg_begin()[0], S1, Loc);
61243830Sdim    ST.evaluate(Expr->arg_begin()[1], S2, Loc);
62223013Sdim    for (RecSet::iterator I = S1.begin(), E = S1.end(); I != E; ++I)
63223013Sdim      if (S2.count(*I))
64223013Sdim        Elts.insert(*I);
65223013Sdim  }
66223013Sdim};
67223013Sdim
68223013Sdim// SetIntBinOp - Abstract base class for (Op S, N) operators.
69223013Sdimstruct SetIntBinOp : public SetTheory::Operator {
70223013Sdim  virtual void apply2(SetTheory &ST, DagInit *Expr,
71223013Sdim                     RecSet &Set, int64_t N,
72243830Sdim                     RecSet &Elts, ArrayRef<SMLoc> Loc) =0;
73223013Sdim
74263508Sdim  virtual void apply(SetTheory &ST, DagInit *Expr, RecSet &Elts,
75263508Sdim                     ArrayRef<SMLoc> Loc) {
76223013Sdim    if (Expr->arg_size() != 2)
77243830Sdim      PrintFatalError(Loc, "Operator requires (Op Set, Int) arguments: " +
78243830Sdim        Expr->getAsString());
79223013Sdim    RecSet Set;
80243830Sdim    ST.evaluate(Expr->arg_begin()[0], Set, Loc);
81243830Sdim    IntInit *II = dyn_cast<IntInit>(Expr->arg_begin()[1]);
82223013Sdim    if (!II)
83243830Sdim      PrintFatalError(Loc, "Second argument must be an integer: " +
84243830Sdim        Expr->getAsString());
85243830Sdim    apply2(ST, Expr, Set, II->getValue(), Elts, Loc);
86223013Sdim  }
87223013Sdim};
88223013Sdim
89223013Sdim// (shl S, N) Shift left, remove the first N elements.
90223013Sdimstruct ShlOp : public SetIntBinOp {
91263508Sdim  virtual void apply2(SetTheory &ST, DagInit *Expr,
92263508Sdim                      RecSet &Set, int64_t N,
93263508Sdim                      RecSet &Elts, ArrayRef<SMLoc> Loc) {
94223013Sdim    if (N < 0)
95243830Sdim      PrintFatalError(Loc, "Positive shift required: " +
96243830Sdim        Expr->getAsString());
97223013Sdim    if (unsigned(N) < Set.size())
98223013Sdim      Elts.insert(Set.begin() + N, Set.end());
99223013Sdim  }
100223013Sdim};
101223013Sdim
102223013Sdim// (trunc S, N) Truncate after the first N elements.
103223013Sdimstruct TruncOp : public SetIntBinOp {
104263508Sdim  virtual void apply2(SetTheory &ST, DagInit *Expr,
105263508Sdim                      RecSet &Set, int64_t N,
106263508Sdim                      RecSet &Elts, ArrayRef<SMLoc> Loc) {
107223013Sdim    if (N < 0)
108243830Sdim      PrintFatalError(Loc, "Positive length required: " +
109243830Sdim        Expr->getAsString());
110223013Sdim    if (unsigned(N) > Set.size())
111223013Sdim      N = Set.size();
112223013Sdim    Elts.insert(Set.begin(), Set.begin() + N);
113223013Sdim  }
114223013Sdim};
115223013Sdim
116223013Sdim// Left/right rotation.
117223013Sdimstruct RotOp : public SetIntBinOp {
118223013Sdim  const bool Reverse;
119223013Sdim
120223013Sdim  RotOp(bool Rev) : Reverse(Rev) {}
121223013Sdim
122263508Sdim  virtual void apply2(SetTheory &ST, DagInit *Expr,
123263508Sdim                      RecSet &Set, int64_t N,
124263508Sdim                      RecSet &Elts, ArrayRef<SMLoc> Loc) {
125223013Sdim    if (Reverse)
126223013Sdim      N = -N;
127223013Sdim    // N > 0 -> rotate left, N < 0 -> rotate right.
128223013Sdim    if (Set.empty())
129223013Sdim      return;
130223013Sdim    if (N < 0)
131223013Sdim      N = Set.size() - (-N % Set.size());
132223013Sdim    else
133223013Sdim      N %= Set.size();
134223013Sdim    Elts.insert(Set.begin() + N, Set.end());
135223013Sdim    Elts.insert(Set.begin(), Set.begin() + N);
136223013Sdim  }
137223013Sdim};
138223013Sdim
139223013Sdim// (decimate S, N) Pick every N'th element of S.
140223013Sdimstruct DecimateOp : public SetIntBinOp {
141263508Sdim  virtual void apply2(SetTheory &ST, DagInit *Expr,
142263508Sdim                      RecSet &Set, int64_t N,
143263508Sdim                      RecSet &Elts, ArrayRef<SMLoc> Loc) {
144223013Sdim    if (N <= 0)
145243830Sdim      PrintFatalError(Loc, "Positive stride required: " +
146243830Sdim        Expr->getAsString());
147223013Sdim    for (unsigned I = 0; I < Set.size(); I += N)
148223013Sdim      Elts.insert(Set[I]);
149223013Sdim  }
150223013Sdim};
151223013Sdim
152234353Sdim// (interleave S1, S2, ...) Interleave elements of the arguments.
153234353Sdimstruct InterleaveOp : public SetTheory::Operator {
154263508Sdim  virtual void apply(SetTheory &ST, DagInit *Expr, RecSet &Elts,
155263508Sdim                     ArrayRef<SMLoc> Loc) {
156234353Sdim    // Evaluate the arguments individually.
157234353Sdim    SmallVector<RecSet, 4> Args(Expr->getNumArgs());
158234353Sdim    unsigned MaxSize = 0;
159234353Sdim    for (unsigned i = 0, e = Expr->getNumArgs(); i != e; ++i) {
160243830Sdim      ST.evaluate(Expr->getArg(i), Args[i], Loc);
161234353Sdim      MaxSize = std::max(MaxSize, unsigned(Args[i].size()));
162234353Sdim    }
163234353Sdim    // Interleave arguments into Elts.
164234353Sdim    for (unsigned n = 0; n != MaxSize; ++n)
165234353Sdim      for (unsigned i = 0, e = Expr->getNumArgs(); i != e; ++i)
166234353Sdim        if (n < Args[i].size())
167234353Sdim          Elts.insert(Args[i][n]);
168234353Sdim  }
169234353Sdim};
170234353Sdim
171223013Sdim// (sequence "Format", From, To) Generate a sequence of records by name.
172223013Sdimstruct SequenceOp : public SetTheory::Operator {
173263508Sdim  virtual void apply(SetTheory &ST, DagInit *Expr, RecSet &Elts,
174263508Sdim                     ArrayRef<SMLoc> Loc) {
175239462Sdim    int Step = 1;
176239462Sdim    if (Expr->arg_size() > 4)
177243830Sdim      PrintFatalError(Loc, "Bad args to (sequence \"Format\", From, To): " +
178243830Sdim        Expr->getAsString());
179239462Sdim    else if (Expr->arg_size() == 4) {
180243830Sdim      if (IntInit *II = dyn_cast<IntInit>(Expr->arg_begin()[3])) {
181239462Sdim        Step = II->getValue();
182239462Sdim      } else
183243830Sdim        PrintFatalError(Loc, "Stride must be an integer: " +
184243830Sdim          Expr->getAsString());
185239462Sdim    }
186239462Sdim
187223013Sdim    std::string Format;
188243830Sdim    if (StringInit *SI = dyn_cast<StringInit>(Expr->arg_begin()[0]))
189223013Sdim      Format = SI->getValue();
190223013Sdim    else
191243830Sdim      PrintFatalError(Loc,  "Format must be a string: " + Expr->getAsString());
192223013Sdim
193223013Sdim    int64_t From, To;
194243830Sdim    if (IntInit *II = dyn_cast<IntInit>(Expr->arg_begin()[1]))
195223013Sdim      From = II->getValue();
196223013Sdim    else
197243830Sdim      PrintFatalError(Loc, "From must be an integer: " + Expr->getAsString());
198224145Sdim    if (From < 0 || From >= (1 << 30))
199243830Sdim      PrintFatalError(Loc, "From out of range");
200224145Sdim
201243830Sdim    if (IntInit *II = dyn_cast<IntInit>(Expr->arg_begin()[2]))
202223013Sdim      To = II->getValue();
203223013Sdim    else
204243830Sdim      PrintFatalError(Loc, "From must be an integer: " + Expr->getAsString());
205224145Sdim    if (To < 0 || To >= (1 << 30))
206243830Sdim      PrintFatalError(Loc, "To out of range");
207223013Sdim
208223013Sdim    RecordKeeper &Records =
209243830Sdim      cast<DefInit>(Expr->getOperator())->getDef()->getRecords();
210223013Sdim
211239462Sdim    Step *= From <= To ? 1 : -1;
212239462Sdim    while (true) {
213239462Sdim      if (Step > 0 && From > To)
214239462Sdim        break;
215239462Sdim      else if (Step < 0 && From < To)
216239462Sdim        break;
217223013Sdim      std::string Name;
218223013Sdim      raw_string_ostream OS(Name);
219224145Sdim      OS << format(Format.c_str(), unsigned(From));
220223013Sdim      Record *Rec = Records.getDef(OS.str());
221223013Sdim      if (!Rec)
222243830Sdim        PrintFatalError(Loc, "No def named '" + Name + "': " +
223243830Sdim          Expr->getAsString());
224223013Sdim      // Try to reevaluate Rec in case it is a set.
225223013Sdim      if (const RecVec *Result = ST.expand(Rec))
226223013Sdim        Elts.insert(Result->begin(), Result->end());
227223013Sdim      else
228223013Sdim        Elts.insert(Rec);
229239462Sdim
230239462Sdim      From += Step;
231223013Sdim    }
232223013Sdim  }
233223013Sdim};
234223013Sdim
235223013Sdim// Expand a Def into a set by evaluating one of its fields.
236223013Sdimstruct FieldExpander : public SetTheory::Expander {
237223013Sdim  StringRef FieldName;
238223013Sdim
239223013Sdim  FieldExpander(StringRef fn) : FieldName(fn) {}
240223013Sdim
241263508Sdim  virtual void expand(SetTheory &ST, Record *Def, RecSet &Elts) {
242243830Sdim    ST.evaluate(Def->getValueInit(FieldName), Elts, Def->getLoc());
243223013Sdim  }
244223013Sdim};
245223013Sdim} // end anonymous namespace
246223013Sdim
247263508Sdim// Pin the vtables to this file.
248263508Sdimvoid SetTheory::Operator::anchor() {}
249263508Sdimvoid SetTheory::Expander::anchor() {}
250234353Sdim
251234353Sdim
252223013SdimSetTheory::SetTheory() {
253223013Sdim  addOperator("add", new AddOp);
254223013Sdim  addOperator("sub", new SubOp);
255223013Sdim  addOperator("and", new AndOp);
256223013Sdim  addOperator("shl", new ShlOp);
257223013Sdim  addOperator("trunc", new TruncOp);
258223013Sdim  addOperator("rotl", new RotOp(false));
259223013Sdim  addOperator("rotr", new RotOp(true));
260223013Sdim  addOperator("decimate", new DecimateOp);
261234353Sdim  addOperator("interleave", new InterleaveOp);
262223013Sdim  addOperator("sequence", new SequenceOp);
263223013Sdim}
264223013Sdim
265223013Sdimvoid SetTheory::addOperator(StringRef Name, Operator *Op) {
266223013Sdim  Operators[Name] = Op;
267223013Sdim}
268223013Sdim
269223013Sdimvoid SetTheory::addExpander(StringRef ClassName, Expander *E) {
270223013Sdim  Expanders[ClassName] = E;
271223013Sdim}
272223013Sdim
273223013Sdimvoid SetTheory::addFieldExpander(StringRef ClassName, StringRef FieldName) {
274223013Sdim  addExpander(ClassName, new FieldExpander(FieldName));
275223013Sdim}
276223013Sdim
277243830Sdimvoid SetTheory::evaluate(Init *Expr, RecSet &Elts, ArrayRef<SMLoc> Loc) {
278223013Sdim  // A def in a list can be a just an element, or it may expand.
279243830Sdim  if (DefInit *Def = dyn_cast<DefInit>(Expr)) {
280223013Sdim    if (const RecVec *Result = expand(Def->getDef()))
281223013Sdim      return Elts.insert(Result->begin(), Result->end());
282223013Sdim    Elts.insert(Def->getDef());
283223013Sdim    return;
284223013Sdim  }
285223013Sdim
286223013Sdim  // Lists simply expand.
287243830Sdim  if (ListInit *LI = dyn_cast<ListInit>(Expr))
288243830Sdim    return evaluate(LI->begin(), LI->end(), Elts, Loc);
289223013Sdim
290223013Sdim  // Anything else must be a DAG.
291243830Sdim  DagInit *DagExpr = dyn_cast<DagInit>(Expr);
292223013Sdim  if (!DagExpr)
293243830Sdim    PrintFatalError(Loc, "Invalid set element: " + Expr->getAsString());
294243830Sdim  DefInit *OpInit = dyn_cast<DefInit>(DagExpr->getOperator());
295223013Sdim  if (!OpInit)
296243830Sdim    PrintFatalError(Loc, "Bad set expression: " + Expr->getAsString());
297223013Sdim  Operator *Op = Operators.lookup(OpInit->getDef()->getName());
298223013Sdim  if (!Op)
299243830Sdim    PrintFatalError(Loc, "Unknown set operator: " + Expr->getAsString());
300243830Sdim  Op->apply(*this, DagExpr, Elts, Loc);
301223013Sdim}
302223013Sdim
303223013Sdimconst RecVec *SetTheory::expand(Record *Set) {
304223013Sdim  // Check existing entries for Set and return early.
305223013Sdim  ExpandMap::iterator I = Expansions.find(Set);
306223013Sdim  if (I != Expansions.end())
307223013Sdim    return &I->second;
308223013Sdim
309223013Sdim  // This is the first time we see Set. Find a suitable expander.
310243830Sdim  const std::vector<Record*> &SC = Set->getSuperClasses();
311243830Sdim  for (unsigned i = 0, e = SC.size(); i != e; ++i) {
312243830Sdim    // Skip unnamed superclasses.
313243830Sdim    if (!dyn_cast<StringInit>(SC[i]->getNameInit()))
314243830Sdim      continue;
315243830Sdim    if (Expander *Exp = Expanders.lookup(SC[i]->getName())) {
316243830Sdim      // This breaks recursive definitions.
317243830Sdim      RecVec &EltVec = Expansions[Set];
318243830Sdim      RecSet Elts;
319243830Sdim      Exp->expand(*this, Set, Elts);
320243830Sdim      EltVec.assign(Elts.begin(), Elts.end());
321243830Sdim      return &EltVec;
322243830Sdim    }
323223013Sdim  }
324223013Sdim
325223013Sdim  // Set is not expandable.
326223013Sdim  return 0;
327223013Sdim}
328223013Sdim
329