safejax.core.load
deserialize(path_or_buf, fs=None, freeze_dict=False, requires_unflattening=True, to_var_collection=False)
Deserialize JAX, Flax, Haiku, or Objax model params from either a bytes
object or a file path,
stored using safetensors.flax.save_file
or directly saved using safejax.save.serialize
with
the filename
parameter.
Note
The default behavior of this function is to restore a Dict[str, jnp.DeviceArray]
from
a bytes
object or a file path. If you are using objax
, you should set requires_unflattening
to False
and to_var_collection
to True
to restore a VarCollection
. If you're using flax
you
should set freeze_dict
to True
to restore a FrozenDict
. Those are just tips on how to use it
but all those frameworks are compatible with the default behavior.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path_or_buf |
Union[PathLike, bytes]
|
A |
required |
fs |
Union[AbstractFileSystem, None]
|
The filesystem to use to load the model params. Defaults to |
None
|
freeze_dict |
bool
|
Whether to freeze the output |
False
|
requires_unflattening |
bool
|
Whether the model params require unflattening or not. Defaults to |
True
|
to_var_collection |
bool
|
Whether to convert the output |
False
|
Returns:
Type | Description |
---|---|
Union[ParamsDictLike, Tuple[ParamsDictLike, Dict[str, str]]]
|
A |
Union[ParamsDictLike, Tuple[ParamsDictLike, Dict[str, str]]]
|
or in case |
Union[ParamsDictLike, Tuple[ParamsDictLike, Dict[str, str]]]
|
model params and the metadata (in that order). |
Source code in safejax/core/load.py
Created: 2023-01-19