Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit bbb258a

Browse files
committed
test ci
1 parent e36c9f0 commit bbb258a

3 files changed

Lines changed: 133 additions & 4 deletions

File tree

python/mxnet/_ctypes/cached_op.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __call__(self, *args, **kwargs):
7777
if not default_device:
7878
default_device = kwargs.pop('default_ctx', None)
7979
out = kwargs.pop('out', None)
80+
nleaf_vars = [container.data() for container in kwargs.pop('_nleaf_vars', [])]
8081
if kwargs:
8182
raise TypeError(
8283
"CachedOp.__call__ got unexpected keyword argument(s): " + \
@@ -93,7 +94,10 @@ def __call__(self, *args, **kwargs):
9394
*args,
9495
type_id,
9596
device_id,
96-
*out_arg
97+
len(out_arg),
98+
*out_arg,
99+
len(nleaf_vars),
100+
*nleaf_vars
97101
)
98102
if out is not None:
99103
return out

python/mxnet/gluon/block.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@
3333
import json
3434
import numpy as np
3535

36-
from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB
36+
from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB, \
37+
_as_list
3738
from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \
3839
profiler as _profiler, device as _device
3940
from ..symbol.numpy import _symbol as np_symbol
4041
from ..symbol import Symbol, fromjson
4142
from ..ndarray import NDArray, get_dtype_name
42-
from .parameter import Parameter, DeferredInitializationError
43+
from .parameter import Parameter, DeferredInitializationError, Intermediate
4344
from .utils import _indent, _brief_print_list, HookHandle, shape_is_known
4445
from .utils import _check_same_symbol_type, _check_all_np_ndarrays, _check_block_input_np_ndarrays
4546
from .. import numpy_extension as _mx_npx
@@ -1091,6 +1092,7 @@ def __init__(self):
10911092
self._backend_opts = {}
10921093
self._partition_if_dynamic = True
10931094
self._first_forward = True
1095+
self._nleaf_vars = OrderedDict()
10941096

10951097
def __setattr__(self, name, value):
10961098
"""Registers parameters."""
@@ -1302,7 +1304,7 @@ def _call_cached_op(self, *args):
13021304
args_without_none = [ele for ele in args if ele is not None]
13031305
cargs = [args_without_none[i] if is_arg else i.data()
13041306
for is_arg, name, i in self._cached_op_args]
1305-
out = self._cached_op(*cargs)
1307+
out = self._cached_op(*cargs, _nleaf_vars=self._nleaf_vars.values())
13061308
if isinstance(out, NDArray):
13071309
out = [out]
13081310
return _regroup(out, self._out_format)
@@ -1678,6 +1680,92 @@ def reset_ctx(self, ctx):
16781680
self.reset_device(ctx)
16791681

16801682

