summary refs log tree commit diff
path: root/src/Math/Tensor.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/Math/Tensor.hpp')
-rw-r--r--src/Math/Tensor.hpp43
1 files changed, 43 insertions, 0 deletions
diff --git a/src/Math/Tensor.hpp b/src/Math/Tensor.hpp
new file mode 100644
index 0000000..e373c1f
--- /dev/null
+++ b/src/Math/Tensor.hpp
@@ -0,0 +1,43 @@
+#pragma once
+
+template<size_t O, typename T, size_t ...S>
+struct Tensor {
+    static constexpr size_t element_size = (S*...);
+
+    template<typename ...Args>
+    using EnableArgs = std::enable_if_t<sizeof...(Args) == O, bool>;
+    template<typename ...Args>
+    using EnableArgsPerElement = std::enable_if_t<sizeof...(Args) == element_size, bool>;
+    using Enable = std::enable_if_t<sizeof...(S) == O, bool>;
+
+    template<Enable = true>
+    Tensor() : elements{} {}
+
+    template<Enable = true>
+    explicit Tensor(T scalar) {
+        std::fill(elements, elements + element_size, scalar);
+    }
+
+    template<Enable = true, typename ...Args, EnableArgsPerElement<Args...> = true>
+    Tensor(Args... args) : elements{ args... } {}
+
+    template<typename ...Args, EnableArgs<Args...> = true>
+    auto& operator()(Args... args) {
+        return elements[pos(args...)];
+    }
+
+    template<typename ...Args, EnableArgs<Args...> = true>
+    static constexpr size_t pos(Args... args) {
+        size_t positions[O] = {static_cast<size_t>(args)...};
+        size_t dimensions[O] = {S...};
+
+        size_t p = 0;
+        for (int i = 0; i < O; i++) {
+            p *= dimensions[i];
+            p += positions[i];
+        }
+        return p;
+    }
+
+    T elements[element_size];
+};