chinese直男口爆体育生外卖, 99久久er热在这里只有精品99, 又色又爽又黄18禁美女裸身无遮挡, gogogo高清免费观看日本电视,私密按摩师高清版在线,人妻视频毛茸茸,91论坛 兴趣闲谈,欧美 亚洲 精品 8区,国产精品久久久久精品免费

0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內(nèi)不再提示

PyTorch教程-9.6. 遞歸神經(jīng)網(wǎng)絡的簡潔實現(xiàn)

jf_pJlTbmA9 ? 來源:PyTorch ? 作者:PyTorch ? 2023-06-05 15:44 ? 次閱讀
加入交流群
微信小助手二維碼

掃碼添加小助手

加入工程師交流群

與我們大多數(shù)從頭開始的實施一樣, 第 9.5 節(jié)旨在深入了解每個組件的工作原理。但是,當您每天使用 RNN 或編寫生產(chǎn)代碼時,您會希望更多地依賴于減少實現(xiàn)時間(通過為通用模型和函數(shù)提供庫代碼)和計算時間(通過優(yōu)化這些庫實現(xiàn))。本節(jié)將向您展示如何使用深度學習框架提供的高級 API 更有效地實現(xiàn)相同的語言模型。和以前一樣,我們首先加載時間機器數(shù)據(jù)集。

import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l

from mxnet import np, npx
from mxnet.gluon import nn, rnn
from d2l import mxnet as d2l

npx.set_np()

from flax import linen as nn
from jax import numpy as jnp
from d2l import jax as d2l

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

import tensorflow as tf
from d2l import tensorflow as d2l

9.6.1. 定義模型

我們使用由高級 API 實現(xiàn)的 RNN 定義以下類。

class RNN(d2l.Module): #@save
  """The RNN model implemented with high-level APIs."""
  def __init__(self, num_inputs, num_hiddens):
    super().__init__()
    self.save_hyperparameters()
    self.rnn = nn.RNN(num_inputs, num_hiddens)

  def forward(self, inputs, H=None):
    return self.rnn(inputs, H)

Specifically, to initialize the hidden state, we invoke the member method begin_state. This returns a list that contains an initial hidden state for each example in the minibatch, whose shape is (number of hidden layers, batch size, number of hidden units). For some models to be introduced later (e.g., long short-term memory), this list will also contain other information.

class RNN(d2l.Module): #@save
  """The RNN model implemented with high-level APIs."""
  def __init__(self, num_hiddens):
    super().__init__()
    self.save_hyperparameters()
    self.rnn = rnn.RNN(num_hiddens)

  def forward(self, inputs, H=None):
    if H is None:
      H, = self.rnn.begin_state(inputs.shape[1], ctx=inputs.ctx)
    outputs, (H, ) = self.rnn(inputs, (H, ))
    return outputs, H

Flax does not provide an RNNCell for concise implementation of Vanilla RNNs as of today. There are more advanced variants of RNNs like LSTMs and GRUs which are available in the Flax linen API.

class RNN(nn.Module): #@save
  """The RNN model implemented with high-level APIs."""
  num_hiddens: int

  @nn.compact
  def __call__(self, inputs, H=None):
    raise NotImplementedError

class RNN(d2l.Module): #@save
  """The RNN model implemented with high-level APIs."""
  def __init__(self, num_hiddens):
    super().__init__()
    self.save_hyperparameters()
    self.rnn = tf.keras.layers.SimpleRNN(
      num_hiddens, return_sequences=True, return_state=True,
      time_major=True)

  def forward(self, inputs, H=None):
    outputs, H = self.rnn(inputs, H)
    return outputs, H

繼承自9.5 節(jié)RNNLMScratch中的類 ,下面的類定義了一個完整的基于 RNN 的語言模型。請注意,我們需要創(chuàng)建一個單獨的全連接輸出層。RNNLM