1683+
def intermediate(self, names, var_arrays_inp, grad_req='write'):
1684+
"""Mark the intermediate variables.
1685+
1686+
Parameters
1687+
----------
1688+
name : str or tuple[str], name of the registered intermediate variable
1689+
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
1690+
grad_req : str, gradient request
1691+
"""
1692+
if not self._active:
1693+
var_arrays = _as_list(var_arrays_inp)
1694+
names = _as_list(names)
1695+
self._nleaf_vars.update(
1696+
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
1697+
else:
1698+
prev_val = dc.set_deferred_compute(False)
1699+
var_arrays = _as_list(var_arrays_inp)
1700+
names = _as_list(names)
1701+
# Prepare ctypes array types
1702+
import ctypes
1703+
var_handles_type = ctypes.c_void_p * len(var_arrays)
1704+
# Convert handles
1705+
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
1706+
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
1707+
self._nleaf_vars.update(
1708+
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
1709+
dc.set_deferred_compute(prev_val)
1710+
return var_arrays_inp
1711+
1712+
def attach_grad_intermediate(self):
1713+
"""Attach gradient to all the intermediate variables.
1714+
"""
1715+
for val in self._nleaf_vars.values():
1716+
val.data().attach_grad(grad_req=val.grad_req)
1717+
1718+
def get_intermediate(self, names):
1719+
"""Get the intermediate variables by names
1720+
"""
1721+
if isinstance(names, list):
1722+
return [self._nleaf_vars[n] for n in names]
1723+
else:
1724+
return self._nleaf_vars[names]
1725+
1726+
def intermediate(self, names, var_arrays_inp, grad_req='write'):
1727+
"""Mark the intermediate variables.
1728+
1729+
Parameters
1730+
----------
1731+
name : str or tuple[str], name of the registered intermediate variable
1732+
var_arrays_inp : ndarray or tuple[ndarray], the output of the expression
1733+
grad_req : str, gradient request
1734+
"""
1735+
if not self._active:
1736+
var_arrays = _as_list(var_arrays_inp)
1737+
names = _as_list(names)
1738+
self._nleaf_vars.update(
1739+
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
1740+
else:
1741+
prev_val = dc.set_deferred_compute(False)
1742+
var_arrays = _as_list(var_arrays_inp)
1743+
names = _as_list(names)
1744+
# Prepare ctypes array types
1745+
import ctypes
1746+
var_handles_type = ctypes.c_void_p * len(var_arrays)
1747+
# Convert handles
1748+
var_handles = var_handles_type(*[arr.handle for arr in var_arrays])
1749+
check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars)))
1750+
self._nleaf_vars.update(
1751+
{name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)})
1752+
dc.set_deferred_compute(prev_val)
1753+
return var_arrays_inp
1754+
1755+
def attach_grad_intermediate(self):
1756+
"""Attach gradient to all the intermediate variables.
1757+
"""
1758+
for val in self._nleaf_vars.values():
1759+
val.data().attach_grad(grad_req=val.grad_req)
1760+
1761+
def get_intermediate(self, names):
1762+
"""Get the intermediate variables by names
1763+
"""
1764+
if isinstance(names, list):
1765+
return [self._nleaf_vars[n] for n in names]
1766+
else:
1767+
return self._nleaf_vars[names]
1768+
16811769
class SymbolBlock(HybridBlock):
16821770
"""Construct block from symbol. This is useful for using pre-trained models
16831771
as feature extractors. For example, you may want to extract the output

python/mxnet/gluon/parameter.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,3 +773,40 @@ def grad_req(self, req):
773773
warnings.warn('Constant parameter "{}" does not support '
774774
'grad_req other than "null", and new value "{}" '
775775
'is ignored.'.format(self.name, req))
776+
777+
class Intermediate:
778+
"""A Container holding marked intermediate variables of Blocks.
779+
780+
Parameters
781+
----------
782+
name : str.
783+
Name of this parameter. It be used to retrieve the marked variables.
784+
grad_req : {'write', 'add', 'null'}, default 'write'
785+
Specifies how to update gradient to grad arrays.
786+
787+
- ``'write'`` means everytime gradient is written to grad :py:class:`NDArray`.
788+
- ``'add'`` means everytime gradient is added to the grad :py:class:`NDArray`. You need
789+
to manually call ``zero_grad()`` to clear the gradient buffer before each
790+
iteration when using this option.
791+
- 'null' means gradient is not requested for this parameter. gradient arrays
792+
will not be allocated.
793+
"""
794+
def __init__(self, name, data=None, grad_req='write'):
795+
self._name = name
796+
self._data = data
797+
self._grad_req = grad_req
798+
799+
def __repr__(self):
800+
s = 'Intermediate name={name}'
801+
return s.format(name=self._name)
802+
803+
def data(self):
804+
return self._data
805+
806+
@property
807+
def name(self):
808+
return self._name
809+
810+
@property
811+
def grad_req(self):
812+
return self._grad_req

0 commit comments

Comments
 (0)