1/*
2 * extractExternal.cpp
3 */
4
5//===----------------------------------------------------------------------===//
6//
7// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
8// See https://llvm.org/LICENSE.txt for license information.
9// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
10//
11//===----------------------------------------------------------------------===//
12
13#include <fstream>
14#include <iostream>
15#include <map>
16#include <set>
17#include <stdlib.h>
18#include <string>
19#include <strstream>
20
21/* Given a set of n object files h ('external' object files) and a set of m
22   object files o ('internal' object files),
23   1. Determines r, the subset of h that o depends on, directly or indirectly
24   2. Removes the files in h - r from the file system
25   3. For each external symbol defined in some file in r, rename it in r U o
26      by prefixing it with "__kmp_external_"
27   Usage:
28   hide.exe <n> <filenames for h> <filenames for o>
29
30   Thus, the prefixed symbols become hidden in the sense that they now have a
31   special prefix.
32*/
33
34using namespace std;
35
36void stop(char *errorMsg) {
37  printf("%s\n", errorMsg);
38  exit(1);
39}
40
41// an entry in the symbol table of a .OBJ file
42class Symbol {
43public:
44  __int64 name;
45  unsigned value;
46  unsigned short sectionNum, type;
47  char storageClass, nAux;
48};
49
50class _rstream : public istrstream {
51private:
52  const char *buf;
53
54protected:
55  _rstream(pair<const char *, streamsize> p)
56      : istrstream(p.first, p.second), buf(p.first) {}
57  ~_rstream() { delete[] buf; }
58};
59
60// A stream encapsulating the content of a file or the content of a string,
61// overriding the >> operator to read various integer types in binary form,
62// as well as a symbol table entry.
63class rstream : public _rstream {
64private:
65  template <class T> inline rstream &doRead(T &x) {
66    read((char *)&x, sizeof(T));
67    return *this;
68  }
69  static pair<const char *, streamsize> getBuf(const char *fileName) {
70    ifstream raw(fileName, ios::binary | ios::in);
71    if (!raw.is_open())
72      stop("rstream.getBuf: Error opening file");
73    raw.seekg(0, ios::end);
74    streampos fileSize = raw.tellg();
75    if (fileSize < 0)
76      stop("rstream.getBuf: Error reading file");
77    char *buf = new char[fileSize];
78    raw.seekg(0, ios::beg);
79    raw.read(buf, fileSize);
80    return pair<const char *, streamsize>(buf, fileSize);
81  }
82
83public:
84  // construct from a string
85  rstream(const char *buf, streamsize size)
86      : _rstream(pair<const char *, streamsize>(buf, size)) {}
87  // construct from a file whole content is fully read once to initialize the
88  // content of this stream
89  rstream(const char *fileName) : _rstream(getBuf(fileName)) {}
90  rstream &operator>>(int &x) { return doRead(x); }
91  rstream &operator>>(unsigned &x) { return doRead(x); }
92  rstream &operator>>(short &x) { return doRead(x); }
93  rstream &operator>>(unsigned short &x) { return doRead(x); }
94  rstream &operator>>(Symbol &e) {
95    read((char *)&e, 18);
96    return *this;
97  }
98};
99
100// string table in a .OBJ file
101class StringTable {
102private:
103  map<string, unsigned> directory;
104  size_t length;
105  char *data;
106
107  // make <directory> from <length> bytes in <data>
108  void makeDirectory(void) {
109    unsigned i = 4;
110    while (i < length) {
111      string s = string(data + i);
112      directory.insert(make_pair(s, i));
113      i += s.size() + 1;
114    }
115  }
116  // initialize <length> and <data> with contents specified by the arguments
117  void init(const char *_data) {
118    unsigned _length = *(unsigned *)_data;
119
120    if (_length < sizeof(unsigned) || _length != *(unsigned *)_data)
121      stop("StringTable.init: Invalid symbol table");
122    if (_data[_length - 1]) {
123      // to prevent runaway strings, make sure the data ends with a zero
124      data = new char[length = _length + 1];
125      data[_length] = 0;
126    } else {
127      data = new char[length = _length];
128    }
129    *(unsigned *)data = length;
130    KMP_MEMCPY(data + sizeof(unsigned), _data + sizeof(unsigned),
131               length - sizeof(unsigned));
132    makeDirectory();
133  }
134
135public:
136  StringTable(rstream &f) {
137    // Construct string table by reading from f.
138    streampos s;
139    unsigned strSize;
140    char *strData;
141
142    s = f.tellg();
143    f >> strSize;
144    if (strSize < sizeof(unsigned))
145      stop("StringTable: Invalid string table");
146    strData = new char[strSize];
147    *(unsigned *)strData = strSize;
148    // read the raw data into <strData>
149    f.read(strData + sizeof(unsigned), strSize - sizeof(unsigned));
150    s = f.tellg() - s;
151    if (s < strSize)
152      stop("StringTable: Unexpected EOF");
153    init(strData);
154    delete[] strData;
155  }
156  StringTable(const set<string> &strings) {
157    // Construct string table from given strings.
158    char *p;
159    set<string>::const_iterator it;
160    size_t s;
161
162    // count required size for data
163    for (length = sizeof(unsigned), it = strings.begin(); it != strings.end();
164         ++it) {
165      size_t l = (*it).size();
166
167      if (l > (unsigned)0xFFFFFFFF)
168        stop("StringTable: String too long");
169      if (l > 8) {
170        length += l + 1;
171        if (length > (unsigned)0xFFFFFFFF)
172          stop("StringTable: Symbol table too long");
173      }
174    }
175    data = new char[length];
176    *(unsigned *)data = length;
177    // populate data and directory
178    for (p = data + sizeof(unsigned), it = strings.begin(); it != strings.end();
179         ++it) {
180      const string &str = *it;
181      size_t l = str.size();
182      if (l > 8) {
183        directory.insert(make_pair(str, p - data));
184        KMP_MEMCPY(p, str.c_str(), l);
185        p[l] = 0;
186        p += l + 1;
187      }
188    }
189  }
190  ~StringTable() { delete[] data; }
191  // Returns encoding for given string based on this string table. Error if
192  // string length is greater than 8 but string is not in the string table
193  // -- returns 0.
194  __int64 encode(const string &str) {
195    __int64 r;
196
197    if (str.size() <= 8) {
198      // encoded directly
199      ((char *)&r)[7] = 0;
200      KMP_STRNCPY_S((char *)&r, sizeof(r), str.c_str(), 8);
201      return r;
202    } else {
203      // represented as index into table
204      map<string, unsigned>::const_iterator it = directory.find(str);
205      if (it == directory.end())
206        stop("StringTable::encode: String now found in string table");
207      ((unsigned *)&r)[0] = 0;
208      ((unsigned *)&r)[1] = (*it).second;
209      return r;
210    }
211  }
212  // Returns string represented by x based on this string table. Error if x
213  // references an invalid position in the table--returns the empty string.
214  string decode(__int64 x) const {
215    if (*(unsigned *)&x == 0) {
216      // represented as index into table
217      unsigned &p = ((unsigned *)&x)[1];
218      if (p >= length)
219        stop("StringTable::decode: Invalid string table lookup");
220      return string(data + p);
221    } else {
222      // encoded directly
223      char *p = (char *)&x;
224      int i;
225
226      for (i = 0; i < 8 && p[i]; ++i)
227        ;
228      return string(p, i);
229    }
230  }
231  void write(ostream &os) { os.write(data, length); }
232};
233
234// for the named object file, determines the set of defined symbols and the set
235// of undefined external symbols and writes them to <defined> and <undefined>
236// respectively
237void computeExternalSymbols(const char *fileName, set<string> *defined,
238                            set<string> *undefined) {
239  streampos fileSize;
240  size_t strTabStart;
241  unsigned symTabStart, symNEntries;
242  rstream f(fileName);
243
244  f.seekg(0, ios::end);
245  fileSize = f.tellg();
246
247  f.seekg(8);
248  f >> symTabStart >> symNEntries;
249  // seek to the string table
250  f.seekg(strTabStart = symTabStart + 18 * (size_t)symNEntries);
251  if (f.eof()) {
252    printf("computeExternalSymbols: fileName='%s', fileSize = %lu, symTabStart "
253           "= %u, symNEntries = %u\n",
254           fileName, (unsigned long)fileSize, symTabStart, symNEntries);
255    stop("computeExternalSymbols: Unexpected EOF 1");
256  }
257  StringTable stringTable(f); // read the string table
258  if (f.tellg() != fileSize)
259    stop("computeExternalSymbols: Unexpected data after string table");
260
261  f.clear();
262  f.seekg(symTabStart); // seek to the symbol table
263
264  defined->clear();
265  undefined->clear();
266  for (int i = 0; i < symNEntries; ++i) {
267    // process each entry
268    Symbol e;
269
270    if (f.eof())
271      stop("computeExternalSymbols: Unexpected EOF 2");
272    f >> e;
273    if (f.fail())
274      stop("computeExternalSymbols: File read error");
275    if (e.nAux) { // auxiliary entry: skip
276      f.seekg(e.nAux * 18, ios::cur);
277      i += e.nAux;
278    }
279    // if symbol is extern and defined in the current file, insert it
280    if (e.storageClass == 2)
281      if (e.sectionNum)
282        defined->insert(stringTable.decode(e.name));
283      else
284        undefined->insert(stringTable.decode(e.name));
285  }
286}
287
288// For each occurrence of an external symbol in the object file named by
289// by <fileName> that is a member of <hide>, renames it by prefixing
290// with "__kmp_external_", writing back the file in-place
291void hideSymbols(char *fileName, const set<string> &hide) {
292  static const string prefix("__kmp_external_");
293  set<string> strings; // set of all occurring symbols, appropriately prefixed
294  streampos fileSize;
295  size_t strTabStart;
296  unsigned symTabStart, symNEntries;
297  int i;
298  rstream in(fileName);
299
300  in.seekg(0, ios::end);
301  fileSize = in.tellg();
302
303  in.seekg(8);
304  in >> symTabStart >> symNEntries;
305  in.seekg(strTabStart = symTabStart + 18 * (size_t)symNEntries);
306  if (in.eof())
307    stop("hideSymbols: Unexpected EOF");
308  StringTable stringTableOld(in); // read original string table
309
310  if (in.tellg() != fileSize)
311    stop("hideSymbols: Unexpected data after string table");
312
313  // compute set of occurring strings with prefix added
314  for (i = 0; i < symNEntries; ++i) {
315    Symbol e;
316
317    in.seekg(symTabStart + i * 18);
318    if (in.eof())
319      stop("hideSymbols: Unexpected EOF");
320    in >> e;
321    if (in.fail())
322      stop("hideSymbols: File read error");
323    if (e.nAux)
324      i += e.nAux;
325    const string &s = stringTableOld.decode(e.name);
326    // if symbol is extern and found in <hide>, prefix and insert into strings,
327    // otherwise, just insert into strings without prefix
328    strings.insert(
329        (e.storageClass == 2 && hide.find(s) != hide.end()) ? prefix + s : s);
330  }
331
332  ofstream out(fileName, ios::trunc | ios::out | ios::binary);
333  if (!out.is_open())
334    stop("hideSymbols: Error opening output file");
335
336  // make new string table from string set
337  StringTable stringTableNew = StringTable(strings);
338
339  // copy input file to output file up to just before the symbol table
340  in.seekg(0);
341  char *buf = new char[symTabStart];
342  in.read(buf, symTabStart);
343  out.write(buf, symTabStart);
344  delete[] buf;
345
346  // copy input symbol table to output symbol table with name translation
347  for (i = 0; i < symNEntries; ++i) {
348    Symbol e;
349
350    in.seekg(symTabStart + i * 18);
351    if (in.eof())
352      stop("hideSymbols: Unexpected EOF");
353    in >> e;
354    if (in.fail())
355      stop("hideSymbols: File read error");
356    const string &s = stringTableOld.decode(e.name);
357    out.seekp(symTabStart + i * 18);
358    e.name = stringTableNew.encode(
359        (e.storageClass == 2 && hide.find(s) != hide.end()) ? prefix + s : s);
360    out.write((char *)&e, 18);
361    if (out.fail())
362      stop("hideSymbols: File write error");
363    if (e.nAux) {
364      // copy auxiliary symbol table entries
365      int nAux = e.nAux;
366      for (int j = 1; j <= nAux; ++j) {
367        in >> e;
368        out.seekp(symTabStart + (i + j) * 18);
369        out.write((char *)&e, 18);
370      }
371      i += nAux;
372    }
373  }
374  // output string table
375  stringTableNew.write(out);
376}
377
378// returns true iff <a> and <b> have no common element
379template <class T> bool isDisjoint(const set<T> &a, const set<T> &b) {
380  set<T>::const_iterator ita, itb;
381
382  for (ita = a.begin(), itb = b.begin(); ita != a.end() && itb != b.end();) {
383    const T &ta = *ita, &tb = *itb;
384    if (ta < tb)
385      ++ita;
386    else if (tb < ta)
387      ++itb;
388    else
389      return false;
390  }
391  return true;
392}
393
394// PRE: <defined> and <undefined> are arrays with <nTotal> elements where
395// <nTotal> >= <nExternal>.  The first <nExternal> elements correspond to the
396// external object files and the rest correspond to the internal object files.
397// POST: file x is said to depend on file y if undefined[x] and defined[y] are
398// not disjoint. Returns the transitive closure of the set of internal object
399// files, as a set of file indexes, under the 'depends on' relation, minus the
400// set of internal object files.
401set<int> *findRequiredExternal(int nExternal, int nTotal, set<string> *defined,
402                               set<string> *undefined) {
403  set<int> *required = new set<int>;
404  set<int> fresh[2];
405  int i, cur = 0;
406  bool changed;
407
408  for (i = nTotal - 1; i >= nExternal; --i)
409    fresh[cur].insert(i);
410  do {
411    changed = false;
412    for (set<int>::iterator it = fresh[cur].begin(); it != fresh[cur].end();
413         ++it) {
414      set<string> &s = undefined[*it];
415
416      for (i = 0; i < nExternal; ++i) {
417        if (required->find(i) == required->end()) {
418          if (!isDisjoint(defined[i], s)) {
419            // found a new qualifying element
420            required->insert(i);
421            fresh[1 - cur].insert(i);
422            changed = true;
423          }
424        }
425      }
426    }
427    fresh[cur].clear();
428    cur = 1 - cur;
429  } while (changed);
430  return required;
431}
432
433int main(int argc, char **argv) {
434  int nExternal, nInternal, i;
435  set<string> *defined, *undefined;
436  set<int>::iterator it;
437
438  if (argc < 3)
439    stop("Please specify a positive integer followed by a list of object "
440         "filenames");
441  nExternal = atoi(argv[1]);
442  if (nExternal <= 0)
443    stop("Please specify a positive integer followed by a list of object "
444         "filenames");
445  if (nExternal + 2 > argc)
446    stop("Too few external objects");
447  nInternal = argc - nExternal - 2;
448  defined = new set<string>[argc - 2];
449  undefined = new set<string>[argc - 2];
450
451  // determine the set of defined and undefined external symbols
452  for (i = 2; i < argc; ++i)
453    computeExternalSymbols(argv[i], defined + i - 2, undefined + i - 2);
454
455  // determine the set of required external files
456  set<int> *requiredExternal =
457      findRequiredExternal(nExternal, argc - 2, defined, undefined);
458  set<string> hide;
459
460  // determine the set of symbols to hide--namely defined external symbols of
461  // the required external files
462  for (it = requiredExternal->begin(); it != requiredExternal->end(); ++it) {
463    int idx = *it;
464    set<string>::iterator it2;
465    // We have to insert one element at a time instead of inserting a range
466    // because the insert member function taking a range doesn't exist on
467    // Windows* OS, at least at the time of this writing.
468    for (it2 = defined[idx].begin(); it2 != defined[idx].end(); ++it2)
469      hide.insert(*it2);
470  }
471
472  // process the external files--removing those that are not required and hiding
473  //   the appropriate symbols in the others
474  for (i = 0; i < nExternal; ++i)
475    if (requiredExternal->find(i) != requiredExternal->end())
476      hideSymbols(argv[2 + i], hide);
477    else
478      remove(argv[2 + i]);
479  // hide the appropriate symbols in the internal files
480  for (i = nExternal + 2; i < argc; ++i)
481    hideSymbols(argv[i], hide);
482  return 0;
483}
484