summary refs log tree commit diff
path: root/src/Math/Tensor.hpp
blob: e373c1f5b992efee33d18e19cc5c326003c8a382 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#pragma once

template<size_t O, typename T, size_t ...S>
struct Tensor {
    static constexpr size_t element_size = (S*...);

    template<typename ...Args>
    using EnableArgs = std::enable_if_t<sizeof...(Args) == O, bool>;
    template<typename ...Args>
    using EnableArgsPerElement = std::enable_if_t<sizeof...(Args) == element_size, bool>;
    using Enable = std::enable_if_t<sizeof...(S) == O, bool>;

    template<Enable = true>
    Tensor() : elements{} {}

    template<Enable = true>
    explicit Tensor(T scalar) {
        std::fill(elements, elements + element_size, scalar);
    }

    template<Enable = true, typename ...Args, EnableArgsPerElement<Args...> = true>
    Tensor(Args... args) : elements{ args... } {}

    template<typename ...Args, EnableArgs<Args...> = true>
    auto& operator()(Args... args) {
        return elements[pos(args...)];
    }

    template<typename ...Args, EnableArgs<Args...> = true>
    static constexpr size_t pos(Args... args) {
        size_t positions[O] = {static_cast<size_t>(args)...};
        size_t dimensions[O] = {S...};

        size_t p = 0;
        for (int i = 0; i < O; i++) {
            p *= dimensions[i];
            p += positions[i];
        }
        return p;
    }

    T elements[element_size];
};