blob: e373c1f5b992efee33d18e19cc5c326003c8a382 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
|
#pragma once
template<size_t O, typename T, size_t ...S>
struct Tensor {
static constexpr size_t element_size = (S*...);
template<typename ...Args>
using EnableArgs = std::enable_if_t<sizeof...(Args) == O, bool>;
template<typename ...Args>
using EnableArgsPerElement = std::enable_if_t<sizeof...(Args) == element_size, bool>;
using Enable = std::enable_if_t<sizeof...(S) == O, bool>;
template<Enable = true>
Tensor() : elements{} {}
template<Enable = true>
explicit Tensor(T scalar) {
std::fill(elements, elements + element_size, scalar);
}
template<Enable = true, typename ...Args, EnableArgsPerElement<Args...> = true>
Tensor(Args... args) : elements{ args... } {}
template<typename ...Args, EnableArgs<Args...> = true>
auto& operator()(Args... args) {
return elements[pos(args...)];
}
template<typename ...Args, EnableArgs<Args...> = true>
static constexpr size_t pos(Args... args) {
size_t positions[O] = {static_cast<size_t>(args)...};
size_t dimensions[O] = {S...};
size_t p = 0;
for (int i = 0; i < O; i++) {
p *= dimensions[i];
p += positions[i];
}
return p;
}
T elements[element_size];
};
|