86 lines
2.9 KiB
Python
86 lines
2.9 KiB
Python
import numpy as np
|
|
from typing import Any, Union, List, Tuple
|
|
|
|
class MultiDimArray:
|
|
"""
|
|
A class to represent and manipulate multi-dimensional arrays.
|
|
|
|
Attributes
|
|
----------
|
|
mdary : numpy.ndarray
|
|
A multi-dimensional array containing the input data.
|
|
shape : tuple
|
|
The shape of the input multi-dimensional array.
|
|
|
|
Methods
|
|
-------
|
|
flatten(output_type="list") -> Union[List, Tuple, np.ndarray]:
|
|
Returns the flattened version of the multi-dimensional array as a list, tuple, or Numpy array.
|
|
|
|
foldout(vector, output_type="list") -> Union[List, Tuple, np.ndarray]:
|
|
Reshapes a 1D vector back into the original shape of the multi-dimensional array,
|
|
and returns it as a list, tuple, or Numpy array.
|
|
"""
|
|
def __init__(self, mdary: Union[List, Tuple, np.ndarray]):
|
|
self.mdary = np.array(mdary)
|
|
self.shape = self.mdary.shape
|
|
|
|
def flatten(self, output_type: str = "list") -> Union[List, Tuple, np.ndarray]:
|
|
"""
|
|
Flatten the multi-dimensional array.
|
|
|
|
Parameters
|
|
----------
|
|
output_type : str, optional
|
|
The output type of the flattened array, either 'list', 'tuple', or 'numpy' (default is 'list').
|
|
|
|
Returns
|
|
-------
|
|
Union[List, Tuple, np.ndarray]
|
|
The flattened version of the multi-dimensional array in the specified output
|
|
"""
|
|
flat_array = self.mdary.flatten()
|
|
|
|
if output_type == "list":
|
|
return flat_array.tolist()
|
|
elif output_type == "tuple":
|
|
return tuple(flat_array)
|
|
elif output_type == "numpy":
|
|
return flat_array
|
|
else:
|
|
raise ValueError("Invalid output_type. Choose 'list', 'tuple', or 'numpy'")
|
|
|
|
def foldout(self, vector: Union[List, Tuple, np.ndarray], output_type: str = "list") -> Union[List, Tuple, np.ndarray]:
|
|
if len(vector) != self.mdary.size:
|
|
raise ValueError("The input vector must have the same length as the flattened form of the multi-dimensional array")
|
|
|
|
reshaped_array = np.reshape(vector, self.shape)
|
|
|
|
if output_type == "list":
|
|
return reshaped_array.tolist()
|
|
elif output_type == "tuple":
|
|
return tuple(map(tuple, reshaped_array))
|
|
elif output_type == "numpy":
|
|
return reshaped_array
|
|
else:
|
|
raise ValueError("Invalid output_type. Choose 'list', 'tuple', or 'numpy'")
|
|
|
|
if __name__ == "__main__":
|
|
"""
|
|
Example usage:
|
|
"""
|
|
mda = MultiDimArray([[1, 2], [3, 4], [5,6]])
|
|
#mda = MultiDimArray([1, 2, 3, 4, 5,6])
|
|
print(f"Input array: {str(mda.mdary.tolist())}")
|
|
flat = mda.flatten(output_type="list")
|
|
print(f"Flattened array: {flat}")
|
|
|
|
# Assuming the flat array is [1, 2, 3, 4]
|
|
folded = mda.foldout(flat, output_type="list")
|
|
print(f"Folded back array: {folded}")
|
|
|
|
"""
|
|
The folded back array should be numerically identical to the original mdary:
|
|
[[1, 2], [3, 4]]
|
|
"""
|