alice/code/mda.py

86 lines
2.9 KiB
Python

import numpy as np
from typing import Any, Union, List, Tuple
class MultiDimArray:
"""
A class to represent and manipulate multi-dimensional arrays.
Attributes
----------
mdary : numpy.ndarray
A multi-dimensional array containing the input data.
shape : tuple
The shape of the input multi-dimensional array.
Methods
-------
flatten(output_type="list") -> Union[List, Tuple, np.ndarray]:
Returns the flattened version of the multi-dimensional array as a list, tuple, or Numpy array.
foldout(vector, output_type="list") -> Union[List, Tuple, np.ndarray]:
Reshapes a 1D vector back into the original shape of the multi-dimensional array,
and returns it as a list, tuple, or Numpy array.
"""
def __init__(self, mdary: Union[List, Tuple, np.ndarray]):
self.mdary = np.array(mdary)
self.shape = self.mdary.shape
def flatten(self, output_type: str = "list") -> Union[List, Tuple, np.ndarray]:
"""
Flatten the multi-dimensional array.
Parameters
----------
output_type : str, optional
The output type of the flattened array, either 'list', 'tuple', or 'numpy' (default is 'list').
Returns
-------
Union[List, Tuple, np.ndarray]
The flattened version of the multi-dimensional array in the specified output
"""
flat_array = self.mdary.flatten()
if output_type == "list":
return flat_array.tolist()
elif output_type == "tuple":
return tuple(flat_array)
elif output_type == "numpy":
return flat_array
else:
raise ValueError("Invalid output_type. Choose 'list', 'tuple', or 'numpy'")
def foldout(self, vector: Union[List, Tuple, np.ndarray], output_type: str = "list") -> Union[List, Tuple, np.ndarray]:
if len(vector) != self.mdary.size:
raise ValueError("The input vector must have the same length as the flattened form of the multi-dimensional array")
reshaped_array = np.reshape(vector, self.shape)
if output_type == "list":
return reshaped_array.tolist()
elif output_type == "tuple":
return tuple(map(tuple, reshaped_array))
elif output_type == "numpy":
return reshaped_array
else:
raise ValueError("Invalid output_type. Choose 'list', 'tuple', or 'numpy'")
if __name__ == "__main__":
"""
Example usage:
"""
mda = MultiDimArray([[1, 2], [3, 4], [5,6]])
#mda = MultiDimArray([1, 2, 3, 4, 5,6])
print(f"Input array: {str(mda.mdary.tolist())}")
flat = mda.flatten(output_type="list")
print(f"Flattened array: {flat}")
# Assuming the flat array is [1, 2, 3, 4]
folded = mda.foldout(flat, output_type="list")
print(f"Folded back array: {folded}")
"""
The folded back array should be numerically identical to the original mdary:
[[1, 2], [3, 4]]
"""