aboutsummaryrefslogtreecommitdiff
path: root/include/llvm/CodeGen/PBQP/Math.h
blob: 8b014ccbb07ba72655c1024b7184d8644a474140 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
//===- Math.h - PBQP Vector and Matrix classes ------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_CODEGEN_PBQP_MATH_H
#define LLVM_CODEGEN_PBQP_MATH_H

#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include <algorithm>
#include <cassert>
#include <functional>
#include <memory>

namespace llvm {
namespace PBQP {

using PBQPNum = float;

/// PBQP Vector class.
class Vector {
  friend hash_code hash_value(const Vector &);

public:
  /// Construct a PBQP vector of the given size.
  explicit Vector(unsigned Length)
    : Length(Length), Data(llvm::make_unique<PBQPNum []>(Length)) {}

  /// Construct a PBQP vector with initializer.
  Vector(unsigned Length, PBQPNum InitVal)
    : Length(Length), Data(llvm::make_unique<PBQPNum []>(Length)) {
    std::fill(Data.get(), Data.get() + Length, InitVal);
  }

  /// Copy construct a PBQP vector.
  Vector(const Vector &V)
    : Length(V.Length), Data(llvm::make_unique<PBQPNum []>(Length)) {
    std::copy(V.Data.get(), V.Data.get() + Length, Data.get());
  }

  /// Move construct a PBQP vector.
  Vector(Vector &&V)
    : Length(V.Length), Data(std::move(V.Data)) {
    V.Length = 0;
  }

  /// Comparison operator.
  bool operator==(const Vector &V) const {
    assert(Length != 0 && Data && "Invalid vector");
    if (Length != V.Length)
      return false;
    return std::equal(Data.get(), Data.get() + Length, V.Data.get());
  }

  /// Return the length of the vector
  unsigned getLength() const {
    assert(Length != 0 && Data && "Invalid vector");
    return Length;
  }

  /// Element access.
  PBQPNum& operator[](unsigned Index) {
    assert(Length != 0 && Data && "Invalid vector");
    assert(Index < Length && "Vector element access out of bounds.");
    return Data[Index];
  }

  /// Const element access.
  const PBQPNum& operator[](unsigned Index) const {
    assert(Length != 0 && Data && "Invalid vector");
    assert(Index < Length && "Vector element access out of bounds.");
    return Data[Index];
  }

  /// Add another vector to this one.
  Vector& operator+=(const Vector &V) {
    assert(Length != 0 && Data && "Invalid vector");
    assert(Length == V.Length && "Vector length mismatch.");
    std::transform(Data.get(), Data.get() + Length, V.Data.get(), Data.get(),
                   std::plus<PBQPNum>());
    return *this;
  }

  /// Returns the index of the minimum value in this vector
  unsigned minIndex() const {
    assert(Length != 0 && Data && "Invalid vector");
    return std::min_element(Data.get(), Data.get() + Length) - Data.get();
  }

private:
  unsigned Length;
  std::unique_ptr<PBQPNum []> Data;
};

/// Return a hash_value for the given vector.
inline hash_code hash_value(const Vector &V) {
  unsigned *VBegin = reinterpret_cast<unsigned*>(V.Data.get());
  unsigned *VEnd = reinterpret_cast<unsigned*>(V.Data.get() + V.Length);
  return hash_combine(V.Length, hash_combine_range(VBegin, VEnd));
}

/// Output a textual representation of the given vector on the given
///        output stream.
template <typename OStream>
OStream& operator<<(OStream &OS, const Vector &V) {
  assert((V.getLength() != 0) && "Zero-length vector badness.");

  OS << "[ " << V[0];
  for (unsigned i = 1; i < V.getLength(); ++i)
    OS << ", " << V[i];
  OS << " ]";

  return OS;
}

/// PBQP Matrix class
class Matrix {
private:
  friend hash_code hash_value(const Matrix &);

public:
  /// Construct a PBQP Matrix with the given dimensions.
  Matrix(unsigned Rows, unsigned Cols) :
    Rows(Rows), Cols(Cols), Data(llvm::make_unique<PBQPNum []>(Rows * Cols)) {
  }

  /// Construct a PBQP Matrix with the given dimensions and initial
  /// value.
  Matrix(unsigned Rows, unsigned Cols, PBQPNum InitVal)
    : Rows(Rows), Cols(Cols),
      Data(llvm::make_unique<PBQPNum []>(Rows * Cols)) {
    std::fill(Data.get(), Data.get() + (Rows * Cols), InitVal);
  }

