Last active
January 24, 2024 07:52
-
-
Save robertknight/ad54cc02a79d0824e6e576401d3d433e to your computer and use it in GitHub Desktop.
Revisions
-
robertknight revised this gist
Jan 24, 2024 . 1 changed file with 1 addition and 1 deletion.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -15,7 +15,7 @@ where { view.to_slice().map(|slice| { let shape: [usize; N] = view.shape().try_into().unwrap(); NdTensorView::from_data(shape, slice) }) } -
robertknight revised this gist
Jan 6, 2024 . 1 changed file with 33 additions and 15 deletions.There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -1,32 +1,46 @@ use ndarray::{Array, Array2, ArrayView, Dim, Dimension, Ix, StrideShape}; use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView}; /// Convert an N-dimensional ndarray view to an [NdTensorView]. /// /// Returns `None` if the view is not in the standard layout (see /// [ArrayView::is_standard_layout]). fn as_ndtensor_view<'a, T, const N: usize>( view: ArrayView<'a, T, Dim<[Ix; N]>>, ) -> Option<NdTensorView<'a, T, N>> where Dim<[Ix; N]>: Dimension, { view.to_slice().map(|slice| { let shape: [usize; N] = view.shape().try_into().unwrap(); NdTensorView::from_slice(slice, shape, None).expect("incorrect slice length") }) } /// Convert an N-dimensional [NdTensorView] into an ndarray view. /// /// Returns `None` if the view is not in "standard layout" (see /// [ArrayView::is_standard_layout]). fn as_array_view<'a, T, const N: usize>( view: NdTensorView<'a, T, N>, ) -> Option<ArrayView<'a, T, Dim<[Ix; N]>>> where Dim<[Ix; N]>: Dimension, [usize; N]: Into<StrideShape<Dim<[Ix; N]>>>, { view.data() .map(|data| ArrayView::from_shape(view.shape(), data).unwrap()) } /// Convert an N-dimensional [NdTensor] into an ndarray. fn into_array<T, const N: usize>(tensor: NdTensor<T, N>) -> Array<T, Dim<[Ix; N]>> where Dim<[Ix; N]>: Dimension, [usize; N]: Into<StrideShape<Dim<[Ix; N]>>>, { Array::from_shape_vec(tensor.shape(), tensor.into_data()).unwrap() } fn main() { @@ -37,14 +51,18 @@ fn main() { array[[1, 0]] = 3.; array[[1, 1]] = 4.; let view = as_ndtensor_view(array.view()).expect("non-contiguous view"); for (idx, el) in view.indices().zip(view.iter()) { println!("index {:?} element {}", idx, el); } // NdTensor => ArrayView let permuted_owned = view.permuted([1, 0]).to_tensor(); let ndarray_view = as_array_view(permuted_owned.view()).expect("non-contiguous view"); println!("ndarray_view {:?}", ndarray_view); // Ndtensor => Array let ndarray = into_array(permuted_owned); println!("ndarray {:?}", ndarray); } -
robertknight created this gist
Jan 6, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,50 @@ use ndarray::{ Array2, ArrayView, Dim, Dimension, Ix, StrideShape, }; use rten_tensor::prelude::*; use rten_tensor::NdTensorView; /// Convert an N-dimensional ndarray view to an [NdTensorView]. /// /// This requires that the ndarray view is contiguous. fn to_rten_view<'a, T, const N: usize>( view: ArrayView<'a, T, Dim<[Ix; N]>>, ) -> NdTensorView<'a, T, N> where Dim<[Ix; N]>: Dimension, { let shape: [usize; N] = view.shape().try_into().unwrap(); NdTensorView::from_slice(view.to_slice().unwrap(), shape, None).unwrap() } /// Convert an N-dimensional [NdTensorView] into an ndarray view. fn to_ndarray_view<'a, T, const N: usize>( view: NdTensorView<'a, T, N>, ) -> ArrayView<'a, T, Dim<[Ix; N]>> where Dim<[Ix; N]>: Dimension, [usize; N]: Into<StrideShape<Dim<[Ix; N]>>>, { ArrayView::from_shape(view.shape(), view.data().unwrap()).unwrap() } fn main() { // Owned ndarray => NdTensorView let mut array: Array2<f32> = Array2::zeros([2, 2]); array[[0, 0]] = 1.; array[[0, 1]] = 2.; array[[1, 0]] = 3.; array[[1, 1]] = 4.; let view = to_rten_view(array.view()); for (idx, el) in view.indices().zip(view.iter()) { println!("index {:?} element {}", idx, el); } // NdTensor => ArrayView let permuted_owned = view.permuted([1, 0]).to_tensor(); let ndarray_view = to_ndarray_view(permuted_owned.view()); println!("new array {:?}", ndarray_view); }