/*****************************************************************************
 * $CAMITK_LICENCE_BEGIN$
 *
 * CamiTK - Computer Assisted Medical Intervention ToolKit
 * (c) 2001-2025 Univ. Grenoble Alpes, CNRS, Grenoble INP - UGA, TIMC, 38000 Grenoble, France
 *
 * Visit http://camitk.imag.fr for more information
 *
 * This file is part of CamiTK.
 *
 * CamiTK is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Lesser General Public License version 3
 * only, as published by the Free Software Foundation.
 *
 * CamiTK is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License version 3 for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * version 3 along with CamiTK.  If not, see <http://www.gnu.org/licenses/>.
 *
 * $CAMITK_LICENCE_END$
 ****************************************************************************/

#include "numpy_utils.h"

#include <Application.h>
#include <vtkDoubleArray.h>
#include <vtkFloatArray.h>

#include <Log.h>

namespace camitk {

// -------------------- vtkImageDataToNumpyTemplated --------------------
template <typename T>
py::array_t<T> vtkImageDataToNumpyTemplated(vtkSmartPointer<vtkImageData> image) {
    int dims[3];
    image->GetDimensions(dims);
    int numComponents = image->GetNumberOfScalarComponents();
    const T* vtkData = static_cast<const T*>(image->GetScalarPointer());

    if (!vtkData) {
        throw std::runtime_error("Invalid or null VTK image scalar pointer");
    }

    // Get VTK's actual increments (strides in elements, not bytes)
    vtkIdType increments[3];
    image->GetIncrements(increments);

    std::vector<size_t> shape;
    std::vector<size_t> strides;

    size_t x = static_cast<size_t>(dims[0]);
    size_t y = static_cast<size_t>(dims[1]);
    size_t z = static_cast<size_t>(dims[2]);

    if (numComponents == 1) {
        // Black and White image (just one value per voxel)
        if (z == 1) {
            // 2D image
            shape = { y, x };
            strides = {
                static_cast<size_t>(increments[1])* sizeof(T),   // Y stride
                static_cast<size_t>(increments[0])* sizeof(T)    // X stride
            };
        }
        else {
            // 3D image
            shape = { z, y, x };
            strides = {
                static_cast<size_t>(increments[2])* sizeof(T),   // Z stride
                static_cast<size_t>(increments[1])* sizeof(T),   // Y stride
                static_cast<size_t>(increments[0])* sizeof(T)    // X stride
            };
        }
    }
    else {
        // numComponents > 1 (e.g. RGB voxels)
        if (z == 1) {
            // 2D image with components
            shape = { y, x, static_cast<size_t>(numComponents) };
            strides = {
                static_cast<size_t>(increments[1])* sizeof(T),   // Y stride
                static_cast<size_t>(increments[0])* sizeof(T),   // X stride
                sizeof(T)  // Component stride (always 1 element)
            };
        }
        else {
            // 3D image with components
            shape = { z, y, x, static_cast<size_t>(numComponents) };
            strides = {
                static_cast<size_t>(increments[2])* sizeof(T),   // Z stride
                static_cast<size_t>(increments[1])* sizeof(T),   // Y stride
                static_cast<size_t>(increments[0])* sizeof(T),   // X stride
                sizeof(T)  // Component stride
            };
        }
    }

    // Create the numpy array from the vtkBuffer
    return py::array_t<T>(
               py::buffer_info(
                   const_cast<T*>(vtkData),
                   sizeof(T),
                   py::format_descriptor<T>::format(),
                   shape.size(),
                   shape,
                   strides
               )
           ).attr("copy")();
}

// -------------------- vtkImageDataToNumpy --------------------
py::array vtkImageDataToNumpy(vtkSmartPointer<vtkImageData> image) {
    py::array numpyArray;

    int vtkType = image->GetScalarType();
    switch (vtkType) {
        case VTK_TYPE_INT8:
            numpyArray = vtkImageDataToNumpyTemplated<int8_t>(image);
            break;
        case VTK_TYPE_UINT8:
            numpyArray = vtkImageDataToNumpyTemplated<uint8_t>(image);
            break;
        case VTK_TYPE_INT16:
            numpyArray = vtkImageDataToNumpyTemplated<int16_t>(image);
            break;
        case VTK_TYPE_UINT16:
            numpyArray = vtkImageDataToNumpyTemplated<uint16_t>(image);
            break;
        case VTK_TYPE_INT32:
            numpyArray = vtkImageDataToNumpyTemplated<int32_t>(image);
            break;
        case VTK_TYPE_UINT32:
            numpyArray = vtkImageDataToNumpyTemplated<uint32_t>(image);
            break;
        case VTK_TYPE_INT64:
            numpyArray = vtkImageDataToNumpyTemplated<int64_t>(image);
            break;
        case VTK_TYPE_UINT64:
            numpyArray = vtkImageDataToNumpyTemplated<uint64_t>(image);
            break;
        case VTK_FLOAT:
            numpyArray = vtkImageDataToNumpyTemplated<float>(image);
            break;
        case VTK_DOUBLE:
            numpyArray = vtkImageDataToNumpyTemplated<double>(image);
            break;
        default:
            throw std::runtime_error("Unsupported VTK scalar type");
    }

    return numpyArray;
}

// -------------------- getVtkImageDataSpacing --------------------
py::array getVtkImageDataSpacing(vtkSmartPointer<vtkImageData> image) {
    double spacing[3];
    image->GetSpacing(spacing);
    return py::make_tuple(spacing[0], spacing[1], spacing[2]);
}

// -------------------- vtkPointSetToNumpy --------------------
py::array vtkPointSetToNumpy(vtkSmartPointer<vtkPointSet> pointSet) {
    vtkPoints* points = pointSet->GetPoints();
    vtkDataArray* dataArray = points->GetData();
    if (!dataArray) {
        throw std::runtime_error("vtkPoints has no data array");
    }

    // Determine the number of points and components
    vtkIdType numPoints = dataArray->GetNumberOfTuples();
    int numComponents = dataArray->GetNumberOfComponents();

    // Get pointer to the raw data
    void* dataPtr = dataArray->GetVoidPointer(0);

    // Determine the data type and corresponding format descriptor
    int vtkType = dataArray->GetDataType();
    std::string format;
    size_t itemSize;

    switch (vtkType) {
        case VTK_FLOAT:
            format = py::format_descriptor<float>::format();
            itemSize = sizeof(float);
            break;
        case VTK_DOUBLE:
            format = py::format_descriptor<double>::format();
            itemSize = sizeof(double);
            break;
        // Add more cases as needed
        default:
            throw std::runtime_error("Unsupported VTK data type");
    }

    // Define shape and strides
    std::vector<py::ssize_t> shape = { numPoints, numComponents };
    std::vector<py::ssize_t> strides = { static_cast<py::ssize_t>(itemSize * numComponents), static_cast<py::ssize_t>(itemSize) };

    // Create buffer_info
    py::buffer_info bufferInfo(
        dataPtr,                // Pointer to buffer
        itemSize,          // Size of one scalar
        format,                      // Python struct-style format descriptor
        2,                     // Number of dimensions
        shape,             // Buffer dimensions
        strides          // Strides (in bytes) for each index
    );

    // Create numpy array from buffer_info
    return py::array(bufferInfo);
}

// -------------------- vtkType --------------------
/// Automatic conversion from info.format descriptor to Vtk types
/// Using direct conversion for C++ int types that contains fixed size (8,16,32,64) and signed/unsigned flag
/// to VTK corresponding type
template<typename T> constexpr int vtkType();

template<> constexpr int vtkType<int8_t>() {
    return VTK_TYPE_INT8;
}
template<> constexpr int vtkType<uint8_t>() {
    return VTK_TYPE_UINT8;
}
template<> constexpr int vtkType<int16_t>() {
    return VTK_TYPE_INT16;
}
template<> constexpr int vtkType<uint16_t>() {
    return VTK_TYPE_UINT16;
}
template<> constexpr int vtkType<int32_t>() {
    return VTK_TYPE_INT32;
}
template<> constexpr int vtkType<uint32_t>() {
    return VTK_TYPE_UINT32;
}
template<> constexpr int vtkType<int64_t>() {
    return VTK_TYPE_INT64;
}
template<> constexpr int vtkType<uint64_t>() {
    return VTK_TYPE_UINT64;
}
template<> constexpr int vtkType<float>() {
    return VTK_FLOAT;
}
template<> constexpr int vtkType<double>() {
    return VTK_DOUBLE;
}

// -------------------- numpyToVTKImageDataTemplated --------------------
template<typename T>
vtkSmartPointer<vtkImageData> numpyToVTKImageDataTemplated(const pybind11::array_t<T>& array) {
    pybind11::buffer_info info = array.request();
    int ndim = static_cast<int>(info.shape.size());

    if (ndim < 2 || ndim > 4) {
        throw std::runtime_error("Unsupported NumPy array shape. Expected 2D grayscale/color or 3D/4D volumetric image.");
    }

    // Initialize dimensions and strides
    py::ssize_t zSize = 1;
    py::ssize_t ySize = 1;
    py::ssize_t xSize = 1;
    py::ssize_t zStride = 0;
    py::ssize_t yStride = 0;
    py::ssize_t xStride = 0;
    py::ssize_t componentStride = 1;
    py::ssize_t components = 1;

    // Determine values depending on ndim
    if (ndim == 2) {
        // 2D grayscale: (H, W)
        ySize = info.shape[0];
        xSize = info.shape[1];
        yStride = info.strides[0] / info.itemsize;
        xStride = info.strides[1] / info.itemsize;
        zStride = xSize * ySize;  // unused, but valid
    }
    else if (ndim == 3) {
        // Distinguish 2D RGB vs 3D grayscale
        if (info.shape[2] == 3 || info.shape[2] == 4) {
            // 2D RGB/RGBA: (H, W, C)
            // WARNING a 3D image with only 3 or 4 slices will be considered as a 2D RGB/RGBA image
            ySize = info.shape[0];
            xSize = info.shape[1];
            components = info.shape[2];
            yStride = info.strides[0] / info.itemsize;
            xStride = info.strides[1] / info.itemsize;
            componentStride = info.strides[2] / info.itemsize;
            zStride = xSize * ySize * components;  // unused, but valid
        }
        else {
            // 3D grayscale: (D, H, W) → VTK (W, H, D)
            zSize = info.shape[0];
            ySize = info.shape[1];
            xSize = info.shape[2];
            zStride = info.strides[0] / info.itemsize;
            yStride = info.strides[1] / info.itemsize;
            xStride = info.strides[2] / info.itemsize;
        }
    }
    else {
        // ndim == 4
        // 3D with channels: (D, H, W, C) → VTK (W, H, D, C)
        zSize = info.shape[0];
        ySize = info.shape[1];
        xSize = info.shape[2];
        components = info.shape[3];
        zStride = info.strides[0] / info.itemsize;
        yStride = info.strides[1] / info.itemsize;
        xStride = info.strides[2] / info.itemsize;
        componentStride = info.strides[3] / info.itemsize;
    }

    // Create VTK image
    vtkSmartPointer<vtkImageData> image = vtkSmartPointer<vtkImageData>::New();
    image->SetDimensions(xSize, ySize, zSize);
    image->AllocateScalars(vtkType<T>(), components);
    if (image->GetScalarSize() != sizeof(T)) {
        throw std::runtime_error("Unable to convert from NumPy array type to VTK scalar. Size does not match.");
    }

    const size_t totalSize = xSize * ySize * zSize;
    T* src = static_cast<T*>(info.ptr);
    T* dst = static_cast<T*>(image->GetScalarPointer());

    // Check if we can use fast contiguous copy
    bool canUseFastCopy = false;

    // More robust contiguity check
    if (ndim == 2) {
        // 2D grayscale: expect stride pattern [W, 1]
        canUseFastCopy = (xStride == 1) && (yStride == xSize);
    }
    else if (ndim == 3 && components > 1) {
        // 2D RGB: expect stride pattern [W*C, C, 1]
        canUseFastCopy = (componentStride == 1) &&
                         (xStride == components) &&
                         (yStride == xSize * components);
    }
    else if (ndim >= 3 && components == 1) {
        // 3D grayscale: expect stride pattern [H*W, W, 1]
        canUseFastCopy = (xStride == 1) &&
                         (yStride == xSize) &&
                         (zStride == xSize * ySize);
    }
    else if (ndim == 4) {
        // 3D with channels: expect stride pattern [H*W*C, W*C, C, 1]
        canUseFastCopy = (componentStride == 1) &&
                         (xStride == components) &&
                         (yStride == xSize * components) &&
                         (zStride == xSize * ySize * components);
    }

    // CAMITK_INFO_ALT(QString("python buffer_info:\n- size(%1,%2,%3)\n- stride(%4,%5,%6)\n- item_size: %7\n- component: %8\n- componentStride: %9")
    // .arg(info.shape[0]).arg(info.shape[1]).arg((ndim>2)?info.shape[2]:-1).arg(info.strides[0]).arg(info.strides[1]).arg((ndim>2)?info.strides[2]:-1).arg(info.itemsize).arg(components).arg(componentStride))
    // CAMITK_INFO_ALT(QString("VTK: image->SetDimensions(%1, %2, %3), components: %4, VTK type size: %5, numpy type size: %6")
    // .arg(xSize).arg(ySize).arg(zSize).arg(components).arg(image->GetScalarSize()).arg(sizeof(T)))
    // CAMITK_INFO_ALT(QString("%1").arg(canUseFastCopy?"Using fast copy":"Not using fast copy"))

    if (canUseFastCopy) {
        // Fast copy for contiguous data
        std::memcpy(dst, src, totalSize * components * sizeof(T));
    }
    else {
        // Element-by-element copy with proper component handling
        for (size_t z = 0; z < zSize; ++z) {
            for (size_t y = 0; y < ySize; ++y) {
                for (size_t x = 0; x < xSize; ++x) {
                    // VTK index (interleaved components)
                    size_t vtk_base_index = (x + y * xSize + z * xSize * ySize) * components;

                    // Copy all components
                    for (size_t comp = 0; comp < components; ++comp) {
                        // NumPy index with proper component stride
                        size_t numpy_index = x * xStride + y * yStride + z * zStride + comp * componentStride;
                        dst[vtk_base_index + comp] = src[numpy_index];
                    }
                }
            }
        }
    }

    return image;
}

// -------------------- numpyToVTKImageData --------------------
// non templated dispatcher
vtkSmartPointer<vtkImageData> numpyToVTKImageData(const pybind11::array& numpyArray) {
    const auto dtype = numpyArray.dtype();

    if (dtype.is(py::dtype::of<int8_t>())) {
        return numpyToVTKImageDataTemplated<int8_t>(numpyArray.cast<py::array_t<int8_t>>());
    }
    else if (dtype.is(py::dtype::of<uint8_t>())) {
        return numpyToVTKImageDataTemplated<uint8_t>(numpyArray.cast<py::array_t<uint8_t>>());
    }
    else if (dtype.is(py::dtype::of<int16_t>())) {
        return numpyToVTKImageDataTemplated<int16_t>(numpyArray.cast<py::array_t<int16_t>>());
    }
    else if (dtype.is(py::dtype::of<uint16_t>())) {
        return numpyToVTKImageDataTemplated<uint16_t>(numpyArray.cast<py::array_t<uint16_t>>());
    }
    else if (dtype.is(py::dtype::of<int32_t>())) {
        return numpyToVTKImageDataTemplated<int32_t>(numpyArray.cast<py::array_t<int32_t>>());
    }
    else if (dtype.is(py::dtype::of<uint32_t>())) {
        return numpyToVTKImageDataTemplated<uint32_t>(numpyArray.cast<py::array_t<uint32_t>>());
    }
    else if (dtype.is(py::dtype::of<int64_t>())) {
        return numpyToVTKImageDataTemplated<int64_t>(numpyArray.cast<py::array_t<int64_t>>());
    }
    else if (dtype.is(py::dtype::of<uint64_t>())) {
        return numpyToVTKImageDataTemplated<uint64_t>(numpyArray.cast<py::array_t<uint64_t>>());
    }
    else if (dtype.is(py::dtype::of<float>())) {
        return numpyToVTKImageDataTemplated<float>(numpyArray.cast<py::array_t<float>>());
    }
    else if (dtype.is(py::dtype::of<double>())) {
        return numpyToVTKImageDataTemplated<double>(numpyArray.cast<py::array_t<double>>());
    }
    else {
        throw std::runtime_error("Unsupported numpy array data type");
    }
    return nullptr;
}

// -------------------- newImageComponentFromNumpy --------------------
// build a new image component from a numpy array (check for unique name)
ImageComponent* newImageComponentFromNumpy(const py::array& numpyArray, const std::string& name, py::object spacingObj) {
    vtkSmartPointer<vtkImageData> imgData = numpyToVTKImageData(numpyArray);

    QString uniqueName = Application::getUniqueComponentName(name.c_str());

    // Set spacing if provided
    double spacing[3] = {1.0, 1.0, 1.0};
    if (!spacingObj.is_none()) {
        py::tuple spacingTuple = spacingObj.cast<py::tuple>();
        if (spacingTuple.size() == 3) {
            spacing[0] = spacingTuple[0].cast<double>();
            spacing[1] = spacingTuple[1].cast<double>();
            spacing[2] = spacingTuple[2].cast<double>();
        }
    }
    imgData->SetSpacing(spacing);

    ImageComponent* img = new ImageComponent(imgData, uniqueName);
    return img;
}

// -------------------- numpyToVtkPoints --------------------
vtkSmartPointer<vtkPoints> numpyToVtkPoints(py::array_t < double, py::array::c_style | py::array::forcecast > pointsArray) {
    py::ssize_t numPoints = pointsArray.shape(0);

    // Create and fill vtkPoints
    vtkSmartPointer<vtkPoints> points = vtkSmartPointer<vtkPoints>::New();
    points->SetDataTypeToDouble(); // Force double for simplicity
    points->SetNumberOfPoints(numPoints);

    const double* pointsPtr = static_cast<const double*>(pointsArray.data());

    for (vtkIdType i = 0; i < numPoints; ++i) {
        points->SetPoint(i, pointsPtr + 3 * i);
    }

    return points;
}

// -------------------- numpyToVtkPointSet --------------------
vtkSmartPointer<vtkPointSet> numpyToVtkPointSet(py::array_t < double, py::array::c_style | py::array::forcecast > pointsArray,
        py::array_t < vtkIdType, py::array::c_style | py::array::forcecast > polysArray) {

    //--1. Convert array to vtkPoints
    if (pointsArray.ndim() != 2 || pointsArray.shape(1) != 3) {
        throw std::runtime_error("Points array must have shape (N,3)");
    }

    vtkSmartPointer<vtkPoints> points = numpyToVtkPoints(pointsArray);

    // No polys? => return a point cloud only
    if (!polysArray || polysArray.size() == 0) {
        vtkSmartPointer<vtkPolyData> polydata = vtkSmartPointer<vtkPolyData>::New();
        polydata->SetPoints(points);

        // create a cell with all the points
        vtkNew<vtkCellArray> verts;
        verts->InsertNextCell(points->GetNumberOfPoints());
        for (vtkIdType i = 0; i < points->GetNumberOfPoints(); ++i) {
            verts->InsertCellPoint(i);
        }
        polydata->SetVerts(verts);

        return polydata;
    }

    //--2. Convert poly to vtkUnstructuredGrid
    // If the vtkCell type is not uniform in the polysArray structure, this method will probably not work
    if (polysArray.ndim() != 2 || polysArray.shape(1) < 2) {
        throw std::runtime_error("Polys array must have shape (M, N+1) where first column is the point count for each poly.");
    }

    py::ssize_t numCells = polysArray.shape(0);

    // We will use vtkUnstructuredGrid if general cells (quad, hex, etc.)
    auto grid = vtkSmartPointer<vtkUnstructuredGrid>::New();
    grid->SetPoints(points);

    const vtkIdType* polysPtr = static_cast<const vtkIdType*>(polysArray.data());
    py::ssize_t stride = polysArray.shape(1); // use the same stride of all vtkCell (might not work)

    for (vtkIdType i = 0; i < numCells; ++i) {
        const vtkIdType* cellData = polysPtr + i * stride;

        vtkIdType numVerts = cellData[0];
        const vtkIdType* pointIds = cellData + 1;

        uint8_t cellType;
        switch (numVerts) {
            case 1: // 1 point id → vertex
                cellType = VTK_VERTEX;
                break;
            case 2: // 2 point ids → line
                cellType = VTK_LINE;
                break;
            case 3: // 3 point ids → triangle
                cellType = VTK_TRIANGLE;
                break;
            case 4: // 4 points ids → quad (or VTK_TETRA! but not implemented yet)
                cellType = VTK_QUAD;
                break;
            // 5 = VTK_PYRAMID
            // 6 = VTK_WEDGE
            // 8 = VTK_HEXAHEDRON
            default:
                throw std::runtime_error("Unsupported cell with " + std::to_string(numVerts) + " points.");
        }

        vtkSmartPointer<vtkIdList> idList = vtkSmartPointer<vtkIdList>::New();
        for (vtkIdType j = 0; j < numVerts; ++j) {
            idList->InsertNextId(pointIds[j]);
        }

        grid->InsertNextCell(cellType, idList);
    }

    return grid;
}

// -------------------- newMeshComponentFromNumpy --------------------
MeshComponent* newMeshComponentFromNumpy(const std::string& name, py::array_t < double, py::array::c_style | py::array::forcecast > points_array, py::array_t < vtkIdType, py::array::c_style | py::array::forcecast > polys_array) {
    vtkSmartPointer<vtkPointSet> pointSet = numpyToVtkPointSet(points_array, polys_array);

    QString uniqueName = Application::getUniqueComponentName(name.c_str());
    MeshComponent* msh = new MeshComponent(pointSet, uniqueName);
    return msh;
}

// -------------------- vtkDataArrayToNumpy --------------------
py::array vtkDataArrayToNumpy(vtkSmartPointer<vtkDataArray> array) {
    int nTuples = array->GetNumberOfTuples();
    int nComps = array->GetNumberOfComponents();

    void* raw_ptr = array->GetVoidPointer(0);

    // Determine VTK type and wrap
    switch (array->GetDataType()) {
        case VTK_DOUBLE:
            return py::array_t<double>({nTuples, nComps}, static_cast<double*>(raw_ptr));
        case VTK_FLOAT:
            return py::array_t<float>({nTuples, nComps}, static_cast<float*>(raw_ptr));
        case VTK_INT:
            return py::array_t<int>({nTuples, nComps}, static_cast<int*>(raw_ptr));
        default:
            throw std::runtime_error("Unsupported VTK data type");
    }
}

// -------------------- numpyToVtkDataArray --------------------
vtkSmartPointer<vtkDataArray> numpyToVtkDataArray(py::array array) {
    py::buffer_info info = array.request();

    int nTuples = info.shape[0];
    int nComps = (info.ndim > 1) ? info.shape[1] : 1;

    if (info.format == py::format_descriptor<double>::format()) {
        vtkSmartPointer<vtkDoubleArray> array = vtkSmartPointer<vtkDoubleArray>::New();
        array->SetNumberOfComponents(nComps);
        array->SetNumberOfTuples(nTuples);
        array->SetArray(static_cast<double*>(info.ptr), nTuples * nComps, 1);
        return array;
    }
    else if (info.format == py::format_descriptor<float>::format()) {
        vtkSmartPointer<vtkFloatArray> array = vtkSmartPointer<vtkFloatArray>::New();
        array->SetNumberOfComponents(nComps);
        array->SetNumberOfTuples(nTuples);
        array->SetArray(static_cast<float*>(info.ptr), nTuples * nComps, 1);
        return array;
    }
    else if (info.format == py::format_descriptor<int>::format()) {
        vtkSmartPointer<vtkIntArray> array = vtkSmartPointer<vtkIntArray>::New();
        array->SetNumberOfComponents(nComps);
        array->SetNumberOfTuples(nTuples);
        array->SetArray(static_cast<int*>(info.ptr), nTuples * nComps, 1);
        return array;
    }
    else {
        throw std::runtime_error("Unsupported NumPy data type");
    }
}

// -------------------- numpyToVtkTransform --------------------
vtkSmartPointer<vtkTransform> numpyToVtkTransform(py::array array) {
    py::buffer_info info = array.request();

    // Check input shape (must be 4x4)
    if (info.ndim != 2 || info.shape[0] != 4 || info.shape[1] != 4 || info.format != py::format_descriptor<double>::format()) {
        throw std::runtime_error("numpyToVtkMatrix4x4: array must be a 4x4 array of float64.");
    }

    // Pointer to NumPy data (row-major) (no need to transpose)
    double* data = static_cast<double*>(info.ptr);

    vtkSmartPointer<vtkMatrix4x4> mat = vtkSmartPointer<vtkMatrix4x4>::New();

    // Copy values row-major → VTK's row-major
    for (int r = 0; r < 4; ++r) {
        for (int c = 0; c < 4; ++c) {
            mat->SetElement(r, c, data[r * 4 + c]);
        }
    }

    // Create a vtkTransform from the matrix
    vtkSmartPointer<vtkTransform> tr = vtkSmartPointer<vtkTransform>::New();
    tr->SetMatrix(mat);
    
    return tr;
}

} // namespace camitk
