safejax.utils
cast_objax_variables(params)
Cast the jnp.DeviceArray
to their corresponding objax.variable
types.
Note
This function may return the same params
if no objax.variable
types
are found in the keys.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
JaxDeviceArrayDict
|
A |
required |
Raises:
Type | Description |
---|---|
ValueError
|
If the params were not serialized from a |
Returns:
Type | Description |
---|---|
Union[JaxDeviceArrayDict, ObjaxDict]
|
A |
Union[JaxDeviceArrayDict, ObjaxDict]
|
with the |
Union[JaxDeviceArrayDict, ObjaxDict]
|
|
Source code in safejax/utils.py
flatten_dict(params, key_prefix=None, include_objax_variables=False)
Flatten a Dict
, FrozenDict
, or VarCollection
, for more detailed information on
the supported input types check safejax.typing.ParamsDictLike
.
Note
This function is recursive to explore all the nested dictionaries,
and the keys are being flattened using the .
character. So that the
later de-nesting can be done using the .
character as a separator.
Reference at https://gist.github.com/Narsil/d5b0d747e5c8c299eb6d82709e480e3d
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
ParamsDictLike
|
A |
required |
key_prefix |
Union[str, None]
|
A prefix to prepend to the keys of the flattened dictionary. |
None
|
include_objax_variables |
bool
|
A boolean indicating whether to include the |
False
|
Returns:
Type | Description |
---|---|
Union[NumpyArrayDict, JaxDeviceArrayDict]
|
A |
Source code in safejax/utils.py
unflatten_dict(params)
Unflatten a Dict
where the keys should be expanded using the .
character
as a separator.
Note
If the params where serialized from a VarCollection
object, then the
objax.variable
types are included in the keys, and since this function
just unflattens the dictionary without objax.variable
casting, those
variables will be ignored and unflattened normally. Anyway, when deserializing
objax
models you should use safejax.objax.deserialize
or just use the
function params in safejax.deserialize
: requires_unflattening=False
and
to_var_collection=True
.
Reference at https://stackoverflow.com/a/63545677.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
params |
Union[NumpyArrayDict, JaxDeviceArrayDict]
|
A |
required |
Returns:
Type | Description |
---|---|
Dict[str, Any]
|
An unflattened |
Source code in safejax/utils.py
Created: 2023-01-19