Skip to content

Instantly share code, notes, and snippets.

@robertknight
Last active January 24, 2024 07:52
Show Gist options
  • Save robertknight/ad54cc02a79d0824e6e576401d3d433e to your computer and use it in GitHub Desktop.
Save robertknight/ad54cc02a79d0824e6e576401d3d433e to your computer and use it in GitHub Desktop.

Revisions

  1. robertknight revised this gist Jan 24, 2024. 1 changed file with 1 addition and 1 deletion.
    2 changes: 1 addition & 1 deletion rten_ndarray_conv.rs
    Original file line number Diff line number Diff line change
    @@ -15,7 +15,7 @@ where
    {
    view.to_slice().map(|slice| {
    let shape: [usize; N] = view.shape().try_into().unwrap();
    NdTensorView::from_slice(slice, shape, None).expect("incorrect slice length")
    NdTensorView::from_data(shape, slice)
    })
    }

  2. robertknight revised this gist Jan 6, 2024. 1 changed file with 33 additions and 15 deletions.
    48 changes: 33 additions & 15 deletions rten_ndarray_conv.rs
    Original file line number Diff line number Diff line change
    @@ -1,32 +1,46 @@
    use ndarray::{
    Array2, ArrayView, Dim, Dimension, Ix, StrideShape,
    };
    use ndarray::{Array, Array2, ArrayView, Dim, Dimension, Ix, StrideShape};

    use rten_tensor::prelude::*;
    use rten_tensor::NdTensorView;
    use rten_tensor::{NdTensor, NdTensorView};

    /// Convert an N-dimensional ndarray view to an [NdTensorView].
    ///
    /// This requires that the ndarray view is contiguous.
    fn to_rten_view<'a, T, const N: usize>(
    /// Returns `None` if the view is not in the standard layout (see
    /// [ArrayView::is_standard_layout]).
    fn as_ndtensor_view<'a, T, const N: usize>(
    view: ArrayView<'a, T, Dim<[Ix; N]>>,
    ) -> NdTensorView<'a, T, N>
    ) -> Option<NdTensorView<'a, T, N>>
    where
    Dim<[Ix; N]>: Dimension,
    {
    let shape: [usize; N] = view.shape().try_into().unwrap();
    NdTensorView::from_slice(view.to_slice().unwrap(), shape, None).unwrap()
    view.to_slice().map(|slice| {
    let shape: [usize; N] = view.shape().try_into().unwrap();
    NdTensorView::from_slice(slice, shape, None).expect("incorrect slice length")
    })
    }

    /// Convert an N-dimensional [NdTensorView] into an ndarray view.
    fn to_ndarray_view<'a, T, const N: usize>(
    ///
    /// Returns `None` if the view is not in "standard layout" (see
    /// [ArrayView::is_standard_layout]).
    fn as_array_view<'a, T, const N: usize>(
    view: NdTensorView<'a, T, N>,
    ) -> ArrayView<'a, T, Dim<[Ix; N]>>
    ) -> Option<ArrayView<'a, T, Dim<[Ix; N]>>>
    where
    Dim<[Ix; N]>: Dimension,
    [usize; N]: Into<StrideShape<Dim<[Ix; N]>>>,
    {
    view.data()
    .map(|data| ArrayView::from_shape(view.shape(), data).unwrap())
    }

    /// Convert an N-dimensional [NdTensor] into an ndarray.
    fn into_array<T, const N: usize>(tensor: NdTensor<T, N>) -> Array<T, Dim<[Ix; N]>>
    where
    Dim<[Ix; N]>: Dimension,
    [usize; N]: Into<StrideShape<Dim<[Ix; N]>>>,
    {
    ArrayView::from_shape(view.shape(), view.data().unwrap()).unwrap()
    Array::from_shape_vec(tensor.shape(), tensor.into_data()).unwrap()
    }

    fn main() {
    @@ -37,14 +51,18 @@ fn main() {
    array[[1, 0]] = 3.;
    array[[1, 1]] = 4.;

    let view = to_rten_view(array.view());
    let view = as_ndtensor_view(array.view()).expect("non-contiguous view");

    for (idx, el) in view.indices().zip(view.iter()) {
    println!("index {:?} element {}", idx, el);
    }

    // NdTensor => ArrayView
    let permuted_owned = view.permuted([1, 0]).to_tensor();
    let ndarray_view = to_ndarray_view(permuted_owned.view());
    println!("new array {:?}", ndarray_view);
    let ndarray_view = as_array_view(permuted_owned.view()).expect("non-contiguous view");
    println!("ndarray_view {:?}", ndarray_view);

    // Ndtensor => Array
    let ndarray = into_array(permuted_owned);
    println!("ndarray {:?}", ndarray);
    }
  3. robertknight created this gist Jan 6, 2024.
    50 changes: 50 additions & 0 deletions rten_ndarray_conv.rs
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,50 @@
    use ndarray::{
    Array2, ArrayView, Dim, Dimension, Ix, StrideShape,
    };

    use rten_tensor::prelude::*;
    use rten_tensor::NdTensorView;

    /// Convert an N-dimensional ndarray view to an [NdTensorView].
    ///
    /// This requires that the ndarray view is contiguous.
    fn to_rten_view<'a, T, const N: usize>(
    view: ArrayView<'a, T, Dim<[Ix; N]>>,
    ) -> NdTensorView<'a, T, N>
    where
    Dim<[Ix; N]>: Dimension,
    {
    let shape: [usize; N] = view.shape().try_into().unwrap();
    NdTensorView::from_slice(view.to_slice().unwrap(), shape, None).unwrap()
    }

    /// Convert an N-dimensional [NdTensorView] into an ndarray view.
    fn to_ndarray_view<'a, T, const N: usize>(
    view: NdTensorView<'a, T, N>,
    ) -> ArrayView<'a, T, Dim<[Ix; N]>>
    where
    Dim<[Ix; N]>: Dimension,
    [usize; N]: Into<StrideShape<Dim<[Ix; N]>>>,
    {
    ArrayView::from_shape(view.shape(), view.data().unwrap()).unwrap()
    }

    fn main() {
    // Owned ndarray => NdTensorView
    let mut array: Array2<f32> = Array2::zeros([2, 2]);
    array[[0, 0]] = 1.;
    array[[0, 1]] = 2.;
    array[[1, 0]] = 3.;
    array[[1, 1]] = 4.;

    let view = to_rten_view(array.view());

    for (idx, el) in view.indices().zip(view.iter()) {
    println!("index {:?} element {}", idx, el);
    }

    // NdTensor => ArrayView
    let permuted_owned = view.permuted([1, 0]).to_tensor();
    let ndarray_view = to_ndarray_view(permuted_owned.view());
    println!("new array {:?}", ndarray_view);
    }