class RNNLM(d2l.RNNLMScratch): #@save
  """The RNN-based language model implemented with high-level APIs."""
  def init_params(self):
    self.linear = nn.LazyLinear(self.vocab_size)

  def output_layer(self, hiddens):
    return self.linear(hiddens).swapaxes(0, 1)

class RNNLM(d2l.RNNLMScratch): #@save
  """The RNN-based language model implemented with high-level APIs."""
  def init_params(self):
    self.linear = nn.Dense(self.vocab_size, flatten=False)
    self.initialize()
  def output_layer(self, hiddens):
    return self.linear(hiddens).swapaxes(0, 1)

class RNNLM(d2l.RNNLMScratch): #@save
  """The RNN-based language model implemented with high-level APIs."""
  training: bool = True

  def setup(self):
    self.linear = nn.Dense(self.vocab_size)

  def output_layer(self, hiddens):
    return self.linear(hiddens).swapaxes(0, 1)

  def forward(self, X, state=None):
    embs = self.one_hot(X)
    rnn_outputs, _ = self.rnn(embs, state, self.training)
    return self.output_layer(rnn_outputs)

class RNNLM(d2l.RNNLMScratch): #@save
  """The RNN-based language model implemented with high-level APIs."""
  def init_params(self):
    self.linear = tf.keras.layers.Dense(self.vocab_size)

  def output_layer(self, hiddens):
    return tf.transpose(self.linear(hiddens), (1, 0, 2))

9.6.2. 訓練和預測

在訓練模型之前,讓我們使用隨機權(quán)重初始化的模型進行預測。鑒于我們還沒有訓練網(wǎng)絡,它會產(chǎn)生無意義的預測。

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNN(num_inputs=len(data.vocab), num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)

'it hasgggggggggggggggggggg'

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNN(num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)

'it hasxlxlxlxlxlxlxlxlxlxl'

data = d2l.TimeMachine(batch_size=1024, num_steps=32)
rnn = RNN(num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)

'it hasnvjdtagwbcsxvcjwuyby'

接下來,我們利用高級 API 訓練我們的模型。

trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)

poYBAGR9NrKAA2V1ABG9IJKp_s8858.svg

trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
trainer.fit(model, data)

poYBAGR9NrmAC0QYABHpbt_PvZk929.svg

with d2l.try_gpu():
  trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1)
trainer.fit(model, data)

poYBAGR9NsGAZ5qbABHCG7mYLzs874.svg

與第 9.5 節(jié)相比,該模型實現(xiàn)了相當?shù)睦Щ蠖?,但由于實現(xiàn)優(yōu)化,運行速度更快。和以前一樣,我們可以在指定的前綴字符串之后生成預測標記。

model.predict('it has', 20, data.vocab, d2l.try_gpu())

'it has and the time trave '

model.predict('it has', 20, data.vocab, d2l.try_gpu())

'it has and the thi baid th'

model.predict('it has', 20, data.vocab)

'it has our in the time tim'

9.6.3. 概括

深度學習框架中的高級 API 提供標準 RNN 的實現(xiàn)。這些庫可幫助您避免浪費時間重新實現(xiàn)標準模型。此外,框架實施通常經(jīng)過高度優(yōu)化,與從頭開始實施相比,可顯著提高(計算)性能。

9.6.4. 練習

您能否使用高級 API 使 RNN 模型過擬合?

使用 RNN實現(xiàn)第 9.1 節(jié)的自回歸模型。

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學習之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • 神經(jīng)網(wǎng)絡

    關(guān)注

    42

    文章

    4819

    瀏覽量

    106111
  • pytorch
    +關(guān)注

    關(guān)注

    2

    文章

    812

    瀏覽量

    14457
收藏 人收藏
加入交流群
微信小助手二維碼

掃碼添加小助手

