公司新闻

tf.keras.optimizers.experimental.AdamW

实现 AdamW 算法的优化器。

继承自:

tf.keras.optimizers.experimental.AdamW(
    learning_rate=0.001,
    weight_decay=0.004,
    beta_1=0.9,
    beta_2=0.999,
    epsilon=1e-07,
    amsgrad=False,
    clipnorm=None,
    clipvalue=None,
    global_clipnorm=None,
    use_ema=False,
    ema_momentum=0.99,
    ema_overwrite_frequency=None,
    jit_compile=True,
    name='AdamW',
    **kwargs
)

AdamW 优化是一种随机梯度下降方法,基于一阶和二阶矩的自适应估计,并根据 Loshchilov, Hutter et al., 2019 论文“解耦权重衰减正则化”中讨论的技术添加了衰减权重的方法。

根据 Kingma et al., 2014 ,底层 Adam 方法是 "computationally efficient, has little memory requirement, invariant to diagonal rescaling of gradients, and is well suited for problems that are large in terms of data/parameters" 。

Args
、浮点值、 调度或不带参数并返回要使用的实际值的可调用函数。学习率。默认为 0.001。
,浮点值。重量衰减。默认为 0.004。
浮点值或常量浮点张量,或不带参数并返回要使用的实际值的可调用对象。第一时刻估计的指数衰减率。默认为 0.9。
浮点值或常量浮点张量,或不带参数并返回要使用的实际值的可调用对象。二阶矩估计的指数衰减率。默认为 0.999。
数值稳定性的小常数。这个epsilon是Kingma和Ba论文中的 "epsilon hat" (在2.1节之前的公式中),而不是论文算法1中的epsilon。默认为 1e-7。
布尔值。是否应用论文 "On the Convergence of Adam and beyond" 中该算法的 AMSGrad 变体。默认为 。
细绳。用于优化器创建的动量累加器权重的名称。
漂浮。如果设置,每个权重的梯度将被单独裁剪,使其范数不高于该值。
漂浮。如果设置,每个权重的梯度将被裁剪为不高于该值。
漂浮。如果设置,则所有权重的梯度都会被剪裁,以便它们的全局范数不高于该值。
布尔值,默认为 False。如果为 True,则应用指数移动平均线 (EMA)。EMA 包括计算模型权重的指数移动平均值(随着每个训练批次后权重值的变化),并定期用移动平均值覆盖权重。
浮动,默认为 0.99。仅在 时使用。这是计算模型权重 EMA 时使用的动量: 。
Int 或 None,默认为 None。仅在 时使用。每 迭代步骤,我们都会用其移动平均值覆盖模型变量。如果为 None,则优化器不会在训练过程中覆盖模型变量,并且您需要在训练结束时通过调用 (就地更新模型变量)显式覆盖变量。使用内置 训练循环时,这会在最后一个 epoch 之后自动发生,您无需执行任何操作。
布尔值,默认为 True 。如果是 True ,优化器将使用 XLA 编译。使用 训练时, 不能为 True 。此外,如果没有找到 GPU 设备,该标志将被忽略。
关键字参数仅用于向后兼容。

Reference:

Notes:

一般来说,epsilon 的默认值 1e-7 可能不是一个好的默认值。例如,在 ImageNet 上训练 Inception 网络时,当前较好的选择是 1.0 或 0.1。请注意,由于 Adam 使用的是 Kingma 和 Ba 论文 2.1 节之前的公式,而不是算法 1 中的公式,因此这里提到的 "epsilon" 在论文中是 "epsilon hat" 。

该算法的稀疏实现(当梯度是 IndexedSlices 对象时使用,通常是因为 或前向传递中的嵌入查找)确实将动量应用于变量切片,即使它们没有在前向传递中使用(意味着它们具有梯度为零)。动量衰减 (beta1) 也应用于整个动量累加器。这意味着稀疏行为等同于密集行为(与一些忽略动量的动量实现相反,除非实际使用了可变切片)。

