import tensorflow as tf from tensorflow.python import keras import numpy as np keras.backend.clear_session() x = keras.Input(shape=(10)) x_1 = keras.layers.Dense(35)(x) x_2 = keras.layers.Dense(70)(x) model = keras.Model(inputs=x, outputs=[x_1, x_2]) model.summary()
def l_1(true, pred): pred = tf.reshape(pred, (-1, 5, 7)) print(true.shape, pred.shape) return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=true, logits=pred))
def l_2(true, pred): pred = tf.reshape(pred, (-1, 10, 7)) print(true.shape, pred.shape) return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=true, logits=pred))
train_set = tf.data.Dataset.from_tensor_slices((np.random.rand(100, 10), np.random.rand(100, 5, 7), np.random.rand(100, 10, 7))).repeat() train_set = train_set.map(lambda x, y, z: (x, (y, z))).batch(32)
x_1 = keras.layers.Reshape((5, 7))(x_1) x_2 = keras.layers.Reshape((10, 7))(x_2) model_warpper = keras.Model(inputs=x, outputs=[x_1, x_2]) model_warpper.summary() model_warpper.compile('adam', loss=[l_1, l_2]) model_warpper.fit(train_set, steps_per_epoch=30)
|