blob: 95c454801a2dfebb731f621df5ca5078dc66b4b8 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
|
#pragma once
template<uint O, typename T, uint ...S>
struct Tensor {
static constexpr USize element_size = (S*...);
template<typename ...Args>
using EnableArgs = std::enable_if_t<sizeof...(Args) == O, Bool>;
template<typename ...Args>
using EnableArgsPerElement = std::enable_if_t<sizeof...(Args) == element_size, Bool>;
using Enable = std::enable_if_t<sizeof...(S) == O, Bool>;
template<Enable = true>
Tensor() : elements{} {}
template<Enable = true>
explicit Tensor(T scalar) {
std::fill(elements, elements + element_size, scalar);
}
template<Enable = true, typename ...Args, EnableArgsPerElement<Args...> = true>
Tensor(Args... args) : elements{ static_cast<T>(args)... } {}
template<typename ...Args, EnableArgs<Args...> = true>
auto& operator()(Args... args) {
return elements[pos(args...)];
}
template<typename ...Args, EnableArgs<Args...> = true>
static constexpr USize pos(Args... args) {
USize positions[O] = {static_cast<USize>(args)...};
USize dimensions[O] = {S...};
USize p = 0;
for (Int i = 0; i < O; i++) {
p *= dimensions[i];
p += positions[i];
}
return p;
}
T elements[element_size];
};
|