  /// Copy construct a PBQP matrix.
  Matrix(const Matrix &M)
    : Rows(M.Rows), Cols(M.Cols),
      Data(llvm::make_unique<PBQPNum []>(Rows * Cols)) {
    std::copy(M.Data.get(), M.Data.get() + (Rows * Cols), Data.get());
  }

  /// Move construct a PBQP matrix.
  Matrix(Matrix &&M)
    : Rows(M.Rows), Cols(M.Cols), Data(std::move(M.Data)) {
    M.Rows = M.Cols = 0;
  }

  /// Comparison operator.
  bool operator==(const Matrix &M) const {
    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    if (Rows != M.Rows || Cols != M.Cols)
      return false;
    return std::equal(Data.get(), Data.get() + (Rows * Cols), M.Data.get());
  }

  /// Return the number of rows in this matrix.
  unsigned getRows() const {
    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    return Rows;
  }

  /// Return the number of cols in this matrix.
  unsigned getCols() const {
    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    return Cols;
  }

  /// Matrix element access.
  PBQPNum* operator[](unsigned R) {
    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    assert(R < Rows && "Row out of bounds.");
    return Data.get() + (R * Cols);
  }

  /// Matrix element access.
  const PBQPNum* operator[](unsigned R) const {
    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    assert(R < Rows && "Row out of bounds.");
    return Data.get() + (R * Cols);
  }

  /// Returns the given row as a vector.
  Vector getRowAsVector(unsigned R) const {
    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    Vector V(Cols);
    for (unsigned C = 0; C < Cols; ++C)
      V[C] = (*this)[R][C];
    return V;
  }

  /// Returns the given column as a vector.
  Vector getColAsVector(unsigned C) const {
    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    Vector V(Rows);
    for (unsigned R = 0; R < Rows; ++R)
      V[R] = (*this)[R][C];
    return V;
  }

  /// Matrix transpose.
  Matrix transpose() const {
    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    Matrix M(Cols, Rows);
    for (unsigned r = 0; r < Rows; ++r)
      for (unsigned c = 0; c < Cols; ++c)
        M[c][r] = (*this)[r][c];
    return M;
  }

  /// Add the given matrix to this one.
  Matrix& operator+=(const Matrix &M) {
    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    assert(Rows == M.Rows && Cols == M.Cols &&
           "Matrix dimensions mismatch.");
    std::transform(Data.get(), Data.get() + (Rows * Cols), M.Data.get(),
                   Data.get(), std::plus<PBQPNum>());
    return *this;
  }

  Matrix operator+(const Matrix &M) {
    assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
    Matrix Tmp(*this);
    Tmp += M;
    return Tmp;
  }

private:
  unsigned Rows, Cols;
  std::unique_ptr<PBQPNum []> Data;
};

/// Return a hash_code for the given matrix.
inline hash_code hash_value(const Matrix &M) {
  unsigned *MBegin = reinterpret_cast<unsigned*>(M.Data.get());
  unsigned *MEnd =
    reinterpret_cast<unsigned*>(M.Data.get() + (M.Rows * M.Cols));
  return hash_combine(M.Rows, M.Cols, hash_combine_range(MBegin, MEnd));
}

/// Output a textual representation of the given matrix on the given
///        output stream.
template <typename OStream>
OStream& operator<<(OStream &OS, const Matrix &M) {
  assert((M.getRows() != 0) && "Zero-row matrix badness.");
  for (unsigned i = 0; i < M.getRows(); ++i)
    OS << M.getRowAsVector(i) << "\n";
  return OS;
}

template <typename Metadata>
class MDVector : public Vector {
public:
  MDVector(const Vector &v) : Vector(v), md(*this) {}
  MDVector(Vector &&v) : Vector(std::move(v)), md(*this) { }

  const Metadata& getMetadata() const { return md; }

private:
  Metadata md;
};

template <typename Metadata>
inline hash_code hash_value(const MDVector<Metadata> &V) {
  return hash_value(static_cast<const Vector&>(V));
}

template <typename Metadata>
class MDMatrix : public Matrix {
public:
  MDMatrix(const Matrix &m) : Matrix(m), md(*this) {}
  MDMatrix(Matrix &&m) : Matrix(std::move(m)), md(*this) { }

  const Metadata& getMetadata() const { return md; }

private:
  Metadata md;
};

template <typename Metadata>
inline hash_code hash_value(const MDMatrix<Metadata> &M) {
  return hash_value(static_cast<const Matrix&>(M));
}

} // end namespace PBQP
} // end namespace llvm

#endif // LLVM_CODEGEN_PBQP_MATH_H