Attributes
已运行的训练步骤数。

默认情况下,每次调用 时迭代都会增加 1。

View source

add_variable(
    shape, dtype=None, initializer='zeros', name=None
)

创建优化器变量。

Args
整数列表、整数元组或 int32 类型的一维张量。如果未指定,则默认为标量。
要创建的优化器变量的 DType。如果未指定,则默认为
字符串或可调用。初始化器实例。
要创建的优化器变量的名称。
Returns
优化器变量,格式为 tf.Variable。

View source

add_variable_from_reference(
    model_variable, variable_name, shape=None, initial_value=None
)

从模型变量创建优化器变量。

根据模型变量的信息创建优化器变量。例如,在 SGD 优化器动量中,对于每个模型变量,都会创建具有相同形状和数据类型的相应动量变量。

Args
tf.变量。要创建的优化器变量对应的模型变量。
细绳。要创建的优化器变量的名称前缀。创建变量名称将遵循模式 ,例如 。
列表或元组,默认为 None。要创建的优化器变量的形状。如果没有,创建的变量将具有与 相同的形状。
Tensor 或可转换为 Tensor 的 Python 对象默认为 None。优化器变量的初始值,如果为None,则初始值将默认为0。
Returns
优化器变量。

View source

aggregate_gradients(
    grads_and_vars
)

聚合所有设备上的梯度。

默认情况下,我们将跨设备执行梯度的reduce_sum。用户可以通过重写此方法来实现自己的聚合逻辑。

Args
(梯度、变量)对的列表。
Returns
(梯度、变量)对的列表。

View source

apply_gradients(
    grads_and_vars, skip_gradients_aggregation=False
)

将梯度应用于变量。

Args
(梯度、变量)对的列表。
如果为 true,则不会在优化器内部执行梯度聚合。当您在优化器外部编写聚合梯度的自定义代码时,通常此参数设置为 True 。
Returns
None
Raises
如果 畸形。
如果在跨副本上下文中调用。

View source

build(
    var_list
)

初始化优化器变量。

AdamW优化器有3种类型的变量:动量、速度和velocity_hat(仅在应用amsgrad时设置),

Args
用于构建 AdamW 变量的模型变量列表。

View source

compute_gradients(
    loss, var_list, tape=None
)

计算可训练变量的损失梯度。

Args
或可调用。如果是可调用的, 应该不带参数并返回要最小化的值。
要更新的 对象的列表或元组以最小化 。
(可选) 。如果 作为 提供,则必须提供计算 的磁带。
Returns
(梯度、变量)对的列表。变量始终存在,但梯度可以是 。

View source

finalize_variable_values(
    var_list
)

设置模型可训练变量的最终值。

有时在结束变量更新之前会执行一些额外的步骤,例如用平均值覆盖模型变量。

Args
模型变量列表。

View source

@classmethod
from_config(
    config
)

从其配置创建优化器。

此方法与 相反,能够从配置字典实例化相同的优化器。

Args
Python 字典,通常是 get_config 的输出。
Returns
优化器实例。

View source

get_config()

返回优化器的配置。

优化器配置是包含优化器配置的 Python 字典(可序列化)。稍后可以从此配置重新实例化相同的优化器(无需任何保存的状态)。

子类优化器应该重写此方法以包含其他超参数。

Returns
Python dictionary.

View source

minimize(
    loss, var_list, tape=None
)

通过更新 最小化 。

此方法仅使用 计算梯度并调用 。如果您想在应用之前处理渐变,请显式调用 和 ,而不是使用此函数。

Args
或可调用。如果是可调用的, 应该不带参数并返回要最小化的值。
要更新的 对象的列表或元组以最小化 。
(Optional) .
Returns
None

View source

update_step(
    gradient, variable
)

更新给定梯度和相关模型变量的步骤。

© 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/keras/optimizers/experimental/AdamW

平台注册入口