加入工程師交流群

    評論

    相關(guān)推薦
    熱點推薦

    PyTorch教程之從零開始的遞歸神經(jīng)網(wǎng)絡實現(xiàn)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程之從零開始的遞歸神經(jīng)網(wǎng)絡實現(xiàn).pdf》資料免費下載
    發(fā)表于 06-05 09:55 ?0次下載
    <b class='flag-5'>PyTorch</b>教程之從零開始的<b class='flag-5'>遞歸</b><b class='flag-5'>神經(jīng)網(wǎng)絡</b><b class='flag-5'>實現(xiàn)</b>

    PyTorch教程9.6遞歸神經(jīng)網(wǎng)絡簡潔實現(xiàn)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程9.6遞歸神經(jīng)網(wǎng)絡簡潔實現(xiàn).pdf》資料免費下載
    發(fā)表于 06-05 09:56 ?0次下載
    <b class='flag-5'>PyTorch</b>教程<b class='flag-5'>9.6</b>之<b class='flag-5'>遞歸</b><b class='flag-5'>神經(jīng)網(wǎng)絡</b>的<b class='flag-5'>簡潔</b><b class='flag-5'>實現(xiàn)</b>

    PyTorch教程10.3之深度遞歸神經(jīng)網(wǎng)絡

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程10.3之深度遞歸神經(jīng)網(wǎng)絡.pdf》資料免費下載
    發(fā)表于 06-05 15:12 ?0次下載
    <b class='flag-5'>PyTorch</b>教程10.3之深度<b class='flag-5'>遞歸</b><b class='flag-5'>神經(jīng)網(wǎng)絡</b>

    PyTorch教程10.4之雙向遞歸神經(jīng)網(wǎng)絡

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程10.4之雙向遞歸神經(jīng)網(wǎng)絡.pdf》資料免費下載
    發(fā)表于 06-05 15:13 ?0次下載
    <b class='flag-5'>PyTorch</b>教程10.4之雙向<b class='flag-5'>遞歸</b><b class='flag-5'>神經(jīng)網(wǎng)絡</b>

    PyTorch教程16.2之情感分析:使用遞歸神經(jīng)網(wǎng)絡

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程16.2之情感分析:使用遞歸神經(jīng)網(wǎng)絡.pdf》資料免費下載
    發(fā)表于 06-05 10:55 ?0次下載
    <b class='flag-5'>PyTorch</b>教程16.2之情感分析:使用<b class='flag-5'>遞歸</b><b class='flag-5'>神經(jīng)網(wǎng)絡</b>

    使用PyTorch構(gòu)建神經(jīng)網(wǎng)絡

    PyTorch是一個流行的深度學習框架,它以其簡潔的API和強大的靈活性在學術(shù)界和工業(yè)界得到了廣泛應用。在本文中,我們將深入探討如何使用PyTorch構(gòu)建神經(jīng)網(wǎng)絡,包括從基礎概念到高級
    的頭像 發(fā)表于 07-02 11:31 ?1238次閱讀

    遞歸神經(jīng)網(wǎng)絡是循環(huán)神經(jīng)網(wǎng)絡

    遞歸神經(jīng)網(wǎng)絡(Recurrent Neural Network,簡稱RNN)和循環(huán)神經(jīng)網(wǎng)絡(Recurrent Neural Network,簡稱RNN)實際上是同一個概念,只是不同的翻譯方式
    的頭像 發(fā)表于 07-04 14:54 ?1814次閱讀

    遞歸神經(jīng)網(wǎng)絡主要應用于哪種類型數(shù)據(jù)

    處理(NLP) 自然語言處理是遞歸神經(jīng)網(wǎng)絡最重要的應用領域之一。在NLP中,遞歸神經(jīng)網(wǎng)絡可以用于以下任務: 1.1 語言模型(Language Modeling) 語言模型是預測給定詞
    的頭像 發(fā)表于 07-04 14:58 ?1413次閱讀

    遞歸神經(jīng)網(wǎng)絡與循環(huán)神經(jīng)網(wǎng)絡一樣嗎

    遞歸神經(jīng)網(wǎng)絡(Recursive Neural Network,RvNN)和循環(huán)神經(jīng)網(wǎng)絡(Recurrent Neural Network,RNN)是兩種不同類型的神經(jīng)網(wǎng)絡結(jié)構(gòu),它們在
    的頭像 發(fā)表于 07-05 09:28 ?1846次閱讀

    遞歸神經(jīng)網(wǎng)絡結(jié)構(gòu)形式主要分為

    結(jié)構(gòu)形式。 Elman網(wǎng)絡 Elman網(wǎng)絡是一種基本的遞歸神經(jīng)網(wǎng)絡結(jié)構(gòu),由Elman于1990年提出。其結(jié)構(gòu)主要包括輸入層、隱藏層和輸出層,其中隱藏層具有時間延遲單元,可以存儲前一時刻
    的頭像 發(fā)表于 07-05 09:32 ?1118次閱讀

    rnn是遞歸神經(jīng)網(wǎng)絡還是循環(huán)神經(jīng)網(wǎng)絡

    RNN(Recurrent Neural Network)是循環(huán)神經(jīng)網(wǎng)絡,而非遞歸神經(jīng)網(wǎng)絡。循環(huán)神經(jīng)網(wǎng)絡是一種具有時間序列特性的神經(jīng)網(wǎng)絡,能
    的頭像 發(fā)表于 07-05 09:52 ?1336次閱讀

    PyTorch神經(jīng)網(wǎng)絡模型構(gòu)建過程

    PyTorch,作為一個廣泛使用的開源深度學習庫,提供了豐富的工具和模塊,幫助開發(fā)者構(gòu)建、訓練和部署神經(jīng)網(wǎng)絡模型。在神經(jīng)網(wǎng)絡模型中,輸出層是尤為關(guān)鍵的部分,它負責將模型的預測結(jié)果以合適的形式輸出。以下將詳細解析
    的頭像 發(fā)表于 07-10 14:57 ?1170次閱讀

    遞歸神經(jīng)網(wǎng)絡實現(xiàn)方法

    (Recurrent Neural Network,通常也簡稱為RNN,但在此處為區(qū)分,我們將循環(huán)神經(jīng)網(wǎng)絡稱為Recurrent RNN)不同,遞歸神經(jīng)網(wǎng)絡更側(cè)重于處理樹狀或圖結(jié)構(gòu)的數(shù)據(jù),如句法分析樹、自然語言的語法結(jié)構(gòu)等。以下
    的頭像 發(fā)表于 07-10 17:02 ?994次閱讀

    遞歸神經(jīng)網(wǎng)絡和循環(huán)神經(jīng)網(wǎng)絡的模型結(jié)構(gòu)

    遞歸神經(jīng)網(wǎng)絡是一種旨在處理分層結(jié)構(gòu)的神經(jīng)網(wǎng)絡,使其特別適合涉及樹狀或嵌套數(shù)據(jù)的任務。這些網(wǎng)絡明確地模擬了層次結(jié)構(gòu)中的關(guān)系和依賴關(guān)系,例如語言中的句法結(jié)構(gòu)或圖像中的層次表示。它使用
    的頭像 發(fā)表于 07-10 17:21 ?1592次閱讀
    <b class='flag-5'>遞歸</b><b class='flag-5'>神經(jīng)網(wǎng)絡</b>和循環(huán)<b class='flag-5'>神經(jīng)網(wǎng)絡</b>的模型結(jié)構(gòu)

    pytorch中有神經(jīng)網(wǎng)絡模型嗎

    當然,PyTorch是一個廣泛使用的深度學習框架,它提供了許多預訓練的神經(jīng)網(wǎng)絡模型。 PyTorch中的神經(jīng)網(wǎng)絡模型 1. 引言 深度學習是一種基于人工
    的頭像 發(fā)表于 07-11 09:59 ?2403次閱讀