{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Copyright 2022 Google LLC\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\");\n", "you may not use this file except in compliance with the License.\n", "You may obtain a copy of the License at\n", "\n", " https://www.apache.org/licenses/LICENSE-2.0\n", "\n", "Unless required by applicable law or agreed to in writing, software\n", "distributed under the License is distributed on an \"AS IS\" BASIS,\n", "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "See the License for the specific language governing permissions and\n", "limitations under the License." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Adversarial training\n", "\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jaxopt/blob/main/docs/notebooks/deep_learning/adversarial_training.ipynb)\n", "\n", "\n", "The following code trains a convolutional neural network (CNN) to be robust\n", "with respect to the projected gradient descent (PGD) method.\n", "\n", "The Projected Gradient Descent Method (PGD) is a simple yet effective method to\n", "generate adversarial images. At each iteration, it adds a small perturbation\n", "in the direction of the sign of the gradient with respect to the input followed\n", "by a projection onto the infinity ball. The gradient sign ensures this\n", "perturbation locally maximizes the objective, while the projection ensures this\n", "perturbation stays on the boundary of the infinity ball.\n", "\n", "## References\n", "\n", " Goodfellow, Ian J., Jonathon Shlens, and Christian Szegedy. \"Explaining\n", " and harnessing adversarial examples.\", https://arxiv.org/abs/1412.6572\n", "\n", " Madry, Aleksander, et al. \"Towards deep learning models resistant to\n", " adversarial attacks.\", https://arxiv.org/pdf/1706.06083.pdf" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "%%capture\n", "%pip install jaxopt flax" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import datetime\n", "import collections\n", "\n", "# activate TPUs if available\n", "try:\n", " import jax.tools.colab_tpu\n", " jax.tools.colab_tpu.setup_tpu()\n", "except KeyError:\n", " print(\"TPU not found, continuing without it.\")\n", "from flax import linen as nn\n", "import jax\n", "from jax import numpy as jnp\n", "from jaxopt import loss\n", "from jaxopt import OptaxSolver\n", "from jaxopt import tree_util\n", "\n", "from matplotlib import pyplot as plt\n", "plt.rcParams.update({'font.size': 22})\n", "\n", "import optax\n", "import tensorflow as tf\n", "import tensorflow_datasets as tfds" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Show on which platform JAX is running. The code below should take around 3 min to run on GPU but might take longer on CPUs." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "JAX running on GPU\n" ] } ], "source": [ "print(\"JAX running on\", jax.devices()[0].platform.upper())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "Flags = collections.namedtuple(\n", " \"Flags\",\n", " [\n", " \"l2reg\", # amount of L2 regularization in the objective\n", " \"learning_rate\", # learning rate for the Adam optimizer\n", " \"epochs\", # number of passes over the dataset\n", " \"dataset\", # one of \"mnist\", \"kmnist\", \"emnist\", \"fashion_mnist\", \"cifar10\", \"cifar100\"\n", " \"epsilon\", # Adversarial perturbations lie within the infinity-ball of radius epsilon.\n", " \"train_batch_size\", # Batch size at train time\n", " \"test_batch_size\" # Batch size at test time\n", " ])\n", "\n", "FLAGS = Flags(\n", " l2reg=0.0001,\n", " learning_rate=0.001,\n", " epochs=10,\n", " dataset=\"mnist\",\n", " epsilon=0.01,\n", " train_batch_size=128,\n", " test_batch_size=128)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def load_dataset(split, *, is_training, batch_size):\n", " \"\"\"Load dataset using tensorflow_datasets.\"\"\"\n", " version = 3\n", " ds, ds_info = tfds.load(\n", " f\"{FLAGS.dataset}:{version}.*.*\",\n", " as_supervised=True, # remove useless keys\n", " split=split,\n", " with_info=True)\n", " ds = ds.cache().repeat()\n", " if is_training:\n", " ds = ds.shuffle(10 * batch_size, seed=0)\n", " ds = ds.batch(batch_size)\n", " return iter(tfds.as_numpy(ds)), ds_info\n", "\n", "\n", "class CNN(nn.Module):\n", " \"\"\"A simple CNN model.\"\"\"\n", " num_classes: int\n", "\n", " @nn.compact\n", " def __call__(self, x):\n", " x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n", " x = nn.relu(x)\n", " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n", " x = nn.relu(x)\n", " x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", " x = x.reshape((x.shape[0], -1)) # flatten\n", " x = nn.Dense(features=256)(x)\n", " x = nn.relu(x)\n", " x = nn.Dense(features=self.num_classes)(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make\n", "# it unavailable to JAX.\n", "tf.config.experimental.set_visible_devices([], \"GPU\")\n", "train_ds, ds_info = load_dataset(\"train\", is_training=True,\n", " batch_size=FLAGS.train_batch_size)\n", "test_ds, _ = load_dataset(\"test\", is_training=False,\n", " batch_size=FLAGS.test_batch_size)\n", "input_shape = (1,) + ds_info.features[\"image\"].shape\n", "num_classes = ds_info.features[\"label\"].num_classes\n", "iter_per_epoch_train = ds_info.splits['train'].num_examples // FLAGS.train_batch_size\n", "iter_per_epoch_test = ds_info.splits['test'].num_examples // FLAGS.test_batch_size\n", "\n", "\n", "net = CNN(num_classes)\n", "\n", "@jax.jit\n", "def accuracy(params, data):\n", " inputs, labels = data\n", " logits = net.apply({\"params\": params}, inputs)\n", " return jnp.mean(jnp.argmax(logits, axis=-1) == labels)\n", "\n", "logistic_loss = jax.vmap(loss.multiclass_logistic_loss)\n", "\n", "@jax.jit\n", "def loss_fun(params, l2reg, data):\n", " \"\"\"Compute the loss of the network.\"\"\"\n", " inputs, labels = data\n", " x = inputs.astype(jnp.float32)\n", " logits = net.apply({\"params\": params}, x)\n", " sqnorm = tree_util.tree_l2_norm(params, squared=True)\n", " loss_value = jnp.mean(logistic_loss(labels, logits))\n", " return loss_value + 0.5 * l2reg * sqnorm\n", "\n", "@jax.jit\n", "def pgd_attack(image, label, params, epsilon=0.1, maxiter=10):\n", " \"\"\"PGD attack on the L-infinity ball with radius epsilon.\n", "\n", " Args:\n", " image: array-like, input data for the CNN\n", " label: integer, class label corresponding to image\n", " params: tree, parameters of the model to attack\n", " epsilon: float, radius of the L-infinity ball.\n", " maxiter: int, number of iterations of this algorithm.\n", "\n", " Returns:\n", " perturbed_image: Adversarial image on the boundary of the L-infinity ball\n", " of radius epsilon and centered at image.\n", "\n", " Notes:\n", " PGD attack is described in (Madry et al. 2017),\n", " https://arxiv.org/pdf/1706.06083.pdf\n", " \"\"\"\n", " image_perturbation = jnp.zeros_like(image)\n", " def adversarial_loss(perturbation):\n", " return loss_fun(params, 0, (image + perturbation, label))\n", "\n", " grad_adversarial = jax.grad(adversarial_loss)\n", " for _ in range(maxiter):\n", " # compute gradient of the loss wrt to the image\n", " sign_grad = jnp.sign(grad_adversarial(image_perturbation))\n", "\n", " # heuristic step-size 2 eps / maxiter\n", " image_perturbation += (2 * epsilon / maxiter) * sign_grad\n", " # projection step onto the L-infinity ball centered at image\n", " image_perturbation = jnp.clip(image_perturbation, - epsilon, epsilon)\n", "\n", " # clip the image to ensure pixels are between 0 and 1\n", " return jnp.clip(image + image_perturbation, 0, 1)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/envs/jaxopt/lib/python3.10/site-packages/jax/_src/tree_util.py:188: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.\n", " warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() '\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 0 out of 10\n", "Accuracy on train set: 0.982\n", "Accuracy on test set: 0.982\n", "Adversarial accuracy on train set: 0.979\n", "Adversarial accuracy on test set: 0.979\n", "Time elapsed: 0:00:14\n", "\n", "Epoch 1 out of 10\n", "Accuracy on train set: 0.987\n", "Accuracy on test set: 0.986\n", "Adversarial accuracy on train set: 0.984\n", "Adversarial accuracy on test set: 0.982\n", "Time elapsed: 0:00:21\n", "\n", "Epoch 2 out of 10\n", "Accuracy on train set: 0.989\n", "Accuracy on test set: 0.987\n", "Adversarial accuracy on train set: 0.986\n", "Adversarial accuracy on test set: 0.984\n", "Time elapsed: 0:00:29\n", "\n", "Epoch 3 out of 10\n", "Accuracy on train set: 0.992\n", "Accuracy on test set: 0.990\n", "Adversarial accuracy on train set: 0.990\n", "Adversarial accuracy on test set: 0.988\n", "Time elapsed: 0:00:36\n", "\n", "Epoch 4 out of 10\n", "Accuracy on train set: 0.994\n", "Accuracy on test set: 0.990\n", "Adversarial accuracy on train set: 0.992\n", "Adversarial accuracy on test set: 0.988\n", "Time elapsed: 0:00:43\n", "\n", "Epoch 5 out of 10\n", "Accuracy on train set: 0.994\n", "Accuracy on test set: 0.990\n", "Adversarial accuracy on train set: 0.992\n", "Adversarial accuracy on test set: 0.987\n", "Time elapsed: 0:00:51\n", "\n", "Epoch 6 out of 10\n", "Accuracy on train set: 0.995\n", "Accuracy on test set: 0.991\n", "Adversarial accuracy on train set: 0.994\n", "Adversarial accuracy on test set: 0.988\n", "Time elapsed: 0:00:58\n", "\n", "Epoch 7 out of 10\n", "Accuracy on train set: 0.995\n", "Accuracy on test set: 0.990\n", "Adversarial accuracy on train set: 0.993\n", "Adversarial accuracy on test set: 0.988\n", "Time elapsed: 0:01:05\n", "\n", "Epoch 8 out of 10\n", "Accuracy on train set: 0.995\n", "Accuracy on test set: 0.989\n", "Adversarial accuracy on train set: 0.993\n", "Adversarial accuracy on test set: 0.987\n", "Time elapsed: 0:01:13\n", "\n", "Epoch 9 out of 10\n", "Accuracy on train set: 0.995\n", "Accuracy on test set: 0.990\n", "Adversarial accuracy on train set: 0.993\n", "Adversarial accuracy on test set: 0.988\n", "Time elapsed: 0:01:20\n", "\n" ] } ], "source": [ "# Initialize solver and parameters.\n", "solver = OptaxSolver(\n", " opt=optax.adam(FLAGS.learning_rate),\n", " fun=loss_fun,\n", " maxiter=FLAGS.epochs * iter_per_epoch_train)\n", "key = jax.random.PRNGKey(0)\n", "params = net.init(key, jnp.zeros(input_shape))[\"params\"]\n", "\n", "state = solver.init_state(params)\n", "start = datetime.datetime.now().replace(microsecond=0)\n", "jitted_update = jax.jit(solver.update)\n", "\n", "accuracy_train = []\n", "accuracy_test = []\n", "adversarial_accuracy_train = []\n", "adversarial_accuracy_test = []\n", "for it in range(solver.maxiter):\n", " # training loop\n", " images, labels = next(train_ds)\n", " # convert images to float as attack requires to take gradients wrt to them\n", " images = images.astype(jnp.float32) / 255\n", "\n", " adversarial_images_train = pgd_attack(\n", " images, labels, params, epsilon=FLAGS.epsilon)\n", " # train on adversarial images\n", " params, state = jitted_update(\n", " params=params,\n", " state=state,\n", " l2reg=FLAGS.l2reg,\n", " data=(adversarial_images_train, labels))\n", "\n", " # Once per epoch evaluate the model on the train and test sets.\n", " if state.iter_num % iter_per_epoch_train == iter_per_epoch_train - 1:\n", "\n", " # compute train set accuracy, both on clean and adversarial images\n", " adversarial_accuracy_train_sample = 0.\n", " accuracy_train_sample = 0.\n", " for _ in range(iter_per_epoch_train):\n", " images, labels = next(train_ds)\n", " images = images.astype(jnp.float32) / 255\n", " accuracy_train_sample += jnp.mean(accuracy(params, (images, labels))) / iter_per_epoch_train\n", " adversarial_images_train = pgd_attack(\n", " images, labels, params, epsilon=FLAGS.epsilon)\n", " adversarial_accuracy_train_sample += jnp.mean(\n", " accuracy(params, (adversarial_images_train, labels))) / iter_per_epoch_train\n", " accuracy_train.append(accuracy_train_sample)\n", " adversarial_accuracy_train.append(adversarial_accuracy_train_sample)\n", "\n", " # compute train set accuracy, both on clean and adversarial images\n", " adversarial_accuracy_test_sample = 0.\n", " accuracy_test_sample = 0.\n", " for _ in range(iter_per_epoch_test):\n", " images, labels = next(test_ds)\n", " images = images.astype(jnp.float32) / 255\n", " accuracy_test_sample += jnp.mean(accuracy(params, (images, labels))) / iter_per_epoch_test\n", " adversarial_images_test = pgd_attack(\n", " images, labels, params, epsilon=FLAGS.epsilon)\n", " adversarial_accuracy_test_sample += jnp.mean(\n", " accuracy(params, (adversarial_images_test, labels))) / iter_per_epoch_test\n", " accuracy_test.append(accuracy_test_sample)\n", " adversarial_accuracy_test.append(adversarial_accuracy_test_sample)\n", "\n", "\n", " time_elapsed = (datetime.datetime.now().replace(microsecond=0) - start)\n", " print(f\"Epoch {it // iter_per_epoch_train} out of {FLAGS.epochs}\")\n", " print(f\"Accuracy on train set: {accuracy_train[-1]:.3f}\")\n", " print(f\"Accuracy on test set: {accuracy_test[-1]:.3f}\")\n", " print(\n", " f\"Adversarial accuracy on train set: {adversarial_accuracy_train[-1]:.3f}\"\n", " )\n", " print(\n", " f\"Adversarial accuracy on test set: {adversarial_accuracy_test[-1]:.3f}\"\n", " )\n", " print(f\"Time elapsed: {time_elapsed}\")\n", " print()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAA/cAAAHQCAYAAADksJpoAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAD390lEQVR4nOzdd5iTVfbA8e+ZTu9FmjRBRCkCYgUUG7bFsvbeu669rPtzXXfV1bVg710sqFixIUVR6SiiKJ2hDp0Zhunn98d9M8mEJDOZZCZTzud58iRvuW9uMiU5t5wrqooxxhhjjDHGGGNqr6REV8AYY4wxxhhjjDGxseDeGGOMMcYYY4yp5Sy4N8YYY4wxxhhjajkL7o0xxhhjjDHGmFrOgntjjDHGGGOMMaaWs+DeGGOMMcYYY4yp5Sy4N8aYGkZEXhcR9W4fxPG6d3vXXB6va5rQRGSy916/UgXXHhHw+9E13tdPBBHpGvCaRlTRc/iuf35VXN84Qb+fKiIzKlhubFC580Occ37QOeeUc83zI/2tiMgrvuMRriEicoKIvCsiy0QkV0QKRGSjiMwUkadE5DQRaRdQZnJQPaO93V2R98wYY4JZcG+MMTWIiDQGTgzYdayItExUfUzdY408ppoNEZHekU4QkSbAXypx7X+ISHLlqlU+EWkGfA18BPwV6Ao0AFKBVsBg4ArgbeDFqqqHMcZUlAX3xhhTs5wMNArYTgNOT1BdjDEmFtu9+4g97LjAuUHA+RXVEzg32kpFYRww0nv8Ja4BojcusO8JnAQ8C2wIKjcKaBLmttI75/sI5/ynSl6NMabOS0l0BYwxxpTh+6K6FMgH+nj7nkpYjUzUVHVEoutQm6jqckCq+Dmq9PompHHAhcDZInKXqoYb/u4L/t8DLqrgtZcC3YG7ROQNVS2MraplichhwOHe5qOq+regUzYDS4APReQ6YKDvgKrujHBd33tQrKo5cayyMcZYz70xxtQUItIZGOFtvu7dAIaKSK+EVMoYYyrvA2AHsDtwSKgTRKQLMBxQ4I0orn2Pd98NuCCGOoZzeMDjByOdqKr5qvpTFdTBGGOiYsG9McbUHGfj/7/8hnfz9fKUN6wVEUkRketEZI6I7BCRzSLyvYicV065Gd4c7C8r8BwveedmisgunyEi0kBErheRKSKywUs8tU5ExovIMRGuW2YeuIjsJSIveAms8kVka9D5B4rIGyKyVER2ekmuVojIjyJyv4gMCfEcGSJyrIg8KyLzRSQ7oH6fi8iZoV5TQPkySfJEZJSIfCwia0SkSETGhzs3xLX2EZG/i8hU730qFJGtIjJbRO4VkTbh6lFZ4iU6A/7P27V7iERek4PP925dRaSFiPxbRH713jsVkQEB5/cUkRtE5CvvPSkQke3e+Y9KhOR/Uk5CPRFZHphoTFyCs6+89y5PRP7w6tY0wnNEStQW/LM9xPudXef9/i0TkcclIGlamOcQEbnI+z3c7t1micg1IpIsAQneIl2nPN7vzwsissT7/d8uIvO896B1hHK+BHKTve1+3t/RKu91rvbO6RFL/QLswAX4EH74/Nm4URvfAcujuPYU4Fvv8Z0iklaZCkYQ+D5mx/naxhhTJWxYvjHG1By+AP5HVV0MLugADsUNa/1HuGGtItIImEDZ3rGGwEHAQSIyEjeMNZTXgSHASBFpr6rrwjxHBi4nAMBbqloSdLwf8DGuly5QO9xc1b+IyMvAJapaHKYuiMhfcAmqMgJ27ww4fhOhe9K6eLf9gb2B44KO3wdcH6JcO9wc2VHAWSJykqrmh6ufV4f7gNsinROhbH9gXohDzYB9vdslInKMqs6uzHNUgR7Ay0DnUAfFJR5bFOJQKtDXu10kIn9V1S9iqYiIPMKuP8dewB24BJQHxzLcWUT+BjxE2Q6QrsDVwAkicqCqrg5RLhU3rDw4Mdwg73YC8G5l6xXwPDcBDwTVLwPo792uEJHRqjq1nOucBrwKpAfs7gCch3udw1V1fqz1xf1/OQc4RUSuVtW8oOPnBJwXrbuAw3B/9xcT3+lLmwMeHw58GMdrG2NMlbCee2OMqQHE9TT38TYDv+T6HncFhkW4xLP4A/s3cMFEa1w257dwX6DD9Zy9DRQBycAZEZ7jeMDXM1pm+Ky4obWTcIH9Ety82R5AS2Af4H9ACW747D8jPEcL3GtegmtI2A3ohAs4EDc94X7v3InA0d5ztsANzz0WGANsCnHtbcDzwKm496Wjd/39vPrtBI4B/hWhfuC+6N8GfIJ7z9vgkms9Xk45H/Xqfi3uZ7oHLkFXX1yAshBoC7wvIg0qeM2K+A6XrOs+b3sluybyGhWm7Ku4RI/X4OY5t8UlGlsbcM4M4BZcY9SeuN+/3riEkDOAxsDbIrJbDK/hHFxg/zyuQaoV7u/mGe94f1yQX1nDcL8LHwMH415Dd9zvhOKCyHBDtP+DP7D/FDjQK9/Xu+ZIKtkg5CMiZ3rPnwT8imswaIf7G7gK2IL7W/hMRLpHuFRP3M90OnAk7ufZGffe5nvXeDaWugaYCKzBNV6dEPR6huB+V/JwDSNRUdUfcInuAO4QkfRI50dpYsDj50XkkkgjQ4wxpkZQVbvZzW52s1uCb7jAUHFfrFsG7G8C5HrHXgxTdrB3XIHnwpzzUsA5y0Mc/9Q7NjtCHcd758wLcexj79gSoHmY8pcGvMYOQcfuDqjfH0CzMNe4xjtnPZAW55/B0d61c4AmIY5PDqjj24BEuJbv3FcqUY/GwGKv/IUhjo8IqEfXSlzf917v8nsQ4XnygQExvLcpuOzgCtwT4njXgOcaEeL48oDjd4Z5Dt/v4Nowx33lzy/nZxvub2hMwHvRNOhYZ1wDmeKWTdvldwO4NeA5tBLvYbr3e6/A78F18M4ZiAuUFfggxPFXAuowAUgJcc4NAefsWYl6Bv7ejPD2PehtfxJ0ru//3rshfg9C/ZzOD/7dxzXO+fZdV9754d6PMK9lQuDPDCgEZuEaky4Idc0KvD++3+XJlf17spvd7Ga3cDfruTfGmATzhvP6lrv7XFVLh4OqajYuqAY3rDVUT+753n0eLoAI5WZcUBKOryd+XxHZM0QdW+Lv1Q3ute+Ofwj8Vaq6NcxzPI+bGpCGW/oqnH+o6rYwx3zTyTaoakGEa0RN3XDxDbge6gMjnFoM3KCqMc2bjlCPHPzzlI+oiueohJdUdV5lC6tqEW4ECcT2mjLxj9wI9rJ3315ccsrKyMWNPoh0/TTcCIFAZ+JGvgDcGOZ34yFc/SvreFwPO8CtqrrLsnGqOhd/j/sJEjl3w3XezyXYKwGPd8ldUUmvefdH++oU9H/vtZClKkBVZ+AaJwFui/Nol5NwQ/1971MKblTUZbgG02Ui8p2IDI/jcxpjTKVZcG+MMYl3DP7kTaHmnfq++DZl1/m84IYPg+sJ2hLqCVR1Ey4BVTgf4U8adXaI43/FBTUl+IM0n5G4hFj5wEwRaRzqhguaf/bKDA5TD1+PYjhzvfu+4hLntYpw7i5EpI2I3Ol9IfclsitNKIcbYg9uDnc481R1TTTPG6IeIiJ/FZH3vWRtuUH1uLkC9ahOn1XkJBE5SkTeFJE/RSQn6DU96Z0Wy2v6WsPna/gj4HH7Sl7/pwiNU5Gu72sM+k29fBnBvHpX6H0Mw/d3ngt8HuE83/D2ZMI3Ui1V1T9DHfAaF33rtlf2fQy+5nzc334K/oB+FO7/3gYgpjwMwD9w/zvaA1fGeK1SqrpTVa/CTc24BfgKN70n0MHAJHHL4RljTEJZcG+MMYnnmwu/hdBf/r8G1gWdG6ird7+wnOf5PdwBdesyv+9tnikiwWuC+wL+SSEC297efTqwEddIEO52onduuB7FjaF6JAPqORn/SIZbgfUi8pOIPCQug3qjcGVF5BDce3Qv/vnU4RLLNgt3HcInJqwQr6FjIi652km4n1+43sZI9ahOEV+zuJUa3sIFaWfi8giE+1nE8poiNarkBjxuGO/rq2qk63f17v8gsvL+RiPZ3bv/M0yPu8+CEGWCldc45XutlX0fQ/E1XPr+h/kS6b1dzusplzdiYby3eUuk/wOVvH6mqj6oqkfh8hH0xuU48CUcFOBhEQnXaGmMMdXCgntjjEkgEWmBf0j7NKCPiAwIvOES0vkyXx8puy7H1di7Ly9DeHnHfV++u+Gy7PvquHvAdqiRBZUJ1jLC7M8Nsz/Qqbie7aW43smhwI240QdZ4pYsK5P4ysvm/gEuwV8WrmFgf1x28Gb4E8r5hk1HWk2mInWM5BFc0jlwQ72PwSUfbBVQD9/Q85qyqk15r/lW/MkYx+MacXrhGlB8r+kK73hycOEohF1lIUhw41RVX98XTO4op1yls/jj3sOKXCNw2bYmYc6p6vcxlDe95x0sIgfgphlADEPyg/wfrve+LS43R5VQ509VfQqX4+Bt71AS/t9xY4xJiJrypcEYY+qr03HD3cEF+cHLtwVLBs4CHg7Yl4MLUBuHLOFX3vHJwCpcdvqzcQnQ8J5PcNnkPwhRzhdsZKlqxHXA40FVC3Hzlx8SkZ7AAbgs58fhhuVeDewvIgcE9Aieggs0S4BDVfW3UNeu6mzYXo+ir8fyflW9Pcx58Zw3XB0u9+7fVtWQKy54SynWVb6gvrwe4/L+BiPxBe3R/J3XmPXZVXWdiHwDHIXL25EOLFTVWXG6/nwReQ/X+HeTiDxZXpk4PGexiFyDf6rBvlX9nMYYE4n13BtjTGKFW54umjLLvftdEuEF6RPpoLp1633z6f/qJbwCF9wDfOQl+AvmG7Ld2ushrzaqulhVX1fVS3AZy8d4hwZTtqHElwDtlwiBfWeqfhh8b/zrir8d4bx9qrgeceMlW+zkbdaJ11QJK7z78vIJ9C7neCTLfc8hIpE6Z/qGKFNT+Eb+dA/ajpe7cQ14rYBqmQOvqhtxo4EgvtMYjDEmahbcG2NMgojIHrih4eB6cSXSDbjJO7e/iAQGSr4e9hEi0jzMc7UCKpLR2ZcJvyVwjIgMBPYKOhbsa+8+CddDnhBeL/3dAbsCGzN8AXWkIeFnRTgWL4HrcIesi4h0wo1EqCqFkZ6/EirymhoBo+P0fDXRNO++r4j0CHWCiCQDx8bwHL6/84a4ZRvD8f0NFgM/xvB8VeFD/CN9lPD/UypFVX8HxnqbN1INOSu83+3m3mZMiTaNMSZWFtwbY0ziBPbAB2egD+UdXK9UcNlXvPsM4L9hyj5I2SAspICs1uCG5vsC3g3Al2HKLMSfCPB+EYnYeykibb1cA1ETkT1EJNJnV2BgtSng8TLvfk9vKH/wdfcE7qhMnaK0PODx8cEHvR7Z56jaaXO+96VNOT3AFbUB/7D0XV6T5xFcg1Fd9Rb+eewPhUhICfA3oEsMz/Ep/h7iB0Rkl/n0ItIf/7zvj1R1Q/A5ieQlJdwb1/DWS1VXVsHT/BP3s2hODHPvReRiETmlnP83AHfhn1r1daQTjTGmqllwb4wxCeB9+fdloP/VC6ojUtVVwHfe5pleTyDenNU3vf2XiMhrIjJQRFqKyL4i8iZwAf4Atzy+3rTj8Af35WW0vhIXeLTGLYf3TxEZJCKtRKS1iPQVkbNF5B1gJWWD8GjcCSwRkftE5AgR6SwizUWkh4hcgD8nwA7gk4By7+MaRlKBz7zM+u1FpIuIXIF7X3OBzZWsV4Wo6lr8P8M7vGX5ennv0UhcFv1RRFjZIA5me/fpwD0i0kFEUr2M91H35nu/F773/XwRedj7ebcSkQNF5APgEqr2NSWUqmYCj3mbo4HxIrK/9ze4p4g8CDxADCstqGo+roEA3Gia70XkOHHLO3YWkctxvz/puN7xWyr7XFVJVVeo6sJwSwbG4fqL8A/3r+z/GXDTnN7DrWX/H2+Zx91FpJmIdBKRY0VkPC6ZJLhe+yqf52+MMZFYcG+MMYkxDP/yWWMjnBfMd24H4PCA/ZfhDxrPAebgemhn45Yme5OKZ6V+CxcIZ+Bf5zri8FmvB244LoBrilt3ehZuabwNwK+4L9yn4oKPwtBXqpCuwG24NadX4pYQXAy8hJt3vxM42wukffVbBPzd2+yFy6y/FjdX+ilc0H8q1ZOA7Aqvzmm4Zfn+wL1H3+B+Lx7BLZNXJVR1JvCDt3k7sBoowP1MJlbysrfgn3f+N9zPeyNuuPqJuMaVhyp57dridlzvOsAJuCHxm3B/EzcBk3ABPlQ8W30ZqvoWbqWIEqAfrgErC/d38DRurvkW4FhVXVKpV1E33ENs/2MAfEtydsH9bL/AjbzZiltV41PgL945C4GRqrotxuc0xpiYWHBvjDGJETisPprg/j38X1pLr6GqO4DDcIHVXFwv9DZcgHGRqp5NBXnr2AcGeX+q6owKlFuICzjOwwUda3BBYz7uy/CXuC/Je6jqz+GuU45bcY0XrwDzgPVAES4on4ubfrCnqo4PUb/7cIHmFO/8PGAJLigaqKpTg8tUBVVdAAzCvYa1uJ/nelzwMFpVb6iGahyDe69+wzWGxERV1wFDcAkNV+Be00bce32Bqp6Cf0pJnaSqBbhg7xJgOq733Pd7eQNunrxvakylG5FU9SHcEmwv4Ubj5HnP9QvwH9xw92r5Xa6pVHUZbpnJWK5xDy4B4vW4/7u/435uxbj3ezEwDjcCq5/3/88YYxJKVDXRdTDGGGOMqfNE5DHgWtxUnLq8eoAxxpgEsJ57Y4wxxpgq5uXZ8GXLnx3pXGOMMaYyLLg3xhhjjImRiKSJSOMIp9yCP8FbleVUMMYYU39V5VI7xhhjjDH1RUvgZxF5Fvgcl88B3Lzti4Dzve1puPwKxhhjTFzZnHtjjDHGmBiJSHtcgsRIFgBHqerqaqiSMcaYesaCe2OMMcaYGIlIKm7ZyaOAAUA73LKQW4H5uOUAX/DWqzfGGGPizoJ7Y4wxxhhjjDGmlrOEesYYY4wxxhhjTC1nwb0xxhhjjDHGGFPLWXBvjDHGGGOMMcbUchbcG2OMMcYYY4wxtZwF98YYY4wxxhhjTC1nwb0xxhhjaiQRaSwig0SkbaLrYowxxtR0FtwbY4wxJmFE5FAReUpEBgbtPx9YD8wAVovIvYmonzHGGFNb2Dr3xhhjjEkYEXkTOBnYTVW3ePu6AX8AKcAqYDdch8SRqjoxUXU1xhhjajLruTfGGGNMIu0H/OwL7D3n4AL7W1W1C3AAoMCVCaifMcYYUytYcG+MMcaYRGqD650PdBiQBzwBoKqzgB+A/tVbNWOMMab2sODeGGOMMYnUECj0bYhIEjAYmKGqOwPOy8QNzzfGGGNMCBbcG2OMMSaRsoCeAdv74wL+aUHnpQM7McYYY0xIFtwbY4wxJpF+BAaKyKki0hS4Eze//uug8/oAa6q7csYYY0xtYdnyjTHGGJMwIrIf8B0ugR6AAHNUdXDAOZ2AlcArqnph9dfSGGOMqfms594YY4wxCaOqM4DjgCnA78ArwLFBp50GbGPX3nxjjDHGeKznvo5p3bq1du3atdLld+zYQaNGjeJXoTrG3p/I7P2JzN6fyOri+zN79uyNqtom0fUwJhb23aLq2XsUmb0/kdn7E1lde38ifbdICbXT1F5du3Zl1qxZlS4/efJkRowYEb8K1TH2/kRm709k9v5EVhffHxFZkeg6GBMr+25R9ew9iszen8js/Ymsrr0/kb5bWHBvjDHGmIQTkWbA2cABQBtgoqr+1zvWG9gd+C5oeTxjjDHGeCy4N8YYY0xCicjRwJtAc1xCPQVWB5zSCxgPnAm8U83VM8YYY2oFS6hnjDHGmIQRkb2BD4AmwFO45HkSdNoXQC7wl+qtnTHGGFN7WM+9McYYYxLpDiAdOFFVPwYQkTK986paKCJzgf4JqJ8xxhhTK1jPvTHGGGMSaQQw1xfYR7Aa2K3qq2OMMcbUThbcG2OMMSaRWgGLK3BeGtCgiutijDHG1FoW3BtjjDEmkbYAnSpwXg9gfRXXxRhjjKm1LLg3xhhjTCLNAIaIyB7hThCRIUA/YFq11coYY4ypZSy4N8aYBMjansffP5zPMY99l+iqGJNoTwKpwDhvPfsyRKQ78BJuebynq7luxhhTo2Vtz+M/03eSlZ2X6KqYGsCy5RtjTDXK2p7HmImLeG/2KkpUKSzWRFfJmIRS1S9F5HHgGuA3EVmAC+QPF5HpwEDc95WHVfX7BFbVGGNqnDETF7FoSwljJi7m3tF7J7o6JsGs594YY6pB1vY8Xl2QzyH/ncQ7szLJLyqxwN4Yj6peB1yJm1O/N26d+07AEGAbcL2q3pS4GhpjTM2TtT2P92avQoFxszKt995Yz70xxlSlwJ76ouISLJ43JjRVfUZEngMGAN2BZCATmKGqRYmsmzHG1ERjJi6iWN0Xi2JV6703FtwbY0xVKBPUl5RQXBL+3C8XrKNHm0Z0admItBQbUGXqL1UtAeZ4N2OMMWGs27qTsTMzKS5xwX1hsfLGTyv4Y9129mzflG6tG9G9TSO6t25MxxYNSE6SBNfYVIcaF9yLyJnAFbisuMnAQuBl4GnvQz+aa7UEbgZGA12BPGA+8Lyqvh7PciJyN/B/EaqTr6oZEZ4zbq/bGJN4V4+dy8xlm6lIR/1lr88GIDlJ6NKyId1bN6JH28Z0b92I7m0a071NI1o1SkPEPpiNMcaY+m5nQTGnPvdjaWAfaObyLcxcvqXMvrTkJHZv1dAL+H3fLxrRrXUjWtr3izqlRgX3IvIkbs5dHjARKARGAk8AI0XklIoGul523W+B3XFz+L4CmgFDgUNEZCRwgapqPMoF+BmYF2J/YYS6xu11G2MSr7hEObB7K2Yv3xzVMPziEmXZxh0s27iDiQuzyhxr1iC1tAW+R1vvvk0jdm8Vn95+30iDOSu38vl1h8R8PWMqSkTOreCpBcBGYK6qbqrCKhljTI2VtT2P816ewcrNOytcpqC4hEVZOSzKysGFN37NGqQG9PK74L9baxf4Z6QmV6p+V4+dyxNnDqRtk7D9mqaK1JjgXkROxgW464BhqrrI298OmASciMuk+1gFLzkWF6CPA85T1Vzven2ACcB5uPVyn49TOZ/xqnp3BetYFa/bGJNAi9Znc8v7vzB35dYy+5OTBFRDBvuH7NGapRt2sHpr+A/qbTsLmbty6y7XTRJcb7/XEh/Y49+6cfmt8Za939QAr0CFBrn4qIh8ClyjqplVUyVjjKl5FqzZxsWvzmLttvCJ85KThP6dmtG7fROWbtjB0o072JCdH/b8bTsLmZe5lXmZW3c51rF5g9LAP7DXv0Pz8MP8x0xcxMzlm23+f4LUmOAeuN27v9UX4AKo6noRuQKYDNwmIo+X14stIgcA++Ey7F7qC9C96/0uIjcB7wF3icgLvl74ypaLUdxetzEmcQqLS3h2yhLGTFxMQcAE+z3bN+HBU/rTrlk6t702hR/WllAcFES/ftFQAHILili2cQdLN+xgyYYc70PZ3ecWFId83hKF5ZtyWb4pl2+DjjXJSKGHN6y/RxvX09+9TWN2b9WQbbmFFtSbmuI13Ai5v+CC/J+BFUAJbmpcf1z2/E+ARril8U4ABojIIOvFN8bUB1//tp7r3p4b9vuAT3GJ8tua7TxzzqDSnvPsvMLSkYFLNuzwvmvksGxj+O8XAKu37mT11p18v3hjmf1pKUl0bdWQ7q0b0620x78RTdJTXfZ+ddn7rx3Z03rvq1mNCO5FpBMwCDfk7r3g46o6RURWAx2B/YEfyrnkEO9+tqpuCXH8K+++My6Ynx5juUqpgtdtjEmAX1dv45Zxv/Db2u2l+1KThWsO24PLh/coHTZ/bt907j93f8ZMXMy4WZm7BPkN01Lo26EZfTs0K3N9VWXd9jwX7G/IYUlA8B+ptz87ryhsazy4aMlC+oqxaQtV6m+4z9PJwFWq+nvgQRHZE3gS2Av3WViE6+3/C3ADcGc11tUYY6qVqvL8d0u5b8JCKtqtGJw5v0lGKv06Nadfp+a7XHv99vzSjoTAoD9zy86Qc/oBCopK+HN9Dn+uzwlbh/yiEs5+YTon9O9A26YZtG+aQTvvvmmDlGqb55+1PY//TN/JXoPy6kVDQ40I7nGt8AALVDXcN9WZuCB3IOUHuY29+41hjmfjAuo0XHDtC9IrWy7QviLyANAC2Oyd85mqFoQ4N96v2xhTjfKLinl84mKenrKkzAdg/07N+O8p/endvskuZdo2yeDe0Xtz7ciejJm4mDkrQrUjliUi7NasAbs1a8BBPVuXObazoNh9GHsfzKU9/hty2FFO63647wiHPTSZ3Vs1pKs35273Vo3o1qoRHZpnkJJcv7L527SFanEP0BzYV1V3+aaoqgtFZDSwBLhHVa8WkUuAw4DjseDeGFNHFRSV8I+PfuXtmf4ZSKnJUu5nUWGxVvj7RftmGbRvlsGBPcp+vygoKmHl5twyAb9vmP/GnPDD/H0U+HN9Dg999ecuxzJSk2jXNIN2TTJo2zS9NPBv1yyDdk3Sad/MbVdmzn+wMRMXsWhLSb2ZJlBTgvtu3v2KCOesDDo3El8mqu5hjnfCBejB16tsuUDHe7dAq0TkbFWdErQ/3q/bGFNN5qzcwi3jfmFxlj8WSU9J4sYje3HhQd3KDYJ9QX6sGqQls1eHpuzVoWmZ/apKVnY+S7JyWLJxB098u4j128v/MAZYutF9ePPHhjL7U5OFzi1c0N+1VSO6tm5I11auASDS/LvKSHRPuQX11eovwJRQgb2PqmaLyBTccPyrVXWTiMwFBldXJY0xpjptzS3gijfm8ONS/8yjIV1b8MzZg2jVOL3MuZMnT2bEiBFxff60lCR6tm1Mz7aNgXZljm3bWchyr2Nh2YYdLNm4gx8Wb2RLbtj84WXkFZawYlMuKzblRjyvaUZKaaDvbq4hoG3AKIDWjdPCfufK2p7npglQf6YJxCW4F5E9cEu4rVDVWZW4hK/HfEeEc3wf+rt2he1qEq7BaJCIDA5RpysCHgd+I65sOXA9Crfjku4twzUC7INbHm848LmIHKCqvwSUicvrFpFLgUsB2rVrx+TJkyNcLrKcnJyYytd19v5EVh/en/xi5YM/C/hqRVGZnu9eLZK4cO902pdk8v13oXN8Jer96QzcPiiJjxen8N3qIkqUqLL4+xQWqz/wD5Is0Kah0K5hEu0aCu0aJZU+btVASKrA8LucnBzGf/EtHy0p5PuAelbne7Y1r2SX5w9W13/HE6AtbgnY8iQDbQK211JzOimMMSZulm3cwUWvzCzzeXvSwI7cd/I+pKfE3psdq2YNUunfuTn9OzcHXBB9yH8nhTw3OUk4ok9btu0sYv32PNZtzys3b4DP9rwitudFHv6fJNC6cXqZBgBf4D/h17WlIyuDpyokSlWvJlDhD0UROQm4GPinqk4P2P934G7c9E1EZKyqnh3nekZFVZeIyBvAOcBHInI1bi5fE1y2+5txy82l4hL2xFTOK/t6iKpMAiaJyDjgZOA/wHFxe6H+534OeA5g8ODBGkvLXVW0/NUl9v5EVtffnx+WbOQf789n5eai0n2N0pK5bdSenDV0d5LK6blO9Psz+ijIys4LO+ff57NrD2bFJjcUb8WmHSzfmMuyTZGz7RYrrNuhrNux6wd2WnISnVs28Hr7vVsr1+vv6/HP2p7Hra9P4Ye1+a6nPOA/bKj3rLhE2VlYTG5BEXkFJeQWFrGzoNjdCovJ9e4Dt/O883cWlLDTO9+3f/vOQrKy88udyhCuPiYmq4FDRaRFmHw3iEhL4FBgTcDuNrjpb8YYU2f8uGQTl78xm207/b3gNx/VmytH9Kix69GPmbiIkjAJAZIEWjfJ4Jlz/EF1dl4h67fnk+UF++u357N+e15p8J/lbReFmfMfqEQhKzufrOx85q/eFva8wmLljZ9WMPmPLFo3TqdFw1RaNEyjecM0WjRMpXmjNFr6HjdMo0UjdzweUwMCVfVqAtG0eJ8NDAPm+3aIyN64uXJFwE9AX+AMEflAVT+I4tq+5phGEc7x9XJnV/CaV+CC8tFAcF3exfWsj2bXLwaVLRfJPbjg/ggRSVVV319rVbxuY0ycZecVcv+Ehbw5fWWZ/Yfs0Zr7TtqHTi0aJqhm0Que8x8qyA+V1A9gR34Ry71g393vcPebciMG/gXFJV4SwF17/FOThAZpyeTkF6EaOg/AsWO+2yVoLyiyxUPqkHeBW4EvReQ6Vf0x8KCIDAXG4EbMPePtE2Bv4HeMMaaOeGfmSu788NfSoDY9JYlHThvAMfvsluCahecb+h5u+lphse4yJL5JRipNMlK9If+hlZQom3MLSoP+9dvzWbctj6zsso835oRKaxbeqi07WbUlfDLiYBmpSbT0NQI0Si1tDNh1n9vfolEaTdJDJwwsnSZQhasJRBPcDwR+DlweDhfwK3Cxqr4mIt2B34BL2DUwjmS5d797hHM6B50bkaruAE70lrc7GtgNF5B/qaqTRMSXnG5+PMqVY6F3nwa0xg0lDHwtcXvdxpj4mvRHFnd+MJ81AWvKNs1I4e/H7cVfB3Wqsa3o5alIkB+sUXrobP4AOflFLN+4gxWbygb+yzbmRky8U1iiFOYVhT0OsGDN9ojHTa13L3Akbv789yKSictFo7jPxy640YFzvXPBfScpAj6q9toaY0ycFZco//1iIc9OXVq6r02TdF44d3Dp0PeaKlKvvU9lhsQnJQmtG6fTunF6yO8dPgVFJWzI8YJ9ryFgyYYdvDl9BRXo+C9XXmEJa7bllfkeWJ6UJKF5Q39DQPOGblTA/NVbKfKWSy4qqZppAtEE961wmdsDDcf1Pr8FoKpLReR7oE+U9Zjr3fcVkQZhMscPCTq3QrwegOBegCbAANwXg5ATRCpbLoxWAY8DJ41U2es2xsRma24B93z6Gx/MWV1m/xF7tePe0XvTrmndSMhSmez9oTROT2Hvjs3Yu2P4wH/5ph2lw/2/+HUtOfkVm3MXigg0SE12tzR33zAt8HEKGUH7GqS57dL9qWXPb5CazI6CIl75YTkfzV1dbmOHiQ9V3SEiw3GB+0W4YL5LwCm5wEvAnV4DPKo6B3/jtzHG1Fo78ou4/p15fP3b+tJ9e+3WlBfOG0yH5g0SWLOKmbNya9yy91dGWkoSHZs3oGPAe/X3D+eTnCSUhKhXSrIwcs+2nDKoM1tyC9iaW8CW3EK25haweYf/se++Mt8DikqUjTkFEUcVFJXsOqIhHqIJ7tPx5tUDiEgaLtCdoqqB3S7rgIOiqYSqZorIHGBf4K/Aa4HHvQ/9Tt61f9z1ClG7EmgAjFXV9eWdHIdyp3r3f6hq6fD6BLxuY0wFfPHrWv4+fkGZHueWjdL45wl9Oa7fbrW2tz6SeGXvDyVU4H/L0b0ZM3Ex783KpLhEI86r++Tqg2mQlkQDLwBvmJZMekpSlf0cHji5Hzce2avCIxpM7LxM+deLyO24pWY7eofWALMiLBdrjDG11tptO7nolVn8ttY/Qu3wPu147PQBNEqvHflCE7GiTSTlTRMoKlam/LGBf43eu9ygWlXZUVDMlh0FbAkI+LfsKGBzUCPAltwCtuwoZEtuQYUTBlZFkr9ofmvWAnsFbA/DBfzTgs5rDFRmDOV9wHvAAyLyg6ouBhCRtsBT3jn3q2rpREsv4d3VwAxVPTfwYiLSG8gKTM7jzdG7EPgXbqj9jcGVqEw5EekCHAy8r6r5QeXO9l4bwCPxeN3GmKqxITuf//v4Vz6fv67M/hP6d+D/jt9rl6VnTOWFmhZQWFwSMjv9Pp3CD8erzvpZkF/1vCD++0TXwxhjqtovq7Zy8auzyArIWXPpsO7cevSecV1atr6J5zQBEaFxegqN01Po3LLi+ZXyi4rZmltYGvAv37SDu8b/uktHRqh8BLGKJrifApwtIrcAX+ACXfUeB9obWBVtRVR1nIg8jUtoN19EvsFlph+JS6IzHngiqFhroDeuZzvYGcAdIjIbyMQtoTMYN9RvPTBKVdfGqVxL4E3gGa8nfg0uKV9f/OvTP6Gqz8bpdRtj4khVGT9vNf/85De2BqzR2rZJOv8+cR+O2KtdhNImFoFB9G2vTeGHtSU1KoiO17QFY4wxxufz+Wu54d155HlLw6QkCfeO3pvT9+tSTklTnkRPEwBIT0mmXdPk0imcn/2yhnCDDePdex9NcP9vXJb4+7ybAN+oauk8fBHpBXTHy2YbLVW90puzfxVuPn8yLhndS8DTUfZef4traBgE9AeKgaXAy8AjqhpurYTKlMsEHsTNj+8J7Ack4Rod3gGeU9Vvw1U0zq/bGBOFtdt2cueHv/Ltwqwy+08b3Jk7ju1DswapCapZ/dK2SQbn9k3n/nP3r5E95VU5bcH4iUgfoBeucTvkVyFVfS3U/gpc+0xcQ3o//J+zL1OJz1lvab6bcd+LugJ5uES7z4dZGrdS5UQkCdgfOAY4DJfTqDFuFOFs3PeL8WGe627g/yK8jHxVrRvJQ4ypJVSVpyYv4cEv/yjd1zQjhWfOHsSBPVsnsGZ1R6hpAolcirgyqwnEosLBvar+KSIHATcAbYEZuIA20EjgZ+DTylZIVd/CS9BXgXPvBu4Oc2wqMLUSzx91OVXdBNwS7XMFXaPCr9sYEztV5e2Zmfzns9/JzvenDenYvAH3n7wPh+zRJoG1q7+sp7x+EpEDgeeInJBXcCMGow7uReRJXN6cPGAi/hFyTwAjReSUigb43spA3+Iy+a8HvgKaAUOBQ0RkJHCBatlxoZUs1x3/9MfNuO9eW7z9o4BRIvIKcGHw8wX4GZgXYn9hiH3GmCqSX1TM7R/ML5Oot2urhrx0/hC6twm/JJyp3apqNYFwosrUoKq/4uaehzv+NPB0rJUyxpiqtHJTLrd98As/LNlUZv/5B3bl5qN615okNnWZ9ZTXHyKyJy7QbQj8ALTHTWl7GzcabiCup308EG7UXaTrn4wL7NcBw1R1kbe/HW7lmxOBa4DHKnjJsbgAfRxwnm+JYG/UwQTgPFxA/nwcyimuQeBB4GtVLc3S5CXd/Qw4H9cp8XKY+o73OkOMMQmyeUcBl70+i5nL/Q3WQ7u15JmzB9GiUVoCa2aqWnVPE7BvsMaYeqO4RHn1h+U8+OUf7Cz0ZzLt1roRD5zcj/26tUxg7Yypt27DBfaXqerzIvIy0E1Vz4LS4PdV3HD9Aypx/du9+1t9gT2Aqq4XkSuAycBtIvJ4eb33InIAburdNuBSX4DuXe93EbkJlyT3LhF5wdebXtlyqroEN8JgF6o6RUTux+VAOpvwwb0xJoEWZ2Vz4SuzWLm59M+eUwd34t7R+5CWkpTAmpnqUN2rCUQd3ItIT+Ay3AdsG+AjVb3FOzYUN0/9XVXdGsd6GmNMTKYv3cRVb80ps+ZoksAlw7rzt8N7kZGanMDaGVOvjQAWqWpwTzdQGvweBywG7iKKaXAi0gmXQ6cAFzwHX3uKiKzGLb23P27kQCRDvPvZgavqBPjKu++MC+anx1iuPHO9+04VPN8YU42+W7SBK9+cQ3aem/4nAreP2pNLDuleJ5fVNYkXVXAvIhcBTwK+8SOKy1jv0xA3LL8Qa0E2xtQAa7bkcuVbc5iXWXY0b+92TfjvKf3o37l5YipmjPFpjxte7lMMICLpvuVlVTVLRKbghtBHk+NmoHe/wFtmL5SZuOB+IOUH976JsRvDHM/GNSSk4RoVfEF6ZcuVZw/vPtTqPz77isgDQAvcvP3pwGeqWhChjDEmRq//tIK7P15Asbf8WYPUZB49fQBH9W2f4JqZuqzCwb2XTO9ZIAe4Eze/K/jDZwpuyNkJWHBvjEmgrO153P3Jb0yYv5bgmU7XH74HV47oacPhjKkZcoK2t3v3uwHLA/bvxAXh0fAtR7siwjkrg86NxLesRvcwxzvh7wAJvF5ly4UlIg2Ba73N9yOcerx3C7RKRM5W1SkVeS5jTMUVlyj3fvYbL09bXrqvfdMMXjhvMHt3bJa4ipl6IZqe+1twPfWjVPVHYJfhJKpaIiJziZzt1hhjqkzW9jzum/A7H/+8trS1PNj1h/eq5loZYyJYBQQu7rzQuz8Ur6NARFJxWeU3RHltX4/5jgjn+BoXmlTgepNw34UGichgVZ0VdPyKgMdN41AukqdwDQG/4VYaCLYEl29gArAM13iwD255vOHA5yJygKr+EuriInIpcClAu3btmDx5cgWrtaucnJyYytcH9h5FVlven51FytM/5/PLBn9en65Nk7huX2HjorlMXhShcAxqy/uTKPXp/YkmuD8AmOEL7CNYBwyufJWMMSZ6a7bkcvuHv/Ldog2EiemNMTXTNOACEWmqqttxQ/SLgUdEJAMX/F+C691+O3HVdAnuROQN4BzgIxG5GpeQrwku2/3NuKmJqUBJrOXCEZG7vHLbgFN90xeC6vp6iKKTgEkiMg44GfgPcFyY1/ocXqPB4MGDNZY1ohO5xnRtYe9RZLXh/cncnMvFr87ij4DA/ui+7XnktAE0SKvavD614f1JpPr0/kQT3DfDfcCWp3GU1zXGmEpbtSWXd2dm8vSUJeUuNWKMqZE+AI7EJdb7WFVXi8h9uOR5T3jnCLAVuCPKa/t65RtFOMfXu59dwWtegQvKR+PqHuhdXA/5aNz89niUK0NEbgDuwb22Uaq6oIL1DnQPLrg/QkRSVdXWvDcmBrNXbOGy12eVSdp75Yge3HRkb5KSLHGeqT7RBOFZVGweWG9gdeWqY4wx5SssLuHbhVmMnbGSKX9uQC2mN6bWUtWJ+BPD+fb9n4j8ApwCtMQN1X9UVSPNnQ9luXe/e4RzOgedG5Gq7gBO9Ja3OxqXG2Az8KWqThIRX1K++fEoF0hErgH+h8s/cFwFRlOG45v6kIZLjBwpIZ8xJoKP5q3m5nG/UFDkBt2kJgv3ndSPUwbZIham+kUT3E8DTgkzVwwAETkCtw7tC/GonDHGBMrcnMs7MzN5d1YmWdm7jEKldeM0jt1nN7buLOTLX9dRrGq9+cbUUqr6PpETxVWEb6m4viLSIEzG/CFB51aIF1iXCa5FpAkwACjCDYGPWzkRuQoYA+QBJ8SYDK9VwOPghIbGmAiytudx9di5PHHGQN6asZJHv/FPpG/RMJVnzxnMft1aJrCGpj6LJrh/BPgr8IGIXAx8E3hQRIYBL+E+mB6PWw2NMfVaYXEJE39fz1szMvluUehe+kP2aM2Z+3VhZJ92pRnws7LzGDNxMeNmZVqQb0w9paqZIjIH2Bf3Hea1wOMiMhw3l38dQQF3JV0JNADGqur6eJUTkctxUxTygdGq+k3wOVE61bv/Q1UrOh3BGAOMmbiImcs3c+qzP7J8U27p/h5tGvHS+UPYvVWkWUDGVK0KB/eqOl1EbgEexGVe3Y7L/DpaRI7FDesS4AZVDTukzBhjKmLlplzenrmSd2etYmPOrr30bZqkc+rgTpw+pAudWzbc5XjbJhncO3pvrh3Z04J8Y+q3+4D3gAdE5AdVXQwgIm1xGecB7lfV0kR2XsK7q3GJhM8NvJiI9AayVHVLwD4BLgT+hRtqf2NwJWIod4lXz3zgRFX9srwXLCJdgIOB9wOT7XnPd7b3noDruDHGVFDW9jzem70KVcoE9gf3bM2TZ+1LswapCaydMVEmvlPV/4nIb8Dd+IexNffu5wN3qerHcaudMaZeKSgq4Zvf1zN2xkq+W7Rxl+MiMGyPNpyxXxdG9mlLanL569QHB/lzVmwpt4wxpnqJSE/gVtzydx2A9DCnqqpG+91lnIg8jUtoN19EvsFlph+JW3ZuPP7EfT6tcTmE1oW45BnAHSIyG8gEknGrBHUB1uOS3IWawx51OREZADyL6zxZBpwmIqeFuPZGVb0pYLsl8CbwjDdyYQ0umV9f/PmTnlDVZ0NcyxgTYFNOPn+uz+HP9dm88dMK8ovKLmhx5tAu/POEvhX6TmJMVYs6q72qTgAmiEgr3AdEMpCpqmviXTljTP2wfOMO3p6ZybjZmWUyzfq0bZLOaUM6c+rgziF76SvCF+QbY2oWERkMfIvLaF9eWulKpZ1W1StF5HvgKtwa78m4pHIvAU8H9tpXwLfA3sAgoD9u2b6lwMvAI6q6LY7lmuN/zXt6t1BWAIHBfSZupOUQoCewH5CEa6x4B3hOVb8t53UaU69szS0oDeIXrc/mj/XZLFqfw6Ydu34v8UlJEq4fuYcF9qbGqPSSdaq6CdgUx7oYY+qRgqISvvptHWNnrGTa4l3/lYjAiF6ul/6wPduSYh+cxtRV/8UtR/cO8ACwyMssH1eq+hbwVgXPvRs3SjHUsanA1Eo8f9TlVHUylWjQ8L6j3RJtOWPqg+15hSxanx0QyOfwx/psNoRI1FseERjz7WLrPDA1hq1Hb4ypVss27uDtGSsZN3tVyNbw9k0zOHVIZ04d3IlOLSrXS2+MqVWGAr+r6hmJrogxpubJ2p7Hf6bvZK9BebRtklHhcjvyi1iU5QL4P9dl82dWDovWZ7N2W15Uz98gNZmurRvyx7psSoLS9hQWK+NmZXLtyJ5R1c2YqlLh4F5E/lHBUwuAjcBsVY1qWRljTN2RtT2PMRMXMWflVj686kC+WrCet6av5Melu/bSJwkc2rstZ+zXhRG921gvvTH1y07g50RXwhhTM42ZuIhFW0oYMzF0D/nOgmIW+4L4LC+QX5/D6q2hVr4MLy0liZ5tGtOrXWN6tW9Cr7ZN6NWuCZ1aNOAfH/3K4qwcSkIk5S1WDVs3Y6pbND33d+Oy45dHfOeJyC/ABao6L+qaGWNqJV9Q/97sVRSrUlSsHHDft2wO0Uu/W7MMTh3cmdOGdKZD8wYJqK0xpgaYAXRPdCWMMTVPaXZ64L1ZmRzdtz0bc/JdIL8+h0VZ2azcnBtymdxw0pKT6N6mEXu0a0Kvtl4g364JXVo2JDlp11kwvjqEW23Heu9NTRJNcH8PLqPr+cAO4GtcApcSoCtwBC4Zzqu4te4PxiWM+UZE9lXVlXGrtTGmxsnanserC/L54ZtJFJWUUByQniowsE8SOGxP10s/vJf10htj+DfwrYicpKofJLoyxpia49+f/0ah94Uiv6iEs1+cXuGyKUlCt9aN6NWuCXu0a0yvdi6I79qqYVTfPcZMXERJOa0H1ntvaopogvsXgdnAWOBaL1lLKRFpATwOHItb2mW1t305LoPrtfGosDGmZvH11L87axWFxSVhh/d0aJbBaUO6cOqQTuzWzHrpjTGOqk4TkdOB50XkROBLYBWu8yDU+VEnszPG1D4/LNnIR/NCrSpZVpJA11aNvOC9seuRb9eEbq0bkZYSewfCnJVbw/ba+xQWqy21a2qEaIL7e4E84HxVLQw+qKpbROQCYAlwr6qeKyI3AacAR8altsaYGiNw+H1RcQmRPvdePn8Iw3q1CTnczRhjgDQgFzjTu4WjWDJgY+q8X1dv4/yXZoQ81rF5A0YP7OB65Ns2oXubRmSkJldZXT6/7pAqu7Yx8RbNB+SRwORQgb2PqhaKyA+4Ifqoaq6I/AwcGFs1jTE1zdVj5zJz2eYKJeI4dM+2VV4fY0ztJCInA2/i1mHfBCwHchJZJ2NM4sxduYWzX5xOQZheg005+Zx3YFeb325MCNEE981x69CWp5F3rs+GKJ7DGFMLqCpH792eOSu2UBS8LowxxkTnDlwy3iuB51Q15HB8Y0zdN2PZZi54eQY7CorDnmPz240JL5qJKMuAQ0WkS7gTvGOHeef67IZriTfG1AFbcwu4+q253PPJb2UC+5QkIdlG3RtjorcnME1Vn7HA3pj6a9rijZz3UuTAHvzZ6bOyo1uv3pj6IJrg/lWgITBJRM4QkdLJLSKS7CXDmQRkeOciIim4jPm/xq/KxphEmbZ4I0c/+h2fzfcnuOnWuhHjrzqIH24/jOGdUshISSLVonxjTMVtwyXQM8bUU5P+yOKCV2ayszByYO/j6703xpQVzbD8/wEjgKOAN4BXRWQtLrlNByAZN6zuS+9cgL7AAuCtONXXGJMA+UXFPPTlHzz/3bIy+8/YrzN/P3YvGqW7fyXn9k3n/nP3Z8zExYyblUmxarkZZo0x9d5XwEEiIqrRrFZtjKkLvlywjqvfmlP6fSElScqd8mfZ6Y0JrcLBvaoWicixuCXtrsWtbd854JQVuKXvHlPVYq/Mz4ClmDSmFvtjXTbXvT2XheuyS/e1bJTG/Sftw5F92+9yftsmGdw7em+uHdmTMRMX24evMaY8d+KW2n1IRG5V1aJEV8gYUz0++XkN178zj2IvmO/UogFvXbw/XVo1LHPe5MmTGTFiRAJqaEztEtVyMt5cuEeBR0WkE9DRO7RGVTPjXDdjTAKVlCiv/LCc+79YSEGRfxrs8F5tePCv/crNUusL8o0xphwXAROA64HRIjKJ8Ovcq6r+qxrrZoypIuNmr+KWcT/j66Tv2qohb12yPx2aN0hsxYypxSq9VqyqrsLmyBlTJ2Vtz+PG937mu0UbS/elpyRxxzF9OPeA3RGxOfXGmLi5GzfFT4Bu3i2Y77gCFtwbU8u9NX0ld3w4v3S7Z9vGvHXxUNo2teXtjIlFpYN7Y0zd9MWv67j9g1/YkltYuq/Pbk157PQB9GrXJIE1M8bUUffggnZjTD3w8rRl/POT30q392zfhDcuHkrrxukJrJUxdUPUwb2IZACHAr2ApriW9GA2bM6YWmZHfhH3fPIb78zyz7ARgUsP6c4NR/YiPSU5QmljjKkcVb070XUwxlSPZ6Ys4f4JC0u3+3VqxmsX7kfzhmkJrJUxdUdUwb2InAw8A7SMdBo2bM6YWmXuyi1c/848VmzKLd23W7MM/ndqfw7s0TqBNTPGGGNMbafe0nWPfPNn6b59uzTnlQv3o2lGagJrZkzdUuHgXkSGAm/jEtyMBfYG9gHuB3oCRwDNgBexufjG1ApFxSU8OWkJY75dVJqpFuC4frvx79H70KyhfeAaY4wxpvJUlf9++QdPT15Sum9ot5a8eP4QGqfbDGFj4imav6ibgCRgtKp+JiIvA/uo6p0AItIaeBk4Btg37jU1xsTVik07+Ns785izcmvpvibpKdwzui+jB3S0pHnGmCohIud6Dz9U1eyA7QpR1deqoFrGmCqgqtzz6W+8PG156b5D9mjNc+cMpkGaTfczJt6iCe4PBH5V1c9CHVTVjSJyJrAM+CdweRzqZ4yJM1Xlvdmr+OfHC9hRUFy6f0jXFjx86gA6t2wYobQxxsTsFdz0vZ+A7IDtirLg3phaoKRE+ftHv/LW9JWl+w7v05YnztyXjFQL7I2pCtEE962BaQHbRQAi0kBVdwJ4LfBTgVHxq6IxJl627Cjgjg/nM+HXdaX7UpKEvx3Ri8uH9yA5yXrrjTFV7jVcML8taNsYU0cUlyi3vv8L42b7Z+qO2rs9j50+kLSUpATWzJi6LZrgfgsQuEbFVu++E7AoYL8CbWOrljEm3r5ftJEb35vH+u35pfu6t27Eo6cPoF+n5omrmDGmXlHV8yNtG2Nqt8LiEm5492c++XlN6b7RAzrw0F/7k5Jsgb0xVSma4D4T6BKw/SsuM/5xwCMAItIIOBhYHa8KGmNik1dYzINf/sGL3y8rs//MoV34+7F9aJhmyWyMMcYYE7uCohKuGTuHLxesL9132uDO/OekfWx0oDHVIJrms8lAXxFp421/CuQC94nIAyJyjXdOa+DrylZIRM4Uke9EZJuI5IjILBG5SkSibuoTkZYicp+I/C4iO0Vki4hMFZFz4llORJJE5EARuVdEfvDOLxSR9SLyuYiMjvBcd4uIRrjlRfu6jfFZuG47o5+cViawb9kojefPHcx/TtzHAntjjDHGxEVeYTGXvzG7TGB/zv67c58F9sZUm2i+2b8HDAAGAl+p6iYRuRF4CpdJH1xPfiZwV2UqIyJPAlcCecBEoBAYCTwBjBSRU1S1pILX6g58C+wOrAe+wi3VNxQ4RERGAheoqsahXHf8+Qg2AzNw0xi64/IPjBKRV4ALg58vwM/AvBD7Cyvyeo0JVFKivPzDch74YiEFRf4/mUN7t+GBU/rRtklGAmtnjDHGmLokt6CIS1+bzfeLN5buu/jgbtx5bB9bfceYalTh4F5VZ+DWsg/c96yIzAZOBloCC4GXVXVrtBURkZNxgf06YJiqLvL2twMmAScC1wCPVfCSY3EB+jjgPFXN9a7XB5gAnIcLyJ+PQznFNQg8CHytqqUpyEVkOPAZcD4wFbdcYCjjVfXuCr42Y8Javz2Pm977me8W+T9g01OS+PuxfTh7/93tQ9YYY4wxcZOTX8SFL89kxvLNpfuuPrQnNx7Zy75zGFPNYs5qoaqzVPV2Vb1MVR+pTGDvud27v9UX2HvXXw9c4W3eVpHh+SJyALAfLhPvpb4A3bve7/hHGtwlAf91KltOVZeo6khV/SIwsPeOTQHu9zbPLq/uxsRiwvy1HPXo1DKBfd8OTfns2oM554Cu9iFrjDHGmLjZtrOQc16cXiawv/GIXtx0VG/7zmFMAtSICbci0gkYBBTghv+XoapTRGQ10BHYH/ihnEsO8e5nq+qWEMe/8u4744L56TGWK89c775TBc83Jio5+UXc9v4vfPrL2tJ9InDZsB7ccEQvW3bGGGOMMXG1ZUcB57w0nV9Xby/dd+cxfbhkWPcE1sqY+q1Swb2IJAOtgLATd1V1ZRSXHOjdL1DVnWHOmYkL7gdSfnDf2LvfGOZ4Nq4hIQ3XqOAL0itbrjx7ePdrI5yzr4g8ALTAzdufDnymqgUVfA5TT33z23quf2cuOfn+QSMdmmXw8GkD2L97qwTWzBhjjDF10YbsfM5+YTp/rM8u3XfPX/py7gFdE1cpY0x0wb2IHAz8H265u7QIp2qU1+7m3a+IcI6vsaBbhHN8srz7cE2HnfDXP/B6lS0Xlog0BK71Nt+PcOrx3i3QKhE52xvab0wZa7bkctkbs5kf0GLuM+H6YTRrkJqAWhljjDGmLlu3LY8zX/iJpRt2AG6k4P0n7cNpQ7qUU9IYU9UqHICLyJG45e98ZTYBOXGqh6/HfEeEc3zP1aQC15uEa2AYJCKDVXVW0PErAh43jUO5SJ7CNQT8BjwX4vgSXL6BCcAyXOPBPrhGlOHA5yJygKr+Eu4JRORS4FKAdu3aMXny5ApWbVc5OTkxla/rasL7szWvhPcXFfD96mLCLb0wd/q0MEeqVk14f2oye38is/fHGGNqtlVbcjnrhems2OTSUiUJ/O/U/pw40GaeGlMTRNO7/i/v/IeA+8LMSa8RVHWJiLwBnAN8JCJXA5NxDQPnATfjlphLBUpiLReOiNzlldsGnKqq+SHq+nqIopOASSIyDrcSwX+A4yK83ufwGg4GDx6sI0aMKK9qYU2ePJlYytd1iXx/srbnMWbiIt6bvYrC4pKwgT2QsDra709k9v5EZu9P/SAiMXXvRTntzxgTJys27eDM56ezequbQZuSJIw5YyDH7LNbgmtmjPGJJrjfB5do7pYqqIevV75RhHN8vfvZEc4JdAUuKB8NfBB07F1cD/lo3Pz2eJQrQ0RuAO7BvbZRqrqggvUOdA8uuD9CRFJV1da8r8euHjuXmcs3o5GiemOMqfmWQ8T2yUiinfZnjImDxVk5nPXCT6zf7vqp0pKTeOqsfTl8r3YJrpkxJlA0H5DbgUXlnlU5y7373SOc0zno3IhUdQdwore83dHAbriA/EtVnSQivqR88+NRLpCIXAP8D9gJHKeqP1akziEs9O7TgNZETshn6rgnzhzImImLGTtjJcUlFuEbY2qtlVQ+uDfGVLOF67Zz9gvT2ZjjcjynpyTx3LmDGd6rTYJrZowJFk1wPxXYu4rq4Vsqrq+INAiTMX9I0LkV4gXWZYJrEWkCDACKcEPg41ZORK4CxgB5wAkxJsMLTHUer/wGppZq2ySD20btyfi5q0oz4ycniQX6xphaRVW7VvdzisiZuJF5/YBkXOP5y8DTqlruNLuga7XETdMbDXTFfd7PB54PM9UupnJe2aOBG4DBuJWKlgJjgYdCTfkLKDcUuA04CJcrKBP4EPi3qm4r77Ua8+vqbZz94nS25rrBow3TknnhvMEc2KN1gmtmjAklmsWv/wl0FZHr410JVc0E5uB6qP8afFxEhuMy1a8jKOCupCuBBsB7qro+XuVE5HLgCSAfGK2q38RYz1O9+z9UtaLTEUwd9tb0FaWBfacWDTh9SGcyUpJITZYE18wYY2omEXkSeBMXGH8HfA30wn1ejxORCn8XEpHuuO8rt+GWrv0K+BnXAfGaiLwiIrv8Q65sOa/sLbiku4d51/gMaAvcC0z2VuUJVe4MYBquMeFP4CPc96ybgVki0rair9vUL1nb8zj12R+ZuHA9Zzz/U2lg3zg9hdcu3M8Ce2NqsAr33KvqAi9j/lgROQX4AlhFmMRyqvpalHW5D3gPeEBEflDVxQDeh89T3jn3B7awewnvrgZmqOq5gRcTkd5AVmDiP++D80JccsDNwI3BlYih3CVePfOBE1X1y/JesJdU6GDg/cCWd+/5zvbeE4BHyruWqfvyCot5/rtlpdtXHdqTM/brwnWH78GYiYsZNyuTYlUKi60n3xhjAETkZFzD/DpgmKou8va3w43AOxG4Bnisgpcci5tCOA44T1Vzvev1wQXg5+EC6ufjUU5EBgP3A7nAYao63dvfGBfkDwP+DfwtqFwn4EVAcJ0NH3n7U4A3gNOAZ73Xb0wZYyYuYuayzcxZsZki71t3swapvHbhfvTv3DyhdTPGRBZtUppDcEPFuwAHlHNuVMG9qo4Tkadxw+bmi8g3uMz0I3FDycbjWtkDtQZ64z60g50B3CEis3HD0JJxrfZdgPW4JHeh5rBHXU5EBuA+JAW3nN1pInJaiGtvVNWbArZb4noTnhGROcAaXDK/vrjl8wCeUNVnQ1zL1DPvz1nFhmzXBtSuaTon7dsRcMP17x29N9eO7MmYiYuZs6LGLmRhjDHV7Xbv/lZfYA+gqutF5Arciji3icjj5Q3P93Lx7IdbAedSX4DuXe93EbkJ10lxl4i8oOrSn1a2nOc23HeLB3yBvVcuR0QuwOVCulJE/qmqWwPKXY8bafiyL7D3yhV5y+eOAkaLyF6q+luk123ql6ztebwzKxOF0sC+ZaM03rhoKHt1qOgq0MaYRIlmnfvLgAe8zZ+BxcR5HriqXiki3wNX4dZ4982Le4no58V9i8sRMAjoDxTj5qi9DDwSYa5ZZco1x334Auzp3UJZAQQG95nAg7hheT1xH/5JuMaKd4DnVPXbcl6nqQeKikt4ZsqS0u1LDulOekpymXN8Qb4xxtQ2IpIGXAecghsyHy6KUFWt0HcXr/d6EFCAC56DLzRFRFYDHYH9gR+Czwniy/0zO8xywF95951xn+e+YLxS5bz3ZJR37M0Q9V8qIj/i5tMfA7wVcHh0hHLbReQT4CzvPAvuTanbP5hfZgRgRmoS71y6P3u0a5LAWhljKiqanvvrcD3pf1HVL6qoPqjqW5T9gIp07t3A3WGOTcUlAYz2+aMup6qT8Qf30ZTbBFTF0oKmjvn0l7VkbnZ5Jps3TOWM/WJaJtoYY2oMEcnADZHfj/I/S6P5rB3o3S8Ik6gXYCYuuB9I+cG9b0nejWGOZ+MaEtJwjQq+4L6y5XoDDYHNqrokTNmZuOB+IN53JxFpCvQIOB6u3Fn43yNjWLMll4kLs8rsU4VmDVMTVCNjTLSiSajXFZhalYG9MWZXJSXK05P93+suOLAbjdJtmWdjTJ1xAzAUl8unF25anwLpuGlq9+Eyy/9bVaP53uKb3rYiwjkrg86NxBf1dA9zvBMuQA++XmXL+R6vJLxQ9e/q3W9V1e1RlDP13M3jftllX4kqYyYuTkBtjDGVEU2EsAHYVFUVMcaENnFhFn+sd4slNEpL5rwDd09wjYwxJq5OAbYDZ3hDxhVAVQuB34E7ReQ74DMRWaCqb1fwur4e8x0RzvFNL6zImONJuEaHQSIyWFVnBR2/IuBx4LSCyparbP3j8rq9ufmXArRr147JkydHuFxkOTk5MZWvDxL9Hm3OK2Hakl0HuBQWK+/MWMHgBlk0T4+mbS2+Ev3+1HT2/kRWn96faIL7j4CTRCRNVQuqqkLGGD9V5clJ/hbzs/bfneYN0yKUMMaYWmcP4IeAXmZfIrpkVS0GUNUvRGQmboWcigb3caWqS0TkDeAc4CNvxZ7JuAD5PNwSc4VAKgErCVW2XKKp6nPAcwCDBw/WESNGVPpakydPJpby9UGi36PzXpoOhJm9IsKsnW2596jE5fVJ9PtT09n7E1l9en+iaYL7B25e2Gsi0rKK6mOMCfDj0k3My9wKQFpyEhcfbCMojTF1ThJlRwb6IozmQectwSW8rShf73SjCOf4ermzK3jNK3Cr93QAPsAtj7sCuMfb/sw7b3McylW2/lXxuk0dtn7bTqb+GS4lhOu9Hzcrk6zsvGqslTGmMqLpuX8YNzzur8DRIjKL8Ovcq6peFIf6GVOvPTXJP9f+lMGdaNs0I4G1McaYKrEGF/T6rPLu++GGtPt0xevVr6Dl3n2kuUydg86NSFV3ACd6y9sdDeyGC8i/VNVJIuJLyjc/DuV8dYqUQTVU/X05BpqLSNMw8+6jet2mbrvjw1/L/cMq9ube26o8xtRs0QT35+M+VAU3J+ywCOcqYMG9MTH4OXMr3y92LelJApcP61FOCWOMqZV+xS1F5zMV913jbhGZparZInIGcADwYxTXnevd9xWRBmEy5g8JOrdCVPXH4LqISBNgAFBE2UaJypZbiBvF0FJEeoTJmL9fcP1VdZuILMFlzB8CTKxIOVN//bi0/JRahcXKnBWhVnI0xtQk0QT3F1RZLYwxu3hqsn+u/fH9O9ClVcME1sYYY6rMBOAvIjJCVSer6jRv/fZDgE0iko0boq/AQxW9qKpmisgcYF/cqMPXAo+LyHBcpvp1RNdoEM6VQANgrKquj7WcqhaIyATgJNyydfcEFhKR7rgGjwL8w/p9PsKtQnAWQcG9t1Te8d7mh1HU09RBc1duIbegGIDkJGHSjSPs+4YxtViFg3tVfbUqK2KM8Vu0PpsvF/i/G14xwnrtjTF11lvAAsoOET8ReBEYBbQAtuCWwos2GL0PeA94QER+UNXFACLSFnjKO+d+VS2dYuglvLsamKGq5wZeTER6A1mquiVgnwAXAv/CDbW/MbgSlS0H3I97L24VkS9UdYZXtjHwEi5fwVOqujWo3KO4ef7nich4Vf3YK5cCPIsbgTleVX8L9aaZ+uOpgKV2T7COBGNqPVss25ga6Okp/g/bw/u0Y8/2TSOcbYwxtZeq5gDTgvZlAceLSEOgGbA+MACP4trjRORpXKA7X0S+wWWmH4kX4AJPBBVrDfTG9egHOwO4Q0RmA5lAMjAYNy9+PTBKVdfGq5yqzhSR24AHgB9E5FtgKzAcaAtMB+4MUS5TRC4CXgfGi8j3uNwG++NyECwGLgtRT1OP/LEum69/s44EY+oSC+6NqWEyN+fy0bw1pdtXHmoftsaY+klVc4HcGK9xpRfcXoULipNx89lfAp6OstHgW1zG/kFAf6AYWAq8DDyiqtviXA5V/a+I/ILr2R8CZHhlxwAPqWp+mHJjRWQpcDtwEDAU17DwIG4URNjnNPXD0wHT/47cqx292jVJYG2MMfEQNrgXkX94D59Q1c0B2xWhqvqv2KpmTP30/HdLKS5xeWsP6N6Kfbu0SHCNjDGmdlPVt3DD/yty7t3A3WGOTcUl/Iv2+StVLqD8F8AXlSg3HRhd2ec1ddfKTbl8/HNgR0LPBNbGGBMvkXru78Ylr3kbNxfMty0RyviOK24OmTEmChuy83lnZmbptvXaG2PqGhHxzWP/0MuEf27EAkFU9bXyzzLGRPLM1CV4/Qgc1LMVAzo3T2h9jDHxESm4vwcXpG8M2jbGVJGXpi0jv8iNEO3XqRkH92yd4BoZY0zcvYL7PvETkB2wXVEW3BsTg6zteYybtap0+6oR1mtvTF0RNrj3hqWF3TbGxNe2nYW8/uOK0u0rR/TEJVM2xpg65TVcML8taNsYUw1e+H4ZBcWuI2FA5+Yc0KNVgmtkjIkXS6hnTA3x+o/LyckvAqBn28YcuVe7BNfIGGPiT1XPj7RtjKk6W3MLeOMnf0fCVYdaR4IxdUlSoitgjIGdBcW8NG156fYVw3uQlGQftsaYuk9E+onI3omuhzH1wSs/LCe3oBiA3u2aMHLPtgmukTEmniy4N6YGeHvmSjbvKACgY/MGnDCgQ4JrZOJuayb7zr4ZtmaWf64x9cs84MlEV8KYui4nv4iXAzoSrjzUOhKMqWssuDcmwQqKSnhu6tLS7cuGdyc12f4065wJt9AkexFMuDXRNTGmptmKW3/dGFOFxk5fybadhQB0admQY/fZLcE1MsbEm0UQxiTY+HmrWbstD4DWjdM4dXDnBNfIxN3K6bB0EoLCkm/dtjHGZx5g634aU4Xyi4p5/ruyHQkp1pFgTJ1jf9XGJFBxifLM5CWl2xce3I2M1OQE1sjEXUkJfHwNFO5020U74ZNr3X5jDMAYYKiIHJ3oihhTV70/ezVZ2fkAtG2Szsn7dkpwjYwxVcGy5RuTQF8uWMfSjTsAaJKRwtn7757gGpm4++Ud2Laq7L6tmTD/Xeh/emLqZEzNMgd4AvhIRF4CPgRWADtDnayqK6uxbsbUekXFJTwzxd+RcMkh3a0jwZg6yoJ7YxJEVXly0uLS7XMP2J2mGakJrJGJu/wc+OI2KNxRdn/hDre/z/GQ1igxdTOm5ljm3QtwqXcLR7HvLsZE5bP5a1m5OReAZg1SOXNolwTXyBhTVewD0pgEmfLnBhas2Q5ARmoSFxzULcE1MnE39SEoyg99rDDPHT/8/6q3TsbUPJm4oN0YE2clJcpTk/y99hcc1JVG6fb135i6Kuq/bhHpCBwKdAAywpymqvqvWCpmTF33VMBc+9OHdKF14/QE1sZUidkvuTn2oRTthJnPW3Bv6j1V7ZroOhhTV01cmMUf67MBaJiWzPkHdk1shYwxVarCwb2ICPAocCX+RHzBi2Oqt08BC+6NCWPW8s3MWLYZgJQk4ZJh3RNcIxN3JSXQti+s/CH8OfnZ8P4lMOwmaNO7+upmjDGmzgue/nfW0C40b5iWwBoZY6paND33NwPXACXAF8BCYHtVVMqYui6w1/7EgR3p2LxBAmtj4i5vO4y/InJg7zP/XZj/HvQ9EYbdDO32qvr6GWOMqfN+XLqJeZlbAUhLTuLiQ6wjwZi6Lprg/gKgEBipqt9XUX2MqfMWrNnGtwuzABCBy0fY8s51ysZF8PaZsPFP/z5JAg1Y+i4lA1p0hQ0LvR0KCz5wtz7HuyB/t/7VWWtjagwRaQY0ZdfRgYBlyzemogLn2p8yuBPtmoabTWuMqSuiWee+G/CdBfbGxObpgF77UXu3p0ebxgmsjYmrhZ/D84eVDez3uxxa9Sx7XouucMWPcNE3sMeRZY/9/gk8OwzeOh1Wz67yKhtTE4hISxF5UkTWAZuB5bgs+sG3pQmrpDG1yLzMrXy/eCMASQKXD7OOBGPqg2iC+61AVhXVw5h6YdnGHXw+f23p9pUjekY429QaJSUw+X54+wzI92YrpWTAic/BMQ/ACU9Aqjf1IqUBnPA4JCVB5yFw1ntw6WTofWzZa/45wTUUvHEyZM6o1pdjTHUSkRbAdOByoCVufXsB1vlO8e5X4jLrG2PK8VTAXPsT+negS6uGCayNMaa6RBPcfwsMqaqKGFMfPDtlCSXegk/DerVh747NElshE7u8bW4Y/uT7/PuadYELv4T+p7ntLkOh+2EoAj0Og877lb1Gh4Fwxltw+few118oMxp58Tfw4hHw6gmwfFqVvxxjEuBWoAfwMtAMGIdbdacj0AS4DNeb/72q2pqhxpTjz/XZfPXb+tLtK6wjwVSlZw6GT2+A7HXln2uqXDTB/V1AGxG5q6oqY0xdtm5bHu/PWVW6fZXNta/9Nvzhetf/nODf122Y64nvMKDsuaMeILvJHjDqgfDXa78PnPoaXPkj7H0KZYL8ZVPglWPg5WNh6WRQWxbc1BnHAxuAq1R1JwFr3qtqrqo+D4wCzhCRKxNUR2NqjcDpf4f3aUfv9k0SWBtT562bD3Nfh8f6W5BfA0STUO8gXKv63SJyDDABN0SuJNTJqvpa7NUzpu54/rulFBa776yDdm/Bft1aJrhGJia/fwofXgYFOf59B1wNh/8TkkP8a23emTmDHmRE887lX7ttHzjlRRh+K3z3P5dNX4vdsRXfw2vfQ+ehMOwW6DnSZWY0pvbqCkxW1XxvWwFEJFnV/eKr6iwR+R64CHgqIbU0phbI3JzLxz+vKd2+8lDrSDDVoLjA3c99Hea9CQPOguG3QJP2ia1XPRRNcP8K/nXshwL7RTwbLLg3xrN5RwFvTfcneL7q0B5IrAHZ1kx473z46ytQkYDRxEdJCUz+D0x90L8vpQH85QnY55T4PlebXnDSs+4D8vuH4ee3oaTIHcucDm+eDB32dcd7HW1Bvqmtiim7tO4O7741sD5g/xrguOqqlDG10bNTl1Dszf87sEcr9u3SIsE1MnVaSVAfry/In/2KBfkJEk1w/xoBQ+WMMRX3yg/L2Vnoel73bN+EQ3u3jf2iE26BNXNgwq1uvrapeju3wgeXwKKv/Pua7w6nv+mG1FeVVj3gL0+6nvrvH4G5b0BJoTu2Zg6MPR3a93MfoL2Pdcn6jKk91gCBLZTLvftBwOcB+/sA+RhjQsranse7swKm/x1qc+3rpGcOhk77VX3QrAo7NsL2VbBtNWxfE/B4tbvPXhOmbDEUFcOc1yzIr2YVDu5V9fwqrIcxdVZOfhGvTFtWun3loT1j77VfOR2WTnJrpy/51m13GRpjTU1E63+Dd86CzQErcXU/FE55CRpW0xSLFrvD8Y/CsJtg2mMw+1Uo9mKddb/AO2dD271g2M0uMV9Scvhr2cgPU3PMAY4MGIY/ETdK8H4RWQasAq4E+uOS+xpjQnjx+2UUFLme1P6dm3Ngj1YJrpGpEuvmu5w/gUFztFRh5xbYtip00L7dC+aLY2xPLSl0t9kvw4bf4YIJ5ZcxMYmm594YUwlvTV/B9jw3lLprq4Ycu89usV2wpAQ+vgYKd7rtop3wybVu3XTrsa0aC8bD+CuhcId/30HXw8h/RA6gq0qzTnDMg3DIjTBtDMx6yf0eAGT9BuMugNa9XJDf96TQOQBs5IepOSYApwNHA5+p6jwR+QSXaO/XgPMUuCcB9TOmxtuaW8AbP60o3b5yRBTT/6qrJ9jET9Ac9z3ajIBBe7qfn6pbyWf7Gi9YXxUUtHuPfd8bqkPz3eH4MdX3fPWYBffGVKG8wmKe/87fa3/Z8B4kJ8XYa//LO+4fdaCtmTD/Xeh/emzXNmWVFMO3/3JD4X1SG7oh8nuflLh6+TRpD0f/Bw7+G/z4OMx4wd8AsfFPN4Vg8v2uEaDfqZCc6o7ZyI/y2ciG6jQW1yO/LWDfmcD9wClAS2AhcI+qTq3+6hlT8736wwp2FLjpf3u0bcwRfdpVvHConmAL8msHL8jvsPZLeLgPpDVxPeWFufG5fnozaNYRmnb035c+7gRNO8B/QnRaJaV6OYK8Gd1blsEbJ8MpL0OnQfGpmwmpUsG9iPQBegFNKbNWk59lyzcGxs1exYZsN6SpXdN0Ttq3Y2wXzM+BL24r24MMbvuL26DP8ZDWKLbnME7uZnj/Ylgy0b+vRTc3v75d38TVK5TGbeCIe+DA6+Cnp2D6s1CQ7Y5tXgIfXQlTHoBDboB+p9vIj4qwkQ3VRlWLgNVB+3YA13g3Y0wEO/KLePmHwOl/PUiKtiPBsp3XaoK6Hvv8beWf7JPW2AvWO/iD9dJA3gvc06NcRjE5DSTJ/f4ceC1Me8Ql1wPYugJeOhIOu8sds+8cVSKq4F5EDgSewyW1CXsarpmmUsG9iJwJXAH0A5JxrfUvA0+rashl9yJcqyVwMzAat9ROHjAfeF5VX493Oa/s0cANwGAgA1iK65V4KGCZn1DlhgK34ZYcbApkAh8C/1bVKP5STU1RVFzCs1P9a81eckh30lNiHMI99SEoCvNrlLcNXj8RDr0DOg2xID8W63518+u3LPfv63kEnPw8NKjBmYcbtYKRd8GBV8NPz8D0p93vBbgP1U+ug6/v3rVxyEZ+lGUjG4wxtcjYGSvZmuuSrHZq0YDj+3Wo/MV8Qf6cV2DeGzDgbAvyawpVWPljxc9PyYDO+5UN2gN74DOaxW+VncCgfvit0MQbOXL8Y9B9BHx8LeRvd7353/wfLJsCJz4LjeOQYNqUUeHgXkT2BL4CGgI/AO2BbsDbQE9gIC4YH0/ZoXUVJiJP4pLm5OES6hQCI4EngJEickpFA3wR6Y4b5rc7bimdr4BmuGX8DhGRkcAFqqrxKOeVvQV4ALesz2RgCzAcuBc4TkRGquou42RE5Azgddz7Nw3Xg7E/roHhRBE5SFWzKvK6Tc3x6S9rydzsekebN0zljP26xH7R2S+FnyOlJW55tNf+Akkpbom0rgfB7ge7wCTa1tf66tf34aOryw5pO+Qm12iSiPn1ldGgBRx6OxxwJcx4Dn580iXOAcjbsuv5NvLDz3Ja1Cgicjgukd4K4EPfuvfGGCe/qJjnv/Mner18eA9SkuPwv6qk2N1mvQRzXoVex8Co+11gaKpXcRH8/jH88LgbURZJUqr7rhIcZFel9vt4ORvCPF/fE6HDQBh3Eaye5fYt+RaePghOfAZ6jqz6OiZSNee0iOav/zZcYH+Zqh4MfAegqmep6lDch+9s3HD9a6OtiIicjAvs1wH9VPU4VT0R2AP4HTiR6IbnjcUF6OOA7qr6F1UdAeyL+5JwHnBxvMqJyGDc/MBc4CBVPVxV/wp0B6bigvV/hyjXCXgRN+JhtKoerKqnAT2Ad3ANJ89G8bpNDVBSojw1eXHp9gUHdqNRehxSXAy6EKQCAWZJEaya4eaKv3ky3N8FnhsBX94Jf0xwS7qZsoqL4Ku7YNyF/sA+rTGc+rrrDa8tgX2gjGYuqd718+Hwf0JKg/DnFua5kSH1XaScFqZKiMglIvKbiBwctP954Evgv7jPw29EJDURdTSmpvpgzmrWb3cj+to0SeeUQZ3i/AzqvlMs/Bge7ecCtHljIXt9nJ/H7CI/G358CsYMdIlyIwT2JZLieur3PReu+wWOe7h6AnuAy78v//ladIULv3A5gnwzundkwRsnwdf/gOLC6qhpYqyb76a7PNYfPr0BstdV6dNFE9yPABap6vOhDqrq78BxQBfgrkrU5Xbv/lZVXRRw3fW4YfoAt4lIuXUWkQOA/XAjCC4N7C336nmTt3mXBKQSrWw5X91wv60PqOr0gHI5wAVACXCliDQPKnc90AB4VVU/CihXBFwKbAdGi8he5b1uU3NMXJjFn+tzAGiUlsx5B+4enwv3PNytHRpKWiPY9zxoE2LWjJbAmrnw4xNuTfQHurqWxAm3wm8fw45Nsddtayb7zr7ZBUK1Te5m1wjyQ0Am15Y94OKJsNcJiatXvKQ3gYOv9yfUC6VoJ8x8odqqVCOVl9OiYEfociZWJ+FGA5Z+dnqfxxcBOcCbwDJgGC7RnjEGN/3vmSn+6X8XH9yNjNQqbIjWIvh1HIy/HP7XC545BL75JyyfVreDs+q2bbXrbHi4L3x5O2xb6T+WnF723OQ0SMlg7W5HVH9QH63kVDj8bjjnA2gUMBx/2mPw0lGweVnYorVecQEU5VVLkB9NcN+eskvSFAOISOlvmTd0fAqul73CvN7rQUAB8F7wcVWdghuq3h7XA16eId79bFUNMQaVr7z7zrhgPqZyIpIGjPI23wxR/6XAj0AacEzQ4dERym0HPgk6z9RwqsqTk/y99mftvzvNG6bFfuGSEjdPKZTURnDsw3DCGLjqJ7h5KZz2Bgy9wg2X2iXvpbqWxOnPwLvnwIPd4cn94bMb3bD0yrTIT7iFJtmLXINBbbL2F3huOCyd7N/X62i45Ftou2fCqlUlBl9UTu/9TjdXv7io+upUk0TKaWEjG6rSXsCvqhoYHZyOy99zhqqei5sal4trLDfGAJ/NX8uKTa4fqmlGCmftH6eOBHDDu1MyoPuhbknVxiECxnW/wPcPwyvHwAPd4O2zYNbLtbORvyZY+wt8cCk81s91NgQmx2vQ0g17/5sXinlBPQPPget+YVGvy2tuUB+sx2FwxTToETAcf/VseHaY+w5al/mC/FkvuZEwVRDkRzNOOCdoe7t3vxuwPGD/TiDaCTkDvfsFqhpu0cWZ3nUH4ub8R9LYu98Y5ng2riEhDdeo4OstqGy53rgpC5tVdUmYsjNxyfIGAm8BiEhT3PB73/Fw5c7C/x6ZGu7HpZuYl7kVgLTkJC4+uFt8LvzL2/65SsGad4Z9TvVvN2rl5k/3Od5t79wCK3+CFdNcC/van3cdAbDhd3fz9d626gm7HwRdD4bdD3SZU8PxEpAJWrsSkP3ynptfHZjHYPht7gO0Ls6vHnYTzH45fN6GkkL44laXKfm4R+vfcjUznw//3hTthFkvwuFhGthMLFqz6+f6MGCLqn4OoKqbROQ7YJ/qrpwxNZGq8vRk/1fO8w/qRuN4TP8Dlxht33PLzqFWhfW/wuJvYPFEl9itJKAhuCAbFn7qbgBt9nSjDXuOhC4HQmpGfOpW16i69/SHMbAsxEqfLXvAAVdB/zMgraHbF3KO++/VVuW4aNwWzhrnRpRO/Kf7Xcrf7qZGLpkEox6o4zmAFIrzq2R1imj+C6zCDbn3WejdH4rLZo83F24osCHKeviinxURzvGNSalIpORLPtc9zPFOuAA9+HqVLed7vJLwQtW/q3e/1eulr2g5U4M9Ncn/YXvK4E60bRqHD7T8bPjmbv92v9Ph949cT2tKAzjh8cjBaIMW0HuUu/mut3I6rPjeBftr5pT9kAbYtNjd5rzqtpvv7gX6B7lEfc13d1lWa2MCsuIiN8frpyf9+9KawEnPwZ7Bg2vqkPTGcPT9boRG4NDzlAy3lu0Ob8TGul/ghZEw+AIY+Y+avUJAPGxfC5PuhYLgNuwAKQ3cyAdTFZKA0lGAItIQ2Bv4POi8TbiGAGPqvW8XZrFwnVvytGFaMhcc2DV+Fz/zXdjjiLL7RFxQ2X4fN286b7sLRn3B/ragr8AbFrrbj0+4/5/dDvGC/cOhZff4ZWmvrQrzXC6XH59071OwLgfCgde4kYTB36Uu/7566ljVkpLgoGvd98r3L/SvUDT3dZcg+pSXof3eCa1izHLKCYl9q1PMftl1sF0wIeanjCa4nwZcICJNvUD0M9zQ/EdEJAMX/F+CC4DfjrIevh7zSBMafd+6KpLyexJuON8gERmsqsHdnVcEPG4ah3KVrX9cXreIXIqbn0+7du2YPHlyhMtFlpOTE1P5uq6892fptmK+X5wHuIHwA9I3xOX97L7kVbrkuMArP60lM5qOpk/TZbTaNINNzfrx65JcWBLt86RAygjoOYKkbvk03b6Q5lt/pfnWBTTd/idJGjR/busKmLfCtTACeemt2NZsb0okhbablxE4y69403L+fPefrG9/aGVfcszS8zbQd8F/WdD3FvIz2pTuTy3Yxl6/PUiLrfNL9+1o2IkFfW8nd11DWDe5yupUI/6+tB1DUlvSsHBH6bqluWltmLXv/+i86iN2X/EeySUF7sislyj4+X2W9LiA9e1GVPmXsep+f5KLcumc+SGdMz8iuSTsSqUAFJLMjzKUkkT//OqmVcCAgO0j8K8eE6g5bhWaSqkhS+228Modj+tISMElEp4K/E9V5wWd3xWXb6AihqtqadefiNwNRBpqkq+q1p1aC6kqTwRM/ztzvy60aBTD9L8GLfwrqnQ/dNfAPpSMptDnOHdThY2LvED/G1j+veuV9CnaCYu+cjdwydV8gX7XQ1zDc6Bqzi5erXI3w8wXYcazsCMo8JNk2OsvbjnbjvVo5FynQXDZd/Dp31xOB4CNf8Lzh8FR/4YhF9e+xqANf7qGrZ/LCYmDlxCMg2iC+w+AI3GJ9T5W1dUich8ued4T3jkCbAXuiEvtKklVl4jIG8A5wEcicjVuabomuGz3N+OW2UvFJbqLqVyiqepzwHMAgwcP1hEjRlT6WpMnTyaW8nVdee/P2Ndn4b7jwQkDOnDqMXGYTbFpCUz9pHQz/dj7OaT/KBi8N7x3Pq3/+iIjmneO/Xk4yv+wMM/Nf1oxzX1IZ87YZbhyRv4mMrKmhLxSckkefZY+T5+DjoXWvaBB8zjUL0pjz4CcxRyw5UM44y23b808eOdq2BYwH7D3sTQ68Rn2y2ga8jLxVGP+vnq8CK//BQp3IikNaHTGSwzvvB9wFGy+GSbcUvolLK1wG30WPkqfvNlw7P+gTe8qq1a1vT/FRTD3NZh0n8vWG6htX9i8tOzve2ojUo/7H8P6H131daufvgSu8JbD/RK3pKwCnwadN4DII+TCqiFL7XbBrTTUBTf9b5JXnwHA2cDpInK6qgZOOs0BXo1Qnb1w+YKycSsWhfIzMC/EfsuAVkv9tHQzc1duBbzpf4eEG3BaAdtWu2SiPsNuCn9uOCLQppe7HXAlFOTCih/8wf6mRWXP37LcTQGc+YKb27/7Af5gv+1eLifQhj/KDleu7TYtcb30897adfpXWmM3DWLo5dAijnkTapOMpnDyC9DjUPj8ZrdqUXE+fH6Ty4l0wuPQsGWiaxmZqvvO/OMT8OcXkc8NDurjmC+hwsG9qk7ELUsXuO//ROQX4BSgJa4V/FFVjTS8PhTff5VIkyt8zXrZFbzmFbigfDSuYSLQu7jh9aOBzXEoV9n6V8XrNgmyaH02Xy7wJ6K7YkSPCGdH4cs73FxocC3Z/by59c07wyUT4/McwVIz3ND7rge5D9WiAlg7z/3TWjHNzd+PNIQZXGbxF73W//Sm0LwLNOvs7pt3DnjcBRq2im+rrJcDAC3x5wDYsgw+uc4lMgFA3Nr1h9xUc6cPVJUuQ6H7YfDnBJfYpnNAXtGW3dyQzN8/cRnit692+5d/59akPfAat7yeb+5fbaLqPnC//j/Y+EfZY+33gSP+Bd2Gw1P7lz0enNPCxNu/gZNxn7+X4zoK3lTV33wniMhAXN6dXZLulidoqd1hvhV5RKQdLsD2LbX7WAUvGbhk7nm+lXVEpA8wAdcZMA0IXl3oflxg/znw14ByScA/cL3sz4rIx77kgqq6ETg/wmvzTV14W1XDjQIcr6p3V/C1mVogcKndkwd1pH2zGAZg/PC4/ztG5/3dEOlYpTWEPQ53N3DBvG/4/tIpZaeFlRS64f3Lprrpck12c/t9w5W9Ocl7tBkBg/asXT35qu770o9PwMLPcG2WAZp0gP0vdysdJaITpKYRgYFnu++64y5wOR7A5XFYM88F/7sfkNAqhlRcCAvGw4+Pu5xWkVRhUO8Tc+YNr4U51tSGy737SM1Vvq7J5RHOKeV9yJ3oLadzNC7x32bgS1WdJCK+5D3z41DOV6fAnAQVqb+vEaR5wHSHipQzNdDTAcvRHN6nHXu2j0NP8KJvAlr/xCUYScTQpJQ0FwB23g8OucH1fK77GV4+JiBYjiB/u/sn7ftHHSy1oRfsBwX9vgaBxu0qHoCHygEw9jT/kENw88tPfh56HRX6GvXBqAcgZ727DybilgDscRhMvg9+etolYCwpdJmRfx0Hox6E3rWoJ3vNXLe00PLvyu5v2hEOuwv6neb/HTvh8dKRDRXKaWFioqprveD9EqAdMAMIHtq+N/ARlfu+EXapXRG5AjdC7zYReby83vvylswVkZtwDRB3icgLQb33vnlK9waVKxGRfwG3AK1wHSm/UQ4R6Yh/yNWL5Z1v6oZfVm3lu0Uu73OSwGXDYuhIyNkAs1/xbw+7qWq+Y7To6oZWD7nYdRZk/uQP9oO/F2SvLbvtBfm7rf0aHt3H/a8+7O81O8gvLoLfP3ZB/eoQA2ra7wMHXAN9T3Tfr0xZbXq5pYi/vgtmPOf2bV/lVmUYcYf7HppUhUs+VlTedpeX6qdnXP3KEJfn6oCrXb2rIaj3iVNazZjN9e77ikiDMBnzhwSdWyGq+iNuGbpSItIENwyuCNdqH2u5hbhVAlqKSI8wGfN9XWOl9VfVbSKyBJcxfwhuqGC55UzNk7k5l4/mrSndvvLQOPTaFxW4nlOfgWdBx31jv248JKe4+WBDr3DL6YXKMC5JkNHMDfEPl4HcpzDX9ZQG96aWPl+ay9bvawBovnvA4y6u9TvZ+3f2yzuwLeifbGBg32ZPOP0taBWnkRW1VUVGfqQ3dvPd+p8Bn93gEtwAbF3pGkz2PM4l6IvLtJAqsmUFfPsvmB/U4ZvWBA75G+x/JaQGLQ8YaWSDiTtv5ZgdqvqvcOd489jDzmWPcO1yl9oVkdW4UQH7U/5qPNEumTs94FjkxA7+br1wK/YEOx+XjHCBqk4v51xTRwQm7T2uXwe6to4ho/hPT/k/n3fr74bFV7WUNOg2zN2OuAe2r3Ej7BZ/4+7ztoUslkQxFBe7nvy5r0NGcxckt+wGTTtBs46usbZpR/c4npnWK5oDID8b5r7h3tetIWYQ9TzCjX7rNqz2zSGvbqkZcMyD0H0EjL8S8ra60ZiT7oVlU1wC5KYdElO3rZnuu+/sV90qEYFSMtx3pgOugtbegPeQqxtUnaiDexFphpsbdgDQBpioqv/1jvXG9b5/F2FJu12oaqaIzAH2Bf4KvBb0nMNxifrWERRwV9KVQANgrKpGs6B3yHKqWiAiE4CTcMvW3RNYyJufdwDuy8VnQdf8CLjBKzcxqFxTXNIdgA+jqKepZs9NXUpxiftedkD3VuzbJQ7ZxWc855+nlt4URtbAJbgiLa2W0Qz+tsD1yuducgn5tma6D7xtmWUf54dbLMJTXODmQW9eGvq4JHsf6h1c5n/fcL5gvY9xHwjpFcnLaUq13xsu+MLNf/z6H7DTm5W08FP3ZWzEbS5ITk5NbD0D7dwK3/0Ppj9bNrFTUgoMvtB9yDaKkHg90sgGE29bccu+VsX6mTVlqV2AL4DLgL+LSOCwfMHlL2qIy2kUlAgirPO9+/J67fcVkQeAFriRiNOBz1Q1zD9KU1MtWp/NFwv8a2LHNP1v51b/0rcAh9yYmICzaQc3FHvg2a7He/VseOnI8svlbXUjsYJHY/lkNHedAr7vBs06BjUCdNi1YTecUDkAAoP87WtcsDfrlbJr04PrnOh3muvBbbtnxZ7P+O15LFwxDd6/BFZ6/5590wRHP129IwjXzIUfnoAFH+66nHTD1rDfpTDkol2/W1Tz6gZRBfcicjTwJi5jrS/R8uqAU3oB44EzgXeirMt9uFb1B0TkB1Vd7D1nW+Ap75z7A4fMeQnvrgZmqOq5QXXtDWQFtqx7H6AXAv/CfcDdGOI1Vqocbi7dicCtIvKFqs7wyjYGXsK1rj+lqluDyj2Km2d4noiMV9WPvXIpwLO4rPzjA+cempolKzuPd2b5E7RddWjP2C+akwVTAoKK4be4NUFrmnBLq6U2cvt9LeeNWrtbuOyvO7cGBP0rXeC/baX/sS+YDEeL3fnBS/EESkpxyeAssK+cpCTY9xzXQPLN/7meE3CjLr7+h8sIe+zDiZ8PV1TgvrBO/W/ZERvgRhoc/k9oXYG/0arMaWGCZQOLyj2rcmrKUrsAf8c1IBwDrBCRn3C9+f1xHSNv4DoRyuV1evTENSSUN6LhePwdBT6rRORsVQ2dFdXUSGWn/7Wlz24xTP+b8by/Yb11b9gz+FckAZJT3MipeMjb6m7hpgOCy/fTtKPXCNAh4LHX+99kN0jxVukMygHAgLOgz/Hus+/XcbsuJ9yghTcV4ZJq6a2t05p1gvM+cZ/rUx90Pfg7N7sRhEOvgCP+6f85xVtJiUsw/OMToRuSWvdyvfT9Tqt4Y1EVq3BwLyJ74xLMpeCC7ansGsB/AeQCfwlxLCJVHSciT+MC3fki8g3+bLZNcY0GTwQVaw30xvXoBzsDuENEZgOZuGVvBuPmxa8HRqnq2niVU9WZInIbLsvvDyLyLa43YjjQFtdSfmeIcpkichHuw3m8iHwPrMEND9wdWIxr6Tc11EvfL6egyLU59evUjIN6tor9ohPv8X/ottoD9qvBvwL9ToPvH4ktAVmD5u62W7/Qx/NzAnr7VwQ1AmS6HtbylBTBrJfg8LsrXi+zq0at4C9PuF6WT/8GWV67Y9Zv8PLRMOBsN9SyURz+DqKhCr+Nh2/u9q+V69NxMBx5b+IbHkw4v+OC4qpQU5baRVU3ishhwJO4pHvHBRz+A5iiqhVNnnuhd/+xl3QvlCW4fAMTcMvppQH74BL3DQc+F5EDVPWXUIVtmd3qVd57tCG3hPFz/YNP9m+2vdLvaXLRTvb/6TF8Y61+bz2K9VOnRixTnUaE2FcsKSBJrGt/GKs7HEOSFpGev5H0/I1k5G0sfZyev4n0/E0kaVGIqwTJ3eRu60L+CQBQkNq8tLXOVcQF+TrrRWTWroNmcht0YFWnE1jX/jBKktJh9u+4f3FVq178jcmBNOv/L/b67WHSCza5fdOfJnvBl/y2103sbNgxbNFo35+k4gLarZ9Mp1Uf0Sg3eD49bGm+D5mdR7O55b6QnQTTas7MqGh67u8A0oETA3qXywTwqlooInNxrdBRU9UrveD2KtwHj28d2peIfh3ab3EJeAZ59SkGluLWtH1EVUNP6ql8OVT1v97qATfi5uVleGXHAA+pasj5dqo6VkSW4j6ED8INTcwEHgT+Hek5TWJt21nIGz/5O4SuHNETiXVY2+o5bs6Wz9H31+yEK0lJVZ+ALL0xtO3jbqEU5rl59pP+7bK8+zL/BkppAIMvil+d6rsu+8NlU12yvcn3+0duzHsD/vjM9ZAPPKd6EtGt/Am++jusmll2f4uubjpL3xNtfmPN9jwuS/wgVQ23nFuNEMuSuSKyJ/Cxd+45wDe4fD2DcJ/3z4vIgap6IRF4U/ZO8TZfilDXUD36k4BJIjIOt0LBfyjbyBBY3pbZrUblvUd3jf+VEm8xqv27t+Ti0TE0Vv7wBBR57UjNu9Dnr3fRJzmakKCKTQ547CUiW9f2UDqe8Rgdm7QjfAjnKSlxa8hvX+WGzG9b7R5vW+1WgNm+xt2Ch1aHkFa4NeT+XT5RGreDw+6i4YAz6ZWUTK9yrxxf9edvbAQccQZ8dJXLiwM0yVnK0Lk3u+V6B5wRslSF358dG93ov9nPQ25Qu6kkw94nwQFX06LDAOIwAbdKRPOXPAKY6wvsI1iNC44rRVXfAt6q4Ll3A3eHOTYVN7og2uevVLmA8l/gRjBEW246bok9U4u8/uNycvJd63DPto05cq8Yh16pwoRbKc2r1Oto/1IyNZmXgEz/+BxJRAKy1Aw31PqEx72EPFtDn1OZ9XtNeMmpcNC17sNuwq1uDj644fCfXOuGLh77sJuzXxU2LXFTBH7/pOz+jOZuTv2Qi6puqJ6JG1V9UUT6A197c8M/BFaEaxCPUo1Yatebavc+bij9QV7SXp9vReQIXIb8C0TkdVUNmezXczpufv4q4MsK1jnYPbjg/ggRSfUtvWdqprhO/yvMc8vf+Rx0vT8hbU0SlF180ezf6VjR4e1JSW4ofJN24acDlhS7UX+lAf9qfyOAr0EgZ50bAl4ROzbAz2+56WumajVqBWeMdTl1vr7LjaYo3AHjL3fLIB/7v+inYG5cBD8+CT+P3XUVqPSmMOg8GHq5myJQw0Xz19yKigW9abikc8bUaTsLinlp2vLS7SuG9yApKcbewV/ehVUz3OOkVDjqP7FdrzqNeoDsNYtomsgEZBXNAWDiq1knOP1N+PNL+Pwmf5bgzOnw7DDY/wqXdC9e+Q52bHQ5KWa9VHaeY3IaDL3MJYZqUFPb1E0wEQnsPvuPdws3CkpVNZrvLsu9+0QvtTsU2AtYGhTY+6652UvMez5wOGFW8vH4evZfiXJEY6CF3n0abopjqGmKpoYInv53cM8IyUDLM+9NF7QCNG7vgueaJmR28TgPbU9K9ubZd8C/CEaQ4kLIXgePRmigDl7izFQPEdj/cjfdbtyFsGmx2//LOy7hXa+jXbb9SKsbqMKKaW4kizcKoIxmnV1Av++5kBGH5a2rSTQfkFuo2Jy4Hri56cbUaW/PXMnmHW7uVcfmDThhQIxLcuTnuF5InwOurF3LtTXvzJxBDzIi0cuixSMHgKmcXkdB10Pgu4dg2hg3PUKLXSKaXz+AUfdDnxNCD5Hfmsm+s2+GAR+EX1qvcKebBvD9I7uusLD3KTDyH9AiUgxnaqhoWkWjbUGtKUvtdvHuI02z2+rdtwx3gojshWsoUNx0wcoKTIqRE/Ysk3DbcuM4/a+4EKY96t8+8Bo3qq2mqebs4mElp4b/PKrGdctNBLv1h0unwIRbXMMVuJ783z+GPz53gXlwo0txkcvR8+MTLgP+Ltcc4P429vpLzVoFqIKiCe5nAEeJyB6qGjKrrYgMAfoBY+NROWNqqoKiEp6b6l+W7bLh3UlNjnFu8Xf/g2yv86RxOxh2c2zXq6+qIweACS+toQuy+53mRlD4sstmr4F3z4U9joRR/3VrEweacAtNshe54f1nBM3MKimB+e/CxH+5IZOBdj8Yjrwn/NBLU+OpapX9cdagpXbXePd7ikjzECvngEukCy75XTi+xCGTVDXM2qAV4mvt/COKJH4mAV6L5/S/+eP8I6satITBF8ShhvWMBfU1T3pjGP0UdB/hEv0WeO2VJUUw62WY9yZ7tDkU9u7oMt//9LRLxhys1yg48GrY/aBanacnmg/UJ3EJYsZ5y8WV4a3l/hKuNfnp+FTPmJpp/LzVrN3m5uS0bpzGqYNj7K3evNS1IPocfrct2RYLLwcAkgSJyAFg3LKD530CJz4Hjdr49y/6Cp7aH6Y8CEXelOqV02HpJAR1ORNWBmSdXToZnhsOH15WNrBv3QtOHwvnf2qBvSnPfd79AyJSOlm5vKV2RWShiJRpDPCO9RaRFkH7xFv5JtySuT/iAvwGwIteUjxf2SQR+TsuuC/Czc3fhYikAmd7mxHXtheRLiJypoikB+0XETkH/3vySKTrmMTKLSjipWn+tp4rR8Qw/a+kBL5/2L+9/5U2VS0ayWmQkuESxV73Cxz3sAX2NU2/U12i3zIUivLpsPZLeGIwfHlH2cA+JQMGnQ9XzYQz34auB9fqwB6i6LlX1S9F5HHgGuA3EVmAC+QPF5HpuLVbU4CHVbWGjKcxJv6KS5RnJvvXmr3w4G5kpCbHdtEv/+5fQ7XjIOh3emzXMzDqAZcsJ5E5AOo7Eeh/GvQ60vW6z/Laf4vyYNK9bm7cMQ+63vpCb7R00U6XjO+Ul9yydou+KnvNRm1gxO2w73k1MwmUqXFqwlK7qlogIucDHwEnAcNFZCYuW/4AoBsuu/71qrqE0I7DLa27lV0T+QVrCbwJPOONXFiDSwLY13sugCdU9dlyrmMSaOyMTLbkulyHHZs34Pj+MUz/W/gJbPzTPU5vCvtdEoca1hMhcwCYGinMdFbxJar2SU6H/S6Fg6+HRjHksKiBovpmpKrXicjvwD/wZ8Tv5N02Af9S1THxraIxNcsXv65j6UaXrK1JRgpn7x/jHN8l37qlw3xG/deGkMdD885wycRE18KAS2533MNuGONnf4O1P7v9mxbB66Pd8jKBNi2Gpw8suy+lgRsud9B1NqqljvJ61S8DDgDaAB+p6i3esaG45WnfDTOkPaKasNSuqn7trQpwA3AYbhWiJFyDwNvAY6r6U4Tn9SXSe0tV8yKcB/7ldIfgMvTv5z3XOuAd4DlV/baca5gEyi8q5vmA6X+XxzL9TxWmPuTfHnIxNGgeWwXrk5qSA8DET0kBrJld5wJ7iDK4B1DVZ0TkOVxLc3fcB2QmMENViyKVNaa2U1Wemry4dPu8A7rSNCOGZBvFhTDhNv92/zOh0+AYamhMDdZpEFwyya0h++29/qR4wWsNB2bAR2DgWXDonV5WY1MXeUPan8Rlbwc3MjDwW1dD3JS/QiqZSC7RS+16ZRfhRhBUpuzxUZy7CbilMs9jaobxc1ezbrtv+l86f41l+t/ib2DdL+5xSgM44Ko41NCYWqgerG5QqTGNXgv3HO9mTL0xf2MxC9bkApCRmsQFB3WN7YIzX/BndU9rDIf/X+TzjantkpLdcnV7/QVePaHsqgbBmu8Op78F7SMsQ2RqPRE5CHgWl7X9TlzgPD3otCm4TPMnEFuWeGNqvOIS5emA6X8XHxLD9L/gXvtB59fJ3kpjIimRFJKSU+pFIkSbsGhMFD5dWlj6+IT+HXjk6z+Zs3Irn193SPQX27ERJt3n3x52c+T1OI2pS5q096+1HE7eVgvs64dbcD31o3xrwAcv9aWqJSIyF+hT/dUzpnp9Pn8tyze5joSmGSmcNbRLOSUiWDENMr3ZHkmpbokvY+oLr6d+bdtD6XjGY3U6qPcJG9yLyLBYLuwNWzOmzpi5fDN/bnHTMgUYP28NqkphsUYuGM63/4J8b2pmyx6wf6VGahpTew26EKY/45LoBUtpAIMv2nW/qYsOwE3tK28punW4xHXG1FmqypOTAqb/HdiVJrFM/wvstR9wJjTrGEPtjKklgobfL5r9Ox3rQWAPkXvuJ0NwasEK03KubUyt88jX/uHDIm6t+0pb+zPMftW/ffR9kJIe/nxj6qJhN8Hsl0MH96kZ7ripD5oBq8o9Cxpj3y1MHTfpjywWrssGoEFqMhcc1K2cEhGsng1LJ7nHkuQygxtT14Vc3eD3hFapOkX6kJxK5YN7Y+qMrO153P3JAn5Ysrl0X0ksfxmq8LlvFCrQ8wjodVRMdTSmVkpvDEffD5/dCIU7/PtTG7n9tgZzfZGFf3m2SHoDq6u4LsYkjOu198+1P2O/LrRslBahRDmm/s//eO9ToGX3GGpnTC1Rz1c3CBvcq+qIaqyHMTVO1vY8xkxcxHuzV5EfSy99sF/fD5j/luJ67Y2pr/qdBt8/UjaxXvPOsM+piauTqW7TgFNEZLCqzgp1gogcAfQCXqjWmhlTjWYs28zsFVsASE0WLhkWQ6/9+gVll9k95IYYa2eMqQ1sMW1jwrh67FzenLEyvoF9wQ746i7/9tDLofUe8bu+MbVNUhKc8DikNnDbKQ3cdpJ9PNUjj+BSmXwgIkeKSJkfvpcD6CWgCHg8AfUzplo8GZAh/+R9O7FbswaVv9h3D/sf73kctLVclMbUB/btyZgwnjhzIMP2iPNyMd8/Atlr3ONGbWC4LUNsDF2GQvfDUAR6HAad90t0jUw1UtXpuIz5nYAJwCbcvKXRIrIemAR0BG5R1fkJq6gxVWj5tmKm/rkBgCSBy4b3qPzFNi2BBR/4tw+5McbaGWNqiwoH9yIySkS+FZFDI5xzmHfOEfGpnjGJk19YwtyVW0u3kwSSJfz55dqyHKaN8W+P/D/IaBbDBY2pQ0Y9QHaTPWDUA4muiUkAVf0fcCwwC5dgT4DmQBvgV2C0qj6aqPoZU5Wytufx0Ky80u1j9tmNbq1jyDny/SOg3qjDHodBx31jrKExpraIpuf+AtwSNDMinDMDGAKcH0OdjEm4wuISrn17LtvzigDo2LwBX/9tOMM7pZCRkkRqZaL8r/4OxfnucYeBbnkOY4zTvDNzBj3o5tubeklVJ6jqUFxAvx9uibxOqtpfVT9ObO2MqTr/+vQ3cgr921eO6Fn5i21bBT+/7d8+xFYdMaY+iSa4HwT8rKo7wp2gqjnAPGBojPUyJqEe+uqP0l77lCTh8TMH0qNtY87tm87UWw/ltCFdogvyl06B3z/xb4/6r80pNsaYEFR1k6rOUtXpqrom0fUxpiplbc/js/lrS7cP6tmKvTo0rfwFp42BEq+loMsB0PWgGGtojKlNookudgMyK3BeJtC+ctUxJvEm/ZHFs1OWlm7fdFRv9u3SonS7bZMM7h29d2mQv9du5XwIFxfBF7f5t/udZnOKjTHGIyLvelP/rMXT1Dv/+fz3MsvrNk6PtEp1OXI2wJxX/dvWa29MvRPNf5B83Dy48jQDiitXHWMSa922PG589+fS7eG92nDpIaHXhfUF+eWa9RJk/eYepzaCw/8Zj6oaY0xdcQpwMrBORN4AXlXV3xJcJ2OqXNb2PD75uezglCl/bCArO4+2TTKiv+BPT0KRN3d/twHQc2TslTTG1CrRtJL/DhwsImEDfBFpChwM/BlrxYypbsUlynVvz2XzjgIA2jVN5+FT+5OUFEMWvdzNMOnf/u1hN0LT3WKsqTHG1CnXAnNwIwRvBuaLyHQRuUJEmie0ZsZUoTs/nE+xlt1XrMqYiYujv9jOLTDjBf/2ITeCxJIF2BhTG0UT3H8ANAFeEpH04IMikoZbh7Yx8H58qmdM9RkzcRHTl20GXGb8R08bSKvGu/yqR+fbeyFvq3vcohvsf1Vs1zPGmDpGVZ9Q1SFAX+AhYB0uOe8TwFoReceG7Zu6ZnFWDl//nrXL/sJiZdysTLKy80KUimDG81CQ7R632dOtbW+MqXei+aB8CtcjPxr4TUTuEZEzvds/cT37JwKLgcfjXlNjqtAPSzYy5ttFpdvXjtyDA3q0iu2i6+bD7Jf920f9B1IrMczOGGPqAVX9XVVvAToDo4B3cevd/xX4FMgUEVsr0dQJF786M+yxqHvv83Pgp6f82wffYEl7jamnKvyXr6q5wJHAz0A34E7gde/2d2/fz8BRkTLqG1PTbMzJ5/q356He0Lj9u7fkmsP2iO2iqjDh1rLrzPYeFds1jTGmHlDVElX9UlXPwCXovQz4ETds3zKEmVrvtR+Xs3xTbtjjUffez37ZDcsHaNEV9j459koaY2qlqJr1VHUlbkm8E4HngC+BL7zHJwGDVHV5nOtoTJUpKVFuePdnsrLd+vOtGqXx2OkDSY5lnj3Agg9hxTT3OCkFjr7f5r4ZY0z0koFUIC3RFTEmHlZtyeVfn5SfL7LCvfeFefBDwIDZg66H5Bgy7htjarWo//pVVYGPvJsxtdqzU5cy9c8Npdv/O7U/7ZrGOHS+IBe+usu/vd+l0KZ3bNc0xph6wptbPwo4HzgOF9gLbqnd1xJXM2NiU1yi3PDOzxSWaLnnFhYrc1ZsKf+i896AnPXucZPdYMCZMdbSGFObWdOeqbdmr9jMQ1/9Ubp9+fAejOjdNvYLT3sMtq9yjxu2huG3xn5NY4yp40Rkb1xAfxbQFhfQ7wTGAq8AE70OBmNqpWenLmHGcpe4NzlJePeyAxi0ewsAJk+ezIgRI6K7YHEhfP+Yf/vAayAlxkTAxphazYJ7Uy9tzS3g2rHzKPZaz/ft0pwbj+wVhwuvhGmP+rdH3gUNmsd+XWOMqaNE5BpcUD8AF9CDm2P/CvCOqm5PSMWMiaP5q7bx8Ff+laKvPrRnaWBf+Yu+B9tWuscNW8Gg82O7njGm1rPg3tQ7qsot435h9dadADTNSGHMGQNJTY5DZtmv7oIiLwFO+34w8JzYr2mMMXWbr+txNS5J7yuq+meE842pVf6/vfuOk6o6/zj+eXbpVYqAUsWCHRDQiA01JrYYK/aaGGOLKUZNopGoiZpEE0uM5WeLXVDRiF3B3gALIBoEQUDpXeruPr8/zp2dYZiZndmd3dnZ/b5fr33dufeee+8zwy5zzz3POWfN+nIueuwjyqIGhQE9N+PCA7ap2UkryuHNG+Pr3zsXmrWu2TlFpOipci+Nzv3vzOSlz+ZXrv/tuP706NCq5if+6k34bHR8/ZC/Qklpzc8rItKwPUZopX/ZPTbFiEjD8efnPmPGwjCRVKtmpfzz+AE0qWmDwtRnYHE0hW/zdjDk7BpGKSINgSr30qhMmrOcvzz3eeX6GUP78MOdutX8xOVl8MJl8fWdj4Xee9b8vCIiDVw05Z1Ig/Tq1Pk8+N7XlesjfrQTfTrXsIXdHd64Ib6++9nqAigiQI5T4YkUs5VrN3DBIxNZXx4ahnbu3o7fHbp9fk4+8T6YPzm8btoKDroqP+cVERGRorRw5TouGfVp5frBO3XjuME9an7iaS/B/EnhddNW8L3zan5OEWkQ1HIvjYK78/unJjNr8WoA2jRvwq0n7kbzJnlIm1+9BF67Jr6+96+hffean1dEpAEysz/W4HB396vzFoxILXF3Ln3iUxZ/tx6ALm2bc+3Ru2BmVRxZ5Ynhjb/H1wedAa071+ycItJgqHIvjcJjH87mv598U7n+l6N3qXlaXMy4a2FNNBftZr3DVDQiIpLOCMCJj4wfU9U0dxaVUeVe6r0H3/+a1z5fULl+w/D+dGjdrOYnnvkWzPkgvC5tpnsOEdlItSr3ZlYKdAJapCvj7l+n2ydSl76Yt5Irn5lSuX7i7j05ov+W+Tn5/Cnw4f/F13/4Z2ia9s9CRETgTym2bQWcRpjX/iVgZrS9D3AQ0BK4P2G7SL315YJV/HnMZ5XrZ+21Fftsu3l+Tv5mQqv9gJOgXZ7uZ0SkQcipcm9mewBXAfsAzTMU9VzPLVIbVq8v44KHJ7KuLPSz365rG/54+E75Obk7PH8pxAZ33mo/2P7w/JxbRKSBcveNKvdm1guYAIwCznf3hUn7OwO3AYcDg+sqTpHqWF9WwS8f+4i1G8K9Qb+ubbnk4H75OfmcCTBjXHhtpbDXL/NzXhFpMLKugJvZXsArxCv1S4EVtRGUSL6MeGYK0xasAqBF0xL+ddJutGyWp+nppj4DM98Mr60UDrkeatqXTkSk8bkGWAec4u7rk3e6+yIzOwWYAfwZOKWO4xPJ2j9e+R+T54bb42alJfzzhAG0aJqn+47EVvtdjoWOW+XnvCLSYOTSuv4nQsX+LuAKd19QRXmRgnrqozk8Pn5O5fpVP96Zbbu2zc/JN6yBFy+Prw/5KXTZIT/nFhFpXA4CxqWq2Me4+3ozewv4ft2FJZKb92Ys5vbXp1euX3JwP3bYol1+Tj5/CnzxXHx971/n57wi0qDkMhXe7sBUdz+nNiv2ZnaSmb1pZsvNbJWZjTez880s52n7zKyjmV1rZlPNbI2ZLTWzN8zs1CqO62BmfzGzSWb2nZmtM7NZZvaAmQ1IUb6PmXmWP/smHTuiivJrc33fAjMWruIPT02uXD9ywJYcNygP08/EvHMLLI+GlWjZEfb/Xf7OLSLSuGwGZPPktQ3QvnZDEame5Ws28JvHP8GjYSH33qYzZ+2Vx5b1NxPmtd/hR9AlT1P5ikiDkkvLvQGfVlmqBszsX8B5wFrgVWADcCBwK3CgmR3rHuvgXOW5+gKvAb2B+YQBetoDewD7mNmBwJnu7knH9QLeBHoBi4CxUTwDCKmAJ5jZCe7+RMJhqwgD/aSzIzAEWEnoV5jKJ8DHKbZvyHBeSWHthnIuePgjVq8vB2Crzq255qg8TD+zbDa7TfgtbP0vePPG+PYDLoeWHWp2bhGRxmsGsL+ZbeXuX6UqYGZbAQdEZUXqnT8+PZm5y9YAsFmrpvz9uP6UlOSpq97i6TDlqfj6Pr/Jz3lFpMHJpXI/CehWW4GY2TGEiv08YF93nxZt70qoYB8FXAjclOUpHyFU7EcBp7v76uh8OwDPA6cDbxO6GSS6jlCxfw44LuG4EuCPwJXAHWb2jLtvgNAfEDgjw3uL5VE96u7fpSk22t1HZPneJIO/PDeVz76N+rs1KeHWkwbSpnkexnd8/hLarpwGj58GZeELnK67hDlmRUSkuu4F/gq8bmZ/AB5x9zIAM2sCnEDol98cuK+6FzGzk4BzgV2BUuDz6Nr/zrbhIOFcHYHfAkcSRvRfS7hPusvdH8hwXIfouB8BfQn3YfOAN4Ab3P3jFMeMINx7pLPO3dNO05LP9y2pPf3xXJ7+OGG63aN2oVv7PM6c89aN8cF7tz4QthyYv3OLSIOSS6r7TYQW7wG1FEssr/nSWMUewN3nE76UAC7LJj3fzPYkdCNYDvwsVkGPzjcVuDhavcI2bc7dP1pek3RcBWFu3TWEaQC3zeZNmVl34IfR6t3ZHCPV9/ykb/nPu7Mq1684bAd22jIPWZxfvw8zxmI4rPw2vv2Q66EkTwPliIg0Tv8ExgA9CJX3NVFXuFmE79z7CQ/dnwduTHOOjKLMwIcIo+2/CbwMbEfIDByVS9e/KDNwInAZ0IGQGfgJIUPvP2Z2X4p7i1hm4MeE+51uhIaL/xIy9E4BPowaOtL5hPBZJP9kepiQt/ctqc1ZuprLR8e7AR47qAeH7rJF/i6wbDZ88mh8fd+L05cVkUYv6+ZMd3/MzHYEXjazPwJj8jWXvZn1AAYB64GRKa79upnNBboD3wPeqeKUQ6LlBHdfmmL/S9GyJ+EhwPsJ+9ZVce5YGv+iKsrFnEF4iDLF3d+voqzUwOwlq7nkiXjPkUN27sYp3+td8xNXVMAzF4ZB9BLteCT02avm5xcRacTcvczMjgAuAH5JmPO+Z0KRr4CbgVuq09JczJmBSXLK8KuF9y1JyiucXz/+CSvXlgHQq2MrRhyRp+l2Y965GSrC+ek1FHoPze/5RaRByeVJdTlwOdCR8MT3KzMrT/NTlmMcsfyiKe6+Jk2ZD5PKZtImWqargK8kPEiA8FAh0QvR8nIzaxXbGD2FvwJoBTyTw6CCZ0TLqlrtdzOz683sTjO7zsyOMrNmWV6j0VtfVsEFj3xU+QXbo0NLrjtm15r3swf49DFYPmfT7T13r/m5RUQED25x960JFfs9o59e7r61u99UgxTyBpkZmIW8vW9J7Y43pvPBV0sAKDH4x/ED8tMNMGbVApj4n/j6vuprLyKZ5TqgXm2UhfCUHmBWhjKxLIFshh6NVbz7ptnfA4hVnJPPdznhAcKhwCwze4/Qmt+f8KT+QcKT8CqZ2X7ANoQHCWnT5iI/in4SzTGzU9z99Wyu15j9/aUv+GT2MgCalBi3nDiQ9i2b1vzE61bBC5fBhhRDJbzxt9Dfvlnrml9HREQAcPe5wNx8nKuBZwamVQvvW5JMnrucG1/6X+X6BQdsy6DeeR5c991boSyaNGmLAaG/vYhIBrmk5dfmk91YS3u6weYgjEgP2U2XM5bwJTnIzAa7+/ik/ecmvN5oAlJ3X2RmBwD/IqTWHZ6w+wvgdXdfmUUMAGdFy2eiQfdSmU54uv48IfWwGbALIT1vP+A5M9vT3dPOVGBmPwN+BtC1a1fGjRuXZXibWrVqVY2OL4RPFpZx54T4PdMx2zZl+YxPGJeHMZW3mv4feqz7jlS96svXfcecB3/BV30zzqzYqBTj709d0ueTmT4fqQXZZgZ2j8pWVcnNNjOwGaFynVi5fwE4h5AZmJiWn21m4G5mdj2hn/+S6Nxj3H19irL5ft+SYM36cn7x6EeUVYTnMQN6bsYvDtgmvxdZvQQ+TEj63PdiyEc2oog0aHnMHao/3H26mT0InAo8bWYXAOMIDwZOJ4xUuwFoCmyU5mdm2wPPRGVPBV4hpMoNAv4G3GVmQ939LDIws3bAsdHqPRliTdWiPxYYa2ajgGOAv7DxQ4bkc9wJ3AkwePBgHzZsWKbQMho3bhw1Ob6ufbt8Db+66c3K9f37bc61pw/Jz/Qz7vDW8ZCy6yOUVqyn94KX6X2WxkmMKbbfn7qmzyczfT6Nm5m1IKSub0d48J7qP3J396tzOG1DygzMJcMv3+9bEvz5uc+YsTC0R7VqVso/jx9Ak9I8t4F9cCesj9q1Nt8B+h2W3/OLSINUXyr3sVb5TPnNsafl2baan0uooB8JPJm073HCl++RhKffQOWUO08QUun3cvd3E455zcwOAj4DzjSzB9x9bIbrn0B4Cj8HeDHLmJNdRajcH2RmTdMMsNNolZVXcNGjH7N0dfhYurZrzg3DB+SnYv/dYhh9bjwdLpUmLWHwT2p+LRGRRi4a/O12wrg+aYsRsvJyqdw3hMzA6mT41fh9N/aswHQ+XlDGgxPj2YLHb1fKzMkfMrOG5038jErLVvO9924h1rnws84Hs+CNN2p4heLWkH6HaoM+n8wa0+eTc+XezJoSWqSHEdK5IPSNGweMqmYFdGa0zDS0eWzk3JkZylSK5pM/Khr85mBgC0JF/kV3H2tmsRS0SQmH7QHsCMxIqtjHzrnEzJ4nDJL3fcKXfDqxlv37ajAI0OfRshnQGfg2Q9lG5+bXvtxoIJubTxhIx9Z5GINw5lvwxE83nvIulaYtNCWNiEgNmdkewKOETLpHgJ0JldfrCA/bDwLaEwamTTG6ad0pRGZgPjL8qqMxZwWms2jVOi7+Z7yS/cOduvLHkwflZfDejT6jt2+CsujZS4c+7Hjs5exYWl/a4wqjofwO1RZ9Ppk1ps8np/8pzGwQYWCW3myaLvdT4JqoH9nEHOP4KFruZGYt0/QPG5JUNitRJX2jirqZtQUGAGVsXEHvFS2XZzjlsmiZtnUhmjJwD8LT/XtziTdJp4TXq9KWaoTe+XIRt7xWOfgvv/z+duzRt1OGI7JQUQ6v/xXe+CskPo/Z+vvw9TuwYXV8W9PWcPB1GkxPRKTmLibM3nOku48xs3uBXdz9DwBm1pnwXXoosFuO526ImYGJ0mX41cb7btTcnUtGfcqiVWGIgy5tm3Pt0XmalSfRhrXwzq3x9b1/BY28Yi8i2ctlKrwehPTyPsBs4Frg7Ojn2mjbVsCLZtY9zWlScvfZwETCF+JxKa69H6Ef2zySKurVdB7QEhgZTQkT80203N7MNktz7Pei5VcZzh/L1R7r7jUZ1m14tPwih0H8GryFK9dx0WMf49G4wkO37sT5+9dwIJvlc+H+H8Hr18Ur9q06wUkj4eSR0L7nxuU36wm7DN/0PCIikquhwGR3H5NqZzQg7UlAc+BPOZ57ZrTMa2agux9FiPsqwnz21wMHuPvxQNeoaKrMwK/SZQYS0u4hZAZmKznDL2ZmtMzb+27sHnr/a177PD7W4d+P65+fbMFkHz0A30XXabsl9D8x/9cQkQYrl9E/LiO0Vt8MbOvuf3D3u6OfPxCeRt9EaG2+rBqxXBstrzezypqamXUBbotWr0tMcTezC8zsczNLmAS0cl8/M+uQtM3M7CeE/npLgOQJQ98lVPBbAndHg+LFji0xs8sJlfsywhP4TUTdFk6JVjOOtGZmvczsJDNrniLOU4l/Jv/IdJ7GpKLC+fXjH7NwZejv1rlNM/55/ABKa9LP/osX4Pa9Ydbb8W199oGfvwXb/QBKSuCIW6Bpy7CvScuwXqKpgUVE8qAzoc95TBmAmbWMbYgecL8BHJLjuTfKDExTptqZge5+pbv/zN0vi7r81WpmYArpMvxq7X03Rl8uWMU1Yz6rXD9zrz7su93m+b9Q+QZ4++b4+l6/gCbN05cXEUmSS+3kYGAG8KtU/erdvYxQWZ5BSJ3LibuPAv4NdAMmmdl/zexJYBrhafdo4NakwzoD/Yh/aSY6EZhvZu+a2eNm9gTh6fT/ESr233f3jTpVR9PJnEHoA3c0MMPMno/i+JLwUKAC+KW7T0/zVg4HuhC+pJPT9ZJ1BB4CFprZODN72Mz+Sxg85z+Ehwy3uvsdVZyn0bj9jem8OS0+A9GNwwfQpV2L6p2sbB288Dt45HhYE2VPWgns/wc47Wlot2W8bK89oO8BOAZbHwA9d6/BuxARkQRLCa3yMcuiZY+kck74fs1aA8wMTJYyw68A77vBWl9WwS8f+4i1G0Lb0nZd23DpwdvXzsU+fRyWR5MYtOoEu51eO9cRkQYrl8p9d+AD91gy9KaiVvUPgC3TlcnE3c8DTiZ8Ie0H/JBQqb4AOMbdy3M43WuEgWu6ESrcPyDcMPwJ6OfuKZ9Uu/vLhGlpbgcWEwYOPIwwPsGjhL5y/8pw3dhAOA+7e4ah1oHQleFvwARga0IfvYMI/y6PAQe6+4VVnKPRGD9zCTe89L/K9fOGbV39J+eLp8PdP4D3botva7slnP4s7HcJlKSY2f6Q61nZdls45PrqXVNERFKZzcYP6ScTxvWpHCDOzFoDexMG8M1V0WYG1jDDL+f3LZv65yv/Y/LcFQA0Ky3hphMG0qJpinuEmvJyeOvG+Pr3zoNmrfJ/HRFp0HIZoWMN2aWKdYzKVou7Pww8nGXZEcCINPveIKTwVSeGaWw8nU0uxybPQZup7GLgkupcp7FZtno9v3jkI8orwrOlwb078OuDtqveyT4dCc/+Mj5/LMB2B8OR/4ZWGX7FN+vJxEF/Y9hmPdOXERGRXI0DLjKzzd19IfAssBq41sy6EUbIP42QrVdVRtwm3H2Umf2b8L0+ycxeIYxofyBhurrRpM8MnJfilCcCvzezCYQHE6XAYMIDivnAIakyA83sDOBpQmbgfmb2IeF+aQBhzKJUmYGxDL/bzWwi4QFBW2An4vPTp8zwq+b7lgTvz1jMv1+P/3NccnA/dtiiXYYjqm/zhe/C4i/DSvP2sPvZtXIdEWnYcqncfwoMM7Pt3f3zVAXMrB+hpfu9PMQmAoQRai8e+SnfLA+JEO1bNuWmEwfSpDTHPu/rv4PnL4GPHoxvK2kKP7ga9vg55HvEWxERycZIQgV3IPCSuy82s98QWpdj840aoSJ9RXUu4O7nmdlbwPmEzMBSwmB09wD/zrH1+jXCdH2DCJl+5YQuifcC/3D3lP3q3f1lM+sP/Bo4gHC/VEJ4IPAocJO7J98/xTL8hhDGNto9OmYeIcPvTnd/rY7ed6OyfM0Gfv34J5WD9+61TSfO2murzAdl6/a9ocfuIVOwbTdwp/esUfH9u58NLdrn51oi0qjkUrm/G9iXMGXL5cCDUR/1xEHkribM7XpXvgOVxuvet2fyytR418V2LZrQfbN04wOlMW8yjDoTFsXT+umwFRx3L2w5ME+RiohIrtz9A0KXtMRtd0Qt48cQWq8/B+5192U1uE7RZQbmI8Mvl/ctcVc+PZm5y0IiavuWTfn7cf0pqcngvYnmTYKFX8DHD8GAk6HHENp8Fw210LRVSMkXEamGrCv37v6AmR1MSEe7C7jDzL4lDHCzJeFJshH6mj9UG8FK4/PpnGX85bn4CLUlBrOX5tDrwx3G3xMGzitfF9++y3Fw2I3QonbS60REpGbcfTwwvtBxSOPz9MdzGf3xN5Xr1x69C1u0z7FRoSrl68PyowfCfUrMoDOhdafUx4iIVCGXlnvc/WQze5swUMxWbDyS7QzgRne/LeXBIjmavnAVJ931HmUJSYMVaYdzTGHNMnjmQpj6THxb01Zw6N/Ck3Kl4YuIiEiCOUtXc/noyZXrxw7qwaG7bFF7F4xV8mNWL4aV80K6vohIjnKq3ANElffbzKw7YQR9gLnuXp0RbEU2sWDFWm5+dRoPf/B1bpX5RLM/hFFnxaeUAeiyU0jD37xfXuIUERGRhqO8wvnN45+wcm0ZAD07tuTKH+1Yt0FMeRI+Gx0aIWJ98kVEspRz5T4mqsyrQi95E6vUj5wwh7KKiupV7Csq4J2b4bWroaIsvn3wT+CHf4ameU6rExERkQbhzjdm8P5XS4DQDfCfxw+gbYumdRtErCV/wr2wcCqc+XzdXl9Eilq1K/ci+XbBIx/x4cwllSPT5mzVAnjqHJieMHBw8/bw41tgxx/nJUYRERFpeCbPXc6NL39RuX7BAdsyqHc2M0DnWWkzsJKo5f7Sur++iBQ1Ve6l3rj1pIHc/OqXPD5+NuvLcpydZ/rYULFfFR9Vnx5D4Ji7oUPv/AYqIiIiDcaa9eVc9OhHbCgPrQv9e27GhQdsUzsXq0h9f1NhTSgpbRKv1LftWjvXF5EGTZV7qTe6tG3BNUfuTLNS4563ZwJh+oWMDfnlZTDuL/DmjRuX3PtXsP8foLSO0+lERESkqPzlualMX/gdAK2alXLT8QNoWlqS/wu5w/O/3XiblUJpU77tsj/dT7xJlXoRqRFV7qVeqahwXvos3vq+73adeX/GEsrdK5+oV1o2G574Ccx+P76t9eZw1B2wzYF1FLGIiIgUq7GfL+CB92ZVrl/5ox3p07l17Vzstavhw/+Lr1spDDoD9ruUaROm0l0VexGpIVXupV6Z8PVS5kTz2PdrsYx7y25gybl38s8P1zJq/Ox4JX/qf+Hp82Ht8vjBfYfBUXfqqbeIiIhUadGqdfx21CeV6z/YsSvDB/esnYu9fTO8eUN8vWNfOGMMtNsy2jC1dq4rIo2KKvdSrzw5MT4Bwz/aPUzJtx/R+Y0/cs2JD/OLA7fhtpenMOiLG+CxMfGDrBQO+APs9SsoqYU0OhERqRNmVgp0AlqkK+PuX6fbJ5Itd+fSUZ+yaFUYnb5L2+Zcd8yumFn+LzbhPnj5ivj6tj+EEx5S10ERyTtV7qXeWFdWzphPvwFgN/sf/b4bD14RRr//+n26tOrIiPkXwbpJ8YPa9wyD5vXao0BRi4hITZnZHsBVwD5A8wxFHd27SA0tWLGW4Xe8y8zFqyu3/e24/nRs3Sz/F5v8JPz3l/H13nvB8PtVsReRWpH1F6SZPQrc6u5v1WI80oiN/XwhK9aWYVRwQ4v/o6R8bdhRtgZGng5rlkNZ/IuY7Q+HI26BVgWYqkZERPLCzPYCXiFeqV8KrChcRNLQXT3ms40q9mcM7cN+222e/wtNexme/BmVA/5uMQBOfBSatsz/tUREyO3p93DgODObDNwGPOju39VOWNIYPfXRHACOLHmL7rYYSxw/b+W38delzeGHf4YhP4XaSJ8TEZG69CdCxf4u4Ap3X1DgeKQBm7N0Nc9+Er+n6Nu5FZcdsn3+LzTrHXjsVKjYENY7bwenPAkt2uX/WiIikVw6KF8E/A/YhVC5n2tmN5lZLfyPKI3NstXrGfv5QlqxliubPkCzijWpC3bcGn76Cux+tir2IiINw+7AVHc/RxV7qW0/f2DCRlPs7rBFe1o0Lc3vRb79BB4+PmQeArTvBaeOhtad8nsdEZEkWVfu3f0Wd98B+D4wGmgFXAhMMbNXzOwoM9NoZlItYyZ9y/ryCs5vMpoWVpa6kJVCv0Ngi13rNjgREalNBnxa6CCk4Zu+cBWTv9m4x8erU+ezYOXa/F1k0TR44GhYF12ndRc4bTS0756/a4iIpJFzZdzdX3P3Y4A+wNXAPOAAYBQwy8wuNzPNRSY5Gf1RGCX/lNJXaMG61IW8HD56oA6jEhGROjAJ6FboIKTh+/2TkzbZVu7Oza9+mZ8LLJsN/zkSVi8K6y3aw6lPQaet83N+EZEqVLul3d2/cfcrgd7A8cAbQHdC37lZZvawmQ3JT5jSkM1espoPZy4F4OGK71NRmmag5CYtYfBP6jAyERGpAzcB+5jZgEIHIg3XghVr+eCrJZts31DujBo/u+at96sWwgNHwoowfhBNW8HJo6DbzjU7r4hIDvKRRm+EgXBaJKw3A04A3jOzR8ysdR6uIw1UrNUe4JM+P6UkXV/6pi1g34vrKCoREakL7v4Y8GfgZTM718x6FTomaXiuevazjfraJ6px6/2aZfDgUbA4OkdpszCPfc/dq39OEZFqqHbl3sz6mNl1wFzgfmAP4D3gJKAL8CvgG8Io+3+veajSELk7T30cr9wfstvW0KzNpgWbtoaDr4Nmek4kItKQmFk5cDnQEbgV+MrMytP8pBmURSS9BSvW8vzkeWn316j1fv3qMHjevCjl30rgmLth6wOqGa2ISPXlXLk3s0PN7FlgGnAJ0BZ4ABji7kPd/VF3X+TuNwE7AbOAo/IZtDQcn85ZzoyFYUbF1s1KObjdrHhftUSb9YRdhtdxdCIiUgcshx8N3Cs5++cr/6O8Il27fVCt1vuy9fD4qTD7vfi2I26BHY+oRpQiIjWX9Tz3ZnYJcA5hID0jtNj/G7jT3VPUxsDdV5jZG8CpNQ9VGqKnElLyD955C5pP+b/4TisNg+g1aRm+LEt0Tyci0tC4u/5zl1r1xrSUt6kb2VDuTJy1NPuTVpTDk2fDl6/Et/3wWhh4SjUiFBHJj6wr98B10fIt4BbgSXcvz+K4yYTB9kQ2sqG8gv9+8k3l+rG7doKnRscL9NwdZr8fUtvUb01ERESqYftu7ZizNMw5f/7+W/PbH25fsxO6w38vgs9Gx7ftdxnseV7NzisiUkO5PC2/Bxjo7vu6+8gsK/a4+9/dff/qhScN2VvTFrH4u/UAdG3XnD3WvxufF7ZjXzj6TthyNzjk+gJGKSIiIsVqwcq1jP1iQeX6cYN61uyE7vDS5RtPzbvHz2HYZTU7r4hIHmTdcu/uP63NQKTxSUzJ//GA7pR8clt8Z/8TYbNecParBYhMRETqmpk1BY4FhhGm1oXQBXAcMMrdNxQmMilmT02cW9nffvetOtKncw0H5n3z7/DurfH1/ieFdPx0M/2IiNShXPrcNwe6AkvdfWWaMm2BDsA8d1+fnxClIVq1royXPouPXHvcdqXw0Nh4gV2PL0BUIiJSCGY2CBgJ9CaM65Pop8A1Znacu0+s8+CkaLk7j4+fXbk+fHANW+0/uAteuya+vv3hGhNIROqVXP43ugj4ChiUocygqMwFNQlKGr4XJs9j7YYKALbv1pZt5z8PHtbpsw906F3A6EREpK6YWQ/gRcKAvbOBa4Gzo59ro21bAS+aWfc0pxHZxMSvlzE9YUaeQ3fpVv2TffIYPHdxfH2r/cKUd6W5DF8lIlK7cqncHwHMdvdx6QpE++YAP65ZWNLQPfXRnMrXRw3YEj5+JL6z/wkFiEhERArkMsIc9zcD27r7H9z97ujnD8A2wE1Ap6isSFZGTYi32h++65a0albNivjnz8Hoc+Pr3QfDCQ9D0xY1jFBEJL9yqdxvDUzNotxnhC9ikZTmLV/LO9MXA6GL2tFbLoKF0a9W01awo54NiYg0IgcDM4BfpepX7+5lwG+iMofWcWxSpFavL+O/n3xbuT58SI/qneirN2DkGWFqXoAuO8LJI6F5m5oHKSKSZ7lU7jsCS7Iot4TwdF0kpWc+mYuHsW3Ys28nNv/yyfjOHX4EzdsWJjARESmE7sAH7rFvhk25ewXwAbBlnUUlRe35SfNYta4MgL6bt2a3Xh1yP8ncCfDIiVC+Lqx32ApOfQpadcxjpCIi+ZNL5X4R2bXIbwMsq1Y00ig8OTE+Sv7R/bvApJHxnf1PLEBEIiJSQGsIDQhV6RiVFalS8kB6luto9gumwoPHwPpVYb3tFnDaaGhbg377IiK1LJfK/fvAYDMbkq5AtG8w4em6yCamfruCz+eFyRaaNynhsBaTYE2UENKuO2y1bwGjExGRAvgUGGZm26crYGb9CFPkfVpXQUnxmrX4O97/KtxblJYYRw/McRzGpTPhgaNgzdKw3rIjnDoaOvTJZ5giInmXS+X+DsL0NKPN7KDkndG2p6LV2/MQmzRAoxPmtv/BTt1o+dnj8Z27Hg8lpQWISkRECuhuoBnwmpmdZWbNYjvMrKmZnQm8CjQF7ipQjFJERk2ID9o7bLvN6dIuh4HvVs6D//wYVkb99Zu1gVNGQZe0z55EROqNrIcNdfcXzewO4BzgBTObA3wR7e4H9CBU/u9y9+fyHqkUvfIK5+mPv6lcH75jC3j6xXiBAScVICoRESkkd3/AzA4GTiRU3u8ws28BJ/SxLyHcXzzs7g8VLlIpBuUVvlHl/rhc5rZfvSS02C+dGdZLm8OJj0L3TLNAi4jUHznNCeLu55rZF8AfgJ7RT8wi4Fp3/0ce45MG5L0Zi5m3Yi0AnVo3Y+ia16EiGhi5+2DovG0BoxMRkUJx95PN7G3CqPhbERoMYmYAN7r7bQUJTorKW18u4tvl8XuNA7bvkt2B61bCQ8fCgs/CupXC8Pthq31qKVIRkfzLJS0fAHf/J9AN2BM4ATg+er2lKvaSyVMJKfk/6r8lpZ8+Gt85QAPpiYg0Zu5+m7tvTWg4+F7009Pdt8lHxd7MTjKzN81suZmtMrPxZna+meV8L2RmHc3sWjObamZrzGypmb1hZqdWcVwHM/uLmU0ys+/MbJ2ZzTKzB8xsQIryJWY21MyuMbN3outsMLP5ZvacmR2Z4VojzMwz/KzN9X0Xg8SB9I4a2J1mTbL4592wFh49KYyOD4DBUbdDv0NqJ0gRkVqSU8t9jLuXEwbYez+/4UhDtWZ9OS9Mnle5fnyf72DiR2GltBnsdHSBIhMRkfrE3ecCc6ssmAMz+xdwHrCW0H9/A3AgcCtwoJkdG023l825+gKvAb2B+cBLQHtgD2AfMzsQODN5aj8z6wW8CfQiZDuOjeIZAJwCnGBmJ7j7EwmH9QXejl4vIQxYvDTafghwiJndB5yVYSrBT4CPU2zfkM37LSbLVq/n5SnzK9ezSskvL4MnfhLms4859G+w6/BaiFBEpHZVq3IvkquXp86PzzfbuTXbz3s2vnO7gzVnrIiI1AozO4ZQsZ8H7Ovu06LtXQkV7KOAC4GbsjzlI4SK/SjgdHdfHZ1vB+B54HRChTx58L/rCBX754DjEo4rAf4IXEkYb+AZd49VvJ3wIOFvwMtR40rsfe0HjAHOAN4A7k0T72h3H5HleytqT3/8DevLwzOa/j3a069b28wHVFTAMxfA5wn3JAdcAbufXYtRiojUnpxT0SB8gZnZj83sVDM7LdVPdQMqxrS56Jgapb/l833XR4mj5B81oBs2KWGUfA2kJyIited30fLSWMUewN3nA+dGq5dl831rZnsCuwPLgZ/FKujR+aYCF0erV9imE6vvHy2vSTquArgaWAN0ArZN2Dfd3Q909xcSK/bRvtcJDwwgtPw3eokp+Ru12t++Nzz76zASfow7vHAZfPJIfNvQX8A+v6mDSEVEakdOLfdmNhS4E9ghUzHCk+b/5BpMEafNJco5/S2f77s+WrRqHa//b2Hl+vEdp8enmGnVGbb5foEiExGRhszMegCDgPXAyOT97v66mc0FuhP6+L9TxSmHRMsJ7r40xf6XomVPwkOAxO6L66o4d+x+ZFEV5RJF/ds2GoCwUZo8dzlTvlkBQPMmJfyo/5bxnfMmwcIv4OOHYMDJsN8lMP4e+OCOeJndToeDroJNnsmIiBSPrCv3ZrY94UurFeHLrxthRNtHgW2AgUApMJrwRDsnRZ42lyin9LdaeN/1zrOffEN5RbhnGdy7A11m3B3fuetwKG1aoMhERKSBGxgtp7j7mjRlPiRU7gdSdeW+TbRMVwFfSXiQ0IzwUCGxcv8CYTrhy80s8f7CgCsI91fPuPuCKmJIFGvl/zZDmd3M7HqgA6Hf/vvAGHdfn8N16r3E6e8O2bkb7Vsm3VuUR2/3owdg4v1QURbft9NRcPg/VLEXkaKXS8r3ZYQvnnPcfW9C6zbufrK77wH0ByYA2wG/qEYsRZs2V0N5e9/11VMJc9sft0v7jfu29dco+SIiUmu2ipazMpT5OqlsJrGKd980+3sQKvapznc5YUC8Q4FZZvZfMxsF/A/4PfAgOaTXm1kr4vdb6bIJAX4EXAKcDVwKPAlMj/rsNwjrysoZ/XG8+1/GgfTK129csW/XHb5/FZSU1mKEIiJ1I5e0/GHANHdPbukGQqXZzA4HviQ8gb4k2xM3grS5lGrhfdc70xeu4pPZywBoWmr8qMn7UBYNP9B1Z9hi18IFJyIiDV2spf27DGVWRcsqRl8DQkadA4PMbLC7j0/af27C63aJO9x9kZkdAPyLkD14eMLuL4DX3X1lFjHE3EZ4gPAZoctksumEBoTnga8IDx12IWQg7gc8Z2Z7uvunqU5uZj8DfgbQtWtXxo0bl0NoG1u1alWNjq/KB/PKWLY6JFN2amGsmz2JcXPibTfDMhxbsXI+fssg5nU9gFl9jmd988IM8Fvbn1Gx0+eTmT6fzBrT55NL5b4bYVTWmHIAM2vu7usA3H2Bmb1OSCXPunJPw0qbyyX9Ld/vu955OmEgvf37daHVZ7fEd6rVXkREioi7TzezB4FTgafN7AJgHOHBwOnAbwnj5jQFNhorJ+re+ExU9lTgFUI24CDCaPh3mdlQdz+rqjjM7IroesuB4bH7sKRYH0hx6FhgbJQxcAzwFzZ+yJB4/J1EDw0GDx7sw4YNqyqstMaNG0dNjq/Kvfd8AISxfU7dexsO2H+7pADSH1viZeDQfd5LdG+2As58vtbizKS2P6Nip88nM30+mTWmzyeXyv2qpPUV0XILYGbC9jWEymgu6lva3EDiaXPvEVrz+xP68D9I6COfzo+in0RzzOyUaGTbRPl+3/WKu/NUQprcydtVwAvR8wkrhV2OK1BkIiJSX5jZo8Ct7v5WLZw+du/SOkOZWINAtq3m5xIq6EcSUtwTPU64vziS8IAfADNrQkid3wbYy93fTTjmNTM7iNACf6aZPeDuY9Nd3Mx+DVxFeG+HuPuULONOdBWhcn+QmTVNM4ZQUfhm2RremBYq9mZw7KAcxxYsbQZWEg20d2ktRCgiUndyqdzPIQw0F/N5tNyfaG5VM2tKGI1+IblpCGlz1Ul/y8v7rq+pc9OWljN7SUjBb9UEenz2f5X7FncYwKQJU4GpeblWXWlMaT3Voc8nM30+menzabSGA8eZ2WRCqvmD7p7pezEXM6Nl7wxlYh20Z2YoUymK7ahofJ+DCY0cS4AX3X2smcWy7CYlHLYHsCMwI6liHzvnEjN7njBn/fcJ9zGbMLMLgRsIDSmHpzpXlmL3cM2AzmQekK9ee3LiHGLzHu21dWd6dGiV3YHJlfq2XWsvSBGROpJL5f5twhPldu6+gpCiXw78w8xaECr/ZxNaxR/Ne6Q5KETaXE3T32qivqbOvfTUJGKJB0cO7M7Ws9+r3NfpgAsYtnN+rlOXGlNaT3Xo88lMn09m+nwarYsIGXG7ECr315vZ/cC/3f3zjEdWLTZV3E5m1jJNF7ghSWWzElWsN6pcm1lbwtS5ZWxcQY81jmSaTWhZtEzZ6dvMzgduJkybe0SKbMBcdEp4nZyZWTQqKpzHx8dHyT9ucJat9qXNYOCpqtSLSIOTywjsTwJzicYlcfe5wLWElu9bCVPgHU744vp9jnHUVtrcaGBLQuxLCOnvV0XrsfED0qXNHe3uD7r7PHdf7u6vAQcB8wkPOWKj6mfjqmh5UJTdEFMb77teWFdWzphP4w0Bp275LSyLeh+0aA/9Di1QZCIiUp+4+y3uvgOhxXo0YWybC4EpZvaKmR1V3Rlj3H02MJHQQr1JX7BoxPgehOloq9sKnug8oCUwMpr1JiY2bcz2ZrZZmmO/Fy2/ShHnzwn3WuuAI939lRrGOTxafpHjIH71ygczl/D1kjCxUdsWTfjhTt2qPqhlJ/jlJDj8RlXsRaTByfrL0t1fdfdt3f2ZhG1XEr4sHye0cN8CDHL3TH3IU5kZLfOaNufuRwFDCZXru4DrgQPc/Xgg9j96qrS5r9KlzRHS7iHchGQrOf0tZma0zNv7ri/GfbGQ5WtCF77um7Vk+/kJ09/tdDQ0bVGgyEREpD5y99fc/RigD2Hq2XnAAcAowhg4l5tZdWpj10bL681sm9hGM+tCyBQAuC6a8ja27wIz+9zM/pN8MjPrZ2YdkraZmf0kinsJ8Jukw94lVPBbAnebWbuEY0vM7HJC5b6MpGntzOzsKM51wFHu/mJVb9jMepnZSWbWPEWcpxL/TP5R1bnqs5EJrfY/HrAlLZqmmM5uxTdAwqzHR94GbbN4CCAiUoRySctPyd2fIPP8qtloMGlzaaRLf6u1911oT01MmG92107Yx0/Hdw44qQARiYhIMXD3b4Arzexqwuw75xHGrvkTYSabJ4F/uPuHWZ5vlJn9m5DRN8nMXiF0zTuQkH04mtAqnqgz0I/wgCHZicDvzWwCMBsoBQYT7iHmEwa526gPu7uvN7MzgKeBo4H9zOxDQre/AYRBcyuAX7r79NhxZjYAuINQO/0KON7Mjk8R0yJ3vzhhvSPwEHC7mU0kPFhoC+xEfIDeW939jhTnKgor127guUnxj3l4urnt37yRylmMt9wNtvth7QcnIlIgWVfuoy+HGe5+bL6DcPfZ0fl3I2QCbPSkvBbT5h7JlDbn7stSHJs2bS6DlOlvBXjfdWL56g289nl8psAT2n0K66O33XFr6DEkzZEiIiKVDGgOtEhYbwacQKjkPg78NJvB99z9PDN7Czif8KCglJBVdw+hb39FpuOTvAbsTBiLpz9h/KEZhMGF/+HuKRsI3P1lM+sP/JqQkTCMkEE5nzBW0U3u/l7SYZsRb3bePvpJZRaQWLmfTRgnaAihq+Hu0bXmAY8Bd0bdDYvWmE+/Zc2GcgC279aWXbq337TQ8jkw8f74+v5/CEPqi4g0ULm03G9PGC2+tlwLjCSkzb3j7l9C1WlzwAXAB+5+WuLJzKwfsMDdlyZsM+Asqk6b25KQNndmNHggUX+/35Mibc7MegF7A08kzjUbXe8UMqe/5fy+67sxk75lfXkId9ce7en2VXyUfAacqC9WERFJy8z6AD8nfF93IlRu3yMMJvcKcDKhIjuckE13bqrzJHP3h4GHsyw7AhiRZt8bwBvZnCfFsdPIMt6o/Dg2yinP+rjFwCW5HldMHh8/u/L1sYN6YKnuLd68AcrXh9c9dodtDqyj6ERECiOXyv0s4oO75V0xp81Rg/S3ar7vem30R/GU/JO2bwJvj4vv3PWEug9IRETqPTM7lJBZ90PCd/Y64AHgFnefkFD0JjO7F/iYkLafdWVZGoYvF6xk4tfLAGhSYhw1sPumhZbOgokJExnt/3s1LohIg5dL5f4J4EIz6+zui2ojmCJOm6tR+lue33dBzV6ymg9mhgkISkuMw+1NiIXfZx/YLE2fOBERaZTM7BLgHMJAekaYmeffhO/OlPcb7r7CzN4gTFkrjczICfGB9L6/Q1c6tWm+aaE3/w4VYWBfeg2FvsPqJjgRkQLKpXL/Z8JUcC+a2Xnu/n5tBFSkaXM1Tn/L5X3XZ09/HG+132ebTrSZOjK+UwPpiYjIpq6Llm8RZt150t3LszhuMtX8npfitaG8gicmxO81hg9JMbf9khnw0UPxdbXai0gjkUvlfgyh9XsI8I6ZzSek6qca4d3dXR2bGhl356mElPwzt1oKr0fDNDRtDTscUaDIRESkHruHkHr/SS4Hufvfgb/XTkhSX73+xUIWrQrDG3Vp25x9t91800Jv/B1iz4f67ANb7VOHEYqIFE4ulfthCa8N6Bb9pOLVDUiK16S5y5m+MAxa3LpZKXutejm+c8cjoHmtDdkgIiJFyt1/WugYpHgkDqR39G49aFJasnGBxdPhk0fi6/v/vo4iExEpvFwq9/vXWhTSICS22h+6UyeafPZEfGf/EwsQkYiI1Hdm1hzoCixNnC42qUxboAMwz93X12V8Un8sWrVuo6l2jxucIiX/9evjY/303R96D62j6ERECi/ryr27v16bgUhxKyuv4L+ffFO5fubm/4Op0SyE7XuGtDgREZFNXUSYFvZAYFyaMoOAV4HfAjfWTVhS34z+aC5lFSE5dHDvDmy9eVJG4MIvYFLCWD9qtReRRqak6iIiVXvzy0UsWhUaU7q0bc4O85+N79z1eCjRr5qIiKR0BDA7mtM9pWjfHODHdRST1DPuzmMfxlPyhw9OMftOYqv9NgdBz93rKDoRkfpBNS7Ji8S57U/cqSU27aX4TqXki4hIelsDU7Mo9xlhullphD6Zs5xpC1YB0KpZKYfuusXGBeZ/BpOfjK/v/7s6jE5EpH7IOi3fzNLO056CRstvRFatK+PFKfMq109s9SFUlIWVHrtDZ92LiYhIWh2BJVmUWwJ0quVYpJ5KHEjv0F22oE3zpFvY16+jcjzn7Q6B7oPqLjgRkXqiuqPlp+OEkfQ1Wn4j8uLkeazdENLg+nVtS9evEp6c9z+hQFGJiEiRWER2LfLbAMtqNxSpj9asL+e/H8fH9dkkJX/eJPjs6fi6Wu1FpJHKx2j5JUBv4DDgGOB64IUaxiVFJHGU/LO2XYONj6YqLm0OOx9doKhERKRIvA/82MyGuPuHqQqY2RBgMDCmTiOTeuHFKfNYuS5kBPbp1IohfTpsXGDcdfHX2x8OW/Svw+hEROqPfI6Wf5+ZnUcYxXZUjaKSojF/xVrenr4IADM4zMfGd/Y7BFp2SHOkiIgIAHcARwKjzewMd385caeZHQTcG63eXsexST2QmJJ/3OCemFl85zcfwecJg/gOU6u9iDReeR1Qz91vA2YCI/J5Xqm/nv54Lh51whjaZzPafJGQkj/gpMIEJSIiRcPdXyRU8LcAXjCzWWb2UvQzi5ANuCXwf+7+XCFjlbo3e8lq3pm+GIASg2N2S5rbPrHVfscjodvOdReciEg9k0tafrYmAQfUwnmlHnrqo3gfuHN6fg0fzA8rrbvA1hpTUUREqubu55rZF8AfgJ7RT8wi4Fp3/0dBgpOCGjlhTuXrfbfbnG7tW8R3zpkA/4v1BDUYdlndBiciUs/URuW+G9CyFs4r9czn81Yw9dsVADRvUsKeK1+M79x1OJTWxq+XiIg0RO7+TzO7hdC3vjdhcN6vgQnuXlbQ4KQgyiucUeMzzG0/7i/x1zsfA112qKPIRETqp7zWvszsBGAo8Ek+zyv1U+JAeof3a03Tac/Hd2puexERyZG7lxMG2Hu/0LFI4b0zfRHfLF8LQIdWTTlwhy7xnV+/D1++El5bCex3aQEiFBGpX3KZ5/6eDLvbANsDO0XrN9ckKKn/KiqcpxNS8n/a4WOYHr6A6baL+ryJiIhIjYwcH0/JP3Jgd5o3KY3vTGy13+U42Hy7OoxMRKR+yqXl/owsyqwErnL3+6oVjRSN92YsZt6KUJnv2LoZ/eYnjFSrVnsREakGM9sB2A5oB1iqMu7+nzoNSgpi+eoNvDBlXuX6cYMSUvJnvQMzxoXXVqpWexGRSC6V+zMz7FsPzAU+dPc1NQtJikFiSv7p/cop+SzKoLTS8ARdREQkS2Y2FLgTyNRp2gj98FW5bwSe+WQu68sqANi5ezt23LJdfOfYhFb7/idAp63rODoRkfopl3nu76/NQKR4rN1QzvOT40/Thzd7O75z24OgTZcUR4mIiGzKzLYHXgJaAe8QBubdCngU2AYYCJQCo4HlhYlS6trjCSn5Gw2k99UbMPPN8NpKYd/f1nFkIiL1V17nuZfG4eXP5rNqXRi4uG+nlnSbOTq+Uyn5IiKSm8sIFftz3H1v4E0Adz/Z3fcA+gMTCOn6vyhYlFJnpn67gklzw3OcZk1KOKL/lmGHO4y9Nl5w4MnQcasCRCgiUj9lXbk3s55mdpqZ9ctQpl9Upkd+wpP6aHRCSv55fRdgy74OKy02g36HFCYoEREpVsOAae5+V6qd7j4VOBzoBVxRh3FJgSQOpPfDnbqxWatmYWXGOPj6nfC6pCnsc3HdByciUo/l0nL/C+DeLMrdB5xfrWik3lu8ah2v/29h5frBZa/Fd+58DDRpXoCoRESkiHUDJieslwOYWeUXirsvAF4Hjqrb0KSurS+r4KmPElPyo/Yi94372u92KnToXcfRiYjUb7lU7n8ATHH3L9IViPZNAX5Y08Ckfnr2028pq3AAhvZqSZvpY+I7B5xUoKhERKSIrUpaXxEtt0javgboXvvhSCG9OnU+S1dvAGDL9i0YunXnsOPLV2HOB+F1aTPY5zcFilBEpP7KpXLfE/gyi3JfElLnpAF6MjElv9tUWB/dk3XaFroPKlBUIiJSxOaw8X3D59Fy/9gGM2sK7AEsRBq0x8fPrnx97KAelJZY1Gr/53ihQWdAe/UAFRFJlstUeC0IU95VZT3QunrhSH02Y+EqPpm9DICmpcYey1+M7xxwIljKKYlFREQyeRs408zaufsKYAwhNf8fZtaCUPk/G+hBGEFfGqh5y9du1PXv2Njc9v97Eb6ZGF6XNoe9f12A6ERE6r9cWu7nAtk0ze4GzKuylBSd0R9/U/n6qK2h6aw3ojWDXY8vTFAiIlLsniTcYwwDcPe5wLVAO+BWwhR4hxOmwft9QSKUOvHkR3OIev6xZ99O9OrUatNW+yE/gXbJPTZERARyq9yPBfqa2RnpCpjZ6cDWwGvpykhxcveNRsn/SdsPgOgbeKt9lR4nIiLV4u6vuvu27v5MwrYrgeOAx4FXgFuAQe4+q0BhSi1z941GyR8+JLqv+HwMzPs0vG7SEvb6Zd0HJyJSJHJJy78ROA2408y2Be529xkAZrYV8FPgYmBDVFYakIlfL+XrJasBaNuilG3nPRvfqYH0REQkz9z9CeCJQschdWP8rKV8teg7ANo2b8LBO20BFRUwLmFe+91/Cm27FihCEZH6L+uWe3f/HPhZtHoZMM3M1pnZOsIgepdF5zvH3afkPVIpqCcnxlvtz9l6GSWLp4WVZm1ghx8VKCoRESl2ZjbRzEYVOg4prMc/jA+kd3j/LWnZrBSmPgPzo1kSm7ZWq72ISBVyScvH3f8DDAX+S5iSpmn0sybaNtTd78tzjFJg68sqePbTbyvXj2vyZnznjj+GZho/UUREqm17QtafNFLfrStjzKT4fcbwwT2gonzjVvs9fgatOxcgOhGR4pFLWj4A7j4eONLMSoDOhI7Xi929It/BSf0w9osFLF8T7rv6tG9Cl1kJKfn9TyxQVCIi0kDMAtoUOggpnDGTvmX1+nIAtu3ShgE9N4PJT8DCaFbEZm1g6C8KF6CISJHIqeU+kbtXuPsCd1+oin3DljiQ3i97TcfWLgsr7XtB770KE5SIiDQUTwD7mpmaZRupkQlz2w8f3BPzChh3XbzA986FVh0LEJmISHHJunJvZh3MbF8z2zJDme5Rmc3yEp0U3PI1G3h16oLK9e+vT5gIof8JUFLt50MiIiIAfwY+B140sz1q4wJmdpKZvWlmy81slZmNN7PzoyzEXM/V0cyuNbOpZrbGzJaa2RtmdmoVx3Uws7+Y2SQz+y4at2iWmT1gZgOqOPZgM3vJzJaY2Wozm2xmfzCz5lUct4eZPWVmC8xsrZlNM7O/mln7XN93bZmxcBUfzlwKQGmJceTA7jBpFMTG9mneDvY8v4ARiogUj1zS8i8CrgB2B75JU6YbYcq8K4Frahaa1AfPTfqW9eUhMWOvLSpoM3tsfGf/EwoUlYiINCBjgHJgCPCOmc0npOqvSVHW3f3AXE5uZv8CzgPWAq8S+vcfCNwKHGhmx2abgWhmfQnT/fYG5gMvAe2BPYB9zOxA4Ex396TjegFvAr2ARYR7pbXAAOAU4AQzOyGaISD5mpcA1xM+o3HAUmA/wn3W4WZ2oLuvTnHcicADQCnwNjAX+B7wW+AoM9vL3RckH1fXRk6IT393wPZd2LxVKbye0Gq/5/nQskMBIhMRKT65VO4PA7509wnpCrj7BDObDhyOKvcNwlMJKfkXbv4xLC0LKz33gE5bFyYoERFpSIYlvDZCQ0G3NGU9zfaUzOwYQsV+HrCvu0+LtnclVLCPAi4EbsrylI8QKvajgNNjlWoz2wF4HjidUJG+K+m46wgV++eA4xKOKwH+SGgUucPMnnH3ysEFzWxwdOxq4AB3fz/a3obwUGRfQubDr5Ledw/gbsLneaS7Px1tbwI8CBwP3BG9/4IpK6/giYTK/fDBPeHTx2DJjLChRfuQki8iIlnJJR2tD/C/LMp9AWxVrWikXpm9ZDUffLUEgBKDQctejO/UQHoiIpIf++fwc0CO5/5dtLw0VrEHcPf5QKzWeFk26flmtiche3E58LPE1nJ3nwpcHK1eYWaWdPj+0fKapOMqgKsJWQqdgG2TjruMUEG/Plaxj45bBZwJVADnpegO+UugJXB/rGIfHVdGmNZ4BWFw5B2ret+16c1pi1iwch0Ands0Z9g2m8Ebf40XGHphqOCLiEhWcmm5bwuszKLcSkKKmhS5Zz6J9744sfcqms6bFFZKm8NOBX3YLyIiDYS7v14b541arwcB64GRqa5rZnOB7oR09XeqOOWQaDnB3Zem2P9StOxJeAjwfsK+dVWcO5aRsCgh/mbAIdHqQ5sc4D7DzN4F9gIOBR5O2H1khuNWmNl/gZOjcp9VEVuteTxhIL1jdutO08mPwdKZYUPLDrDHzwsTmIhIkcql5X4esHMW5XYi4ctJipO78+TEeKrcGa0T7nm2Pwxablb3QYmIiGRvYLSc4u6p+u8DfJhUNpPYdH3p7nFWEh4kQHiokOiFaHm5mbWKbYxa+K8AWgHPJPWB7xdtX+Lu09Ncc5P4zawdsHXS/iqPq2uLV63jlanzK9ePG9gFXv9bvMBeF0HztgWITESkeOXScv82YcCXQ939uVQFzOwQYBfg8XwEJ4Uzee4Kpi/8DoC2zWCbeQn/5ANOKlBUIiIiWYt1EZyVoczXSWUziVW8+6bZ3wNoluZ8lxMq0ocCs8zsPUJrfn9CH/4HCWMDJIqd42vSSxV/n2i5zN1X5HBcnRr98TdsKA8JCwN7bcY2c5+G5VFYrTrDkLMLFZqISNHKpXJ/E3AC8IiZXQz8x93XAURTsZwG/I2QWnZzdQMys5MI/eB2JYzw+jlwL/DvbEezTThXR8KosEcSvuzWApOAu9z9gQzHdYiO+xHhS7wJIXPhDeAGd/84qXwJIaXvUEJ/wB0IT/iXABOAO919dJprjSAMpJPOOndvkel91obEgfR+0Xs2Nju6p2nTFfrun+YoERGR3JjZa1WXqpTLaPmxlvbvMpRZFS2zaSIeS7jHGWRmg919fNL+xJHf2iXucPdFZnYA8C/CoHuHJ+z+Anjd3ZO7PlY3/ry8bzP7GaF/Pl27dmXcuHEZTpfZqlWrNjre3bn37XgyxW5tlrP25WuI3ex8ucWPmPNu8sfbsCV/RrIxfT6Z6fPJrDF9PllX7t39AzO7nDAq6+3ALWYWe/Lbk/C02oA/untV/dZSKuLpavoSMhsgVOg/IExV05fQX+4QM7sPOCv5egk+AT5OsX1Dim21qqy8YqP+9j+2hO6Quw6H0lyeCYmIiGQ0LIsyTrjHyGm0/Hxy9+lm9iBwKvC0mV1AmJquLaHC/lvCd3ZTwkB3lcxse+CZqOypwCuEQfQGERpG7jKzoe5+Vt28m6q5+53AnQCDBw/2YcOGVftc48aNI/H4T+csY86L4bapRdMSfrv1LFrMWBx2tu7CNsf/hW2atUpxpoYr+TOSjenzyUyfT2aN6fPJqZbm7tea2eeEluZdgW0Sdn8K/Mndn6pOIEU+XY0THiT8DXjZ3csT3td+hOlqziC0/N+bJt7R7j4iy/dWq976chGLVoWxf/q22cDmc1+N7+yvlHwREcmrdOlgJYTv8cOAYwhzvb+Qpmwqsdbp1hnKxFq5sxkwGELrfFtCRuCTSfseJzR0HEl40A9UTj/3BOGeaS93fzfhmNfM7CDCoHZnmtkD7j62hvHXxvvOq5Hj42P6/HinjrR4L2Emv31+DY2sYi8iki+5DKgHgLs/5e4DgC0Iqeh7AFu4+4DqVuwjRTtdjbtPd/cD3f2FxIp9tO91wgMDCC3/9V5iSv7F3adg5dEgv912ha4FnTVHREQaGHd/Pc3PWHe/z92PAy4gzOW+qorTJZoZLXtnKNMzqWxVsX7n7kcBQ4GrCA0E1xPmoD8e6BoVnZRw2B7AjsBXSRX72DmXEBodAL6fIv5eOcYfG2Ngs2hwvWyPqxNrN5Tz9Mfx+4yft3kTVn4bVtpuAYPOqOuQREQajGrnV0eV7vlVFsxCsU9Xk4WPomWPHI4piDVlzotT5lWu77c2odVeA+mJiEgBuPttZvYLYARhPJxsxL57dzKzlmlGzB+SVDbbeN4FNqqom1lbQhe+MkLGYUyscr48wymXRcuOCds+JzQodDSzrdOMmL97tKyM392Xm9l0woj5QwjdHKs8rq68OGUeK9aWAbBth1L6TL0jvnOf30DTlnUdkohIg5Fzy30tKfbpaqoSa+X/NkOZ3czsejO708yuM7Ojojlu69TE+WWs3RC6Cg7rvILW8yeEHSVNYJfj6jocERGRmEmEFvOsuPtsYCIhVX6TL7Co21wPQnfATVrUq+E8oCUwMmoAiYkNYrO9mW2W5tjvRcuvYhvcfT3xFv2Tkw+Ixhbak3A/MyZp99MZjmtH/AFJTTIuqyUxJf/Kbu9i30W3U+26w26n1XU4IiINSs4t92bWgpC6vh1hNNjktHYIo9lencNpi326mrSiBwS/iFafyFD0R2zaGjHHzE6JUvvrxDvflFW+vqDTh/EEyG1/AK0711UYIiIiyboRKs+5uJaQEXi9mb3j7l8CmFkX4LaozHWJg/VGA+VdAHzg7hvVNs2sH7AgMSswevh/FqHr3hLgN0kxvEuo4G8J3G1mZ8amqIu6Gv6eULkvY9P7hOsIYw5damYvuPsH0XFtgHsIjTS3ufuypOP+SejSeLqZjXb3Z6LjmgB3EO7fRrv7Zxk+u7ybs3Q1b08P7S6tbS17zkuYuGif30CT5nUZjohIg5NT5T4a9O52Nk4b26QYIXU9l8p9sU9Xk8lthAcInxGNOptkOmG8gecJT+ybAbsQBu7bD3jOzPZ090/TXSBf09UsXVvBZ4vLAcOoYPtvnq7cN7nJrixqJFNIZNKYptKoDn0+menzyUyfj6RjZicQWu0/yeU4dx9lZv8mfO9PMrNXiM/E0w4YTZiRJ1FnoB+hRT/ZicDvzWwCMJswZe9gQur9fOAQd98oS8/d15vZGYTW9KOB/czsQ0LK/QDCPUIF8Mvk1Ht3/9DMLiP0638nmjZwGeH+oAuhW+EfUrzv2Wb2E+ABYLSZvUV4wPA9QkPFl8A56T632vLEhLnE5gz6Y9e3KV0WjZDfvicMPLWuwxERaXCyrtyb2R7Ao4QvoEeAnQmV0OsII8AeRJhq7m5gTprT1In6Ml2NmV0RXW85MNzdN+nP7+4PbHJgeDgx1sxGEUYI/gsbP2RIPkdepqu5640ZOFMBOLj1l7TZEJ6uL/U27HzUr/VEncY1lUZ16PPJTJ9PZvp8GiczuyfD7jbA9sBO0frNuZ7f3c+LKrfnEyrFpYT+7PcA/852it3Ia4T7n0GEjL5yYAZhJpx/uHvKfvXu/rKZ9Qd+DRxAmP6vhPBA4FHgJnd/L82xfzWzTwkZAUOAFtE1bwb+nureIjruETObQWhA2IswsN9swr3Mn9PFWlsqKpyRE2YD0IbVHPXdqPjOfX8LTeq8J6KISIOTS8v9xYQvoiPdfYyZ3Qvs4u5/ADCzzoQvt0OB3XKMo9inq9mEmf2aMJLuKsKT/ClZxp3oKkLl/iAza5ow9V6teHz87MrX31//Wrj9AZ4p35PTVbEXEZHacUYWZVYCV7n7fdW5gLs/DDycZdkRhIH7Uu17gzCtbXVimMbGmYO5HPsCuU0DGDvufcK9TsG9N2Mxc5aGYZXOafEKzTZEzxY2660Be0VE8iSXyv1QYLK7Jw/aAlSms59ESC3/E/DzHM49M1rmdboa4KhoWryDCVP3LQFedPexZhYbcT/VdDUz0k1XY2bPE25Evs/Go+FWMrMLgRsILf6HpzpXlj6Pls0IaYKZBuSrtgUr1nLVfz9j2oLwjKUVazm4JD6BwBPl+3J6bVxYREQEzsywbz0wF/gww4C7UgRiDQhtWc1PS8eEnAeA/S6F0qaFC0xEpAHJpXLfGXg7Yb0MIHF6GXdfaWZvAIfkGEexT1eTeO7zCalya4EjajgYXqeE17nM7ZuVBSvWcvOr0xg5YQ4byuNZiT8s+ZDWFrL8plV051NPNy6hiIhIzbj7/YWOQWrXdxuc5yeHIQzOKn2eluVREmbHvrDr8QWMTESkYcllKrylQGJu9rJomTx3uxMGeclasU9XkxDnzwkD86wjdF94pYZxDo+WX+Q4iF9WLnjkIx764GvWlVVQ4fHtx5TGMw6fLN+H1BMiiIiIiGS2YMVarnp3DevKKmjHKn7W9Pn4zv0ug9KcJ24SEZE0cqnczybesg0wmVDrqxzozcxaA3sTUuhydW20vN7Mtkk4Z8bpaszsczP7T/LJzKyfmXVI2mbR6LFVTVfTkjBdTbuEY0vM7HLSTFdjZmdHca4DjnL3F6t6w2bWy8xOMrPmSdvNzE4l/pn8o6pzVcetJw3k5D1606JJCaUloQK/BYsZWhJmxqlw46nyvWrj0iIiIgCYWU8zOy2aZi5dmX5RmeQGBannbn51GvNXhxaEnzZ5jtasDjs6bQu7HFvAyEREGp5cHpeOAy4ys83dfSHwLLAauNbMuhFGyD+NkL6fPIBdlYp5uhozG0CYN9YILfrHm1mqPLNF7n5xwnpH4CHgdjObSHiw0JYwKvBWUZlb3f2OFOeqsS5tW3DNkTvziwO34eZXv+Tt8R/xUMmVlFj4En6rYmfmbdQzQEREJO9+QRhFfscqyt1HmBLud7UdkOTHghVreXx8mEBpM1ZyVmnCmIDDLoOS0gJFJiLSMOVSuR9JqOAOBF5y98Vm9htCa3WswmqEivQV1QmmiKer2Yx47vr20U8qs4h/VhCfkmYIYYT+3aNrzQMeA+5099eyerc1EKvkr115Fc2nV04ewBPl+9T2pUVERH4ATHH3L9IVcPcvzGwK8ENUuS8aN786jbKKcOv2syZjaGNrw47Nd4CdjipgZCIiDVPWlXt3/4Awl33itjuilvFjCK3QnwP3uvuy6gZUjNPVuPs4qtEx3d0XA5fkelyt+Pp9WswcV7m6zprzeskeNDVjQ7mnP05ERKRmehKyA6vyJaCnzvXd7XtDj91ZtNtFjJwwhwqHjqzg9NKE3opqtRcRqRU1HsXE3ccD4/MQixRKRQU8cyGUr6vc1Lx5C1761Q+4+bUZTJy1tIDBiYhIA9eCMOVdVdYDrWs5FqmpeZNg4Re0n/AAV9i+3MRR/KTJ85Wz8HzbYhu22OGIAgcpItIwaYhSgU8fg+WzN95Wto4uXz3DNUeeUJiYRESksZhL6EJXld1IPcaO1Dfl62kKHFcyjmObv06Tyknt4ZrvfsyV362nS9sWhYtPRKSBymW0fGmI1q2CFy6DDas33l62Nmxf/11h4hIRkcZiLNA3GtA2JTM7HdiaMJ6OFInmVkYL20ATC/3uF3sbPqroy82vflngyEREGiZV7hu7N/4OZetS79uwNuwXERGpPTcSZse508z+bGZ9YzvMbCsz+zNwZ1TmxgLFKHnQjjW81uSX7DhxBIu+nVXocEREGhxV7hu7CfdA2ZrU+8rWwPi76zYeERFpVNz9c+Bn0eplwDQzW2dm6wiD6F1GuF85x92nFChMyYOmVk4L28AJ9gorHzyt0OGIiDQ4qtw3doPOgiYtU+9r0hIG/6Ru4xERkUbH3f8DDAX+C6wBmkY/a6JtQ939voIFKHmxzpuwxpvxYPmBXN7kN4UOR0SkwdGAeo3dvhfDhHtTt943bRH2i4iI1LJo9p0jzawE6Aw4sNjdKwobmdRUhTWhpLQJzQecDPtdymltu6J2exGR/FPLfWPXvA0cfB00TZpdqGnrsL2ZZh0SEZG64+4V7r7A3ReqYl/kSptBkxZ8u8VBcNGncPiN0LZroaMSEWmwVLkX2PV4aN9j422b9YRdhhcmHhERaTTMrIOZ7WtmW2Yo0z0qs1kdhibVFVXqGXgqXPQp07b7uSr1IiJ1QJV7gZISOOIWaBr1vW/SMqyX6NdDRERq3UWE6fC2yFCmW1TmgjqJSKqv2y6VlXq11IuI1C3V3iTotQf0PQDHYOsDoOfuhY5IREQah8OAL919QroC0b7pwOF1FpVUz8/fUqVeRKRAVLmXuEOuZ2XbbeGQ6wsdiYiINB59gP9lUe4LYKvaDUVERKR4qXIvcZv1ZOKgv4X+9iIiInWjLbAyi3Irgfa1HIuIiEjRUuVeRERECmkesHMW5XYCFtVyLCIiIkVLlXsREREppLeBnczs0HQFzOwQYBfgrTqLSkREpMioci8iIiKFdFO0fMTMzjaz5rEdZtbczM4GHgEcuLkQAYqIiBQDVe5FRESkYNz9A+ByQt/724HlZvY/M/sfsCza1g640t3fKVigIiIi9Zwq9yIiIlJQ7n4tcAwwCWgGbBP9NI+2HePu1xQuQhERkfqvSaEDEBEREXH3p4CnzKwr0JuQhv+1u88vbGQiIiLFwdy90DFIHpnZQmBWDU7RGY1GnIk+n8z0+WSmzyezhvj59Hb3zQsdhEhN6N6iTugzykyfT2b6fDJraJ9P2nsLVe5lI2Y23t0HFzqO+kqfT2b6fDLT55OZPh+Rhkl/21XTZ5SZPp/M9Plk1pg+H6Xli4iISMGZWQtgf2A7wgB6lqKYu/vVdRqYiIhIkVDlXkRERArKzI4hjIrfMVMxQj98Ve5FRERSUOVekt1Z6ADqOX0+menzyUyfT2b6fBohM9sDeBSoIMxnvzOwC3AdYcT8g4D2wN3AnAKFKTWjv+2q6TPKTJ9PZvp8Mms0n4/63IuIiEjBmNlI4GjgCHcfY2b3Aqe5e2m0vzNwL7AbsJtGzxcREUlN89yLiIhIIQ0FJrv7mFQ73X0RcBJhzvs/1WVgIiIixUSVexERESmkzsAXCetlAGbWMrbB3VcCbwCH1G1oIiIixUOVe8HMTjKzN81suZmtMrPxZna+mTXq3w8za2pmB5rZDdFnssLM1pvZXDMbZWbDCh1jfWNmfzEzj34uLnQ89YWZtTSzS8zsQzNbZmarzewrMxtpZnsVOr5CMrMeZnaLmX1hZmvMbK2ZTTOz282sb6HjkzqxlNAqH7MsWvZIKudAl7oISPJD9xeb0r1F9ej+YlO6t8issd5fNNr/XCUws38BDwGDgTeBlwnTEN0KjGrMX8DAfsArwK+B7oRWo6eAJcAxwFgzu6pw4dUvZjYEuIRwAy4RM9sK+BS4nvB7NBYYAywEjiRM/dUomdlAYBJwAdAKeBF4AWgJnAN8YmZDCxeh1JHZQK+E9cmEkfEPj20ws9bA3sDcug1Nqkv3F2np3iJHur/YlO4tMmvM9xcaLb8Ri6YeOg+YB+zr7tOi7V0J/0kcBVwI3FSwIAurAngCuMnd30zcYWbHE25arjCzse4+thAB1hdm1hy4H5gPfED4Ymn0ogrJy0Bf4DLg7+5enrC/E9CpQOHVB/8CNgPuAs539w0QWrYI06KdBfwb6F+oAKVOjAMuMrPN3X0h8CywGrjWzLoRRsg/jZC+/2TBopSs6f4iI91b5ED3F5vSvUVWGu39RWN9airB76LlpbEvXoBoJOJzo9XLGuvTdXd/zd2PTf7yjfY9BtwXrZ5Sp4HVT1cBOwA/B5YXOJb65HJga+Bf7n594pcvgLsvdvf/FSa0wjKzFsCe0eqVsS9egOj15dHqrmbWqq7jkzo1EngdGAjh7wL4DdAUuBj4JzCIUMm/ojAhSo50f5GG7i1ypvuLTeneIoPGfn/R6P5TlcDMehBultYTbqw24u6vE9IfuwHfq9voisZH0TK5X2ijEs1R/RvgYXf/b6HjqS/MrBlwdrR6YyFjqafKiQZOq8J3wJpajkUKyN0/cPeD3P2lhG13AHsAfwX+j/B/TP9o5Hypx3R/UWO6t4jo/mJTurfISqO+v1BafuM1MFpOcfd0v9gfEvrxDATeqZOoisu20fLbgkZRQNHT0fsJfQUvKnA49c0gQlrcXHf/ysx2I6SidiGkF77k7m8VMsBCcvcNZvYq8EPgT2aWnDZ3dVT0bndXP8tGyN3HA+MLHYfkTPcXNdPo7y1A9xcZ6N6iCo39/kKV+8Zrq2g5K0OZr5PKSiTqB3pGtPpEAUMptD8D/YAT1KK2iV2i5Vwz+zuh9SHRFWY2GjjF3b+r08jqj/MIA9ycDRxiZrGK3BCgAyEd+5LChCYi1aT7i2rSvcVGdH+Rmu4tstNo7y+Ult94tYmWmf7wV0XLtrUcS1ExsybAg0B74NXGmioWjTL6S2B01E9QNtYxWg4kfPn+E9iG8KXyY0Ja6pHAbQWIrV5w9xnAUOB5QgrqkdFPd+Az4M3EvnIiUhR0f1ENureI0/1FRrq3yEJjvr9Q5V4kd7cDBxKmb2qUA96YWUvCoD8rCE9HZVOx/1+bAg+6+6/cfbq7L3P3ZwhfMg6camZbFyrIQopu4CYTbkx+DGwe/RxJuFF5wsz+WLAARUTqTqO/twDdX2RB9xZZaMz3F6rcN16xp+atM5SJPX1fWcuxFA0zuwn4CWF6nwPdfV6BQyqUvxD6Bf7a3Rt1v8AMEv9u7kreGfUnnkCYz3u/ugqqvjCzzYDRhJa7g939GXdfFP08DRxMGOjmCjPbNv2ZRKSe0f1FjnRvsRHdX2Sme4sqNPb7C1XuG6+Z0bJ3hjI9k8o2amZ2A/ALYCHhy3daFYc0ZEcR5uo93czGJf4Q/tMEODfa9n8Fi7KwvkrzOlWZbrUcS310GOEp+ntR+txG3P1L4H3C2DDD6jY0EamBmdFS9xdZ0L3FJnR/kZnuLarWqO8vNKBe4xWbamUnM2uZZkTbIUllGy0z+yvwa2Ax8H13/6zAIdUHJWR+Ktw3+tmsTqKpfxL/bjoRUi2TdY6Wq1Lsa+h6RctM8xYvi5YdM5QRkfpF9xdZ0r1FWrq/SE/3FlVr1PcXarlvpNx9NjARaAYcl7zfzPYjDEAxD3i3bqOrX8zsOuC3wFLgIHf/tMAhFZy793F3S/VDmLoG4LfRtgEFDLVg3H0u4ckwhH6UGzGzDsBu0WpjnO7rm2g5KJqaZiPRtkHRarrWCRGpZ3R/kR3dW6Sm+4vMdG+RlUZ9f6HKfeN2bbS83sy2iW00sy7ER9m8zt0r6jyyesLMrgEuJTzhO8jdG3Urg+Tsz9Hy92Y2OLYxmr/334RRkSfQOG9wnwdWE56w/8PMmsd2RK9vJqTuLgVeLEiEIlJdur/IQPcWUkO6t8isUd9fmLsXOgYpIDO7DTgXWAu8AmwgPAlsRxiM4lh3Ly9YgAVkZkcAT0er44EpaYp+7u7X1U1U9Z+Z3QecTniy/vcCh1NwCfPQbgDeI6Rf7g5sSZiyZv/G2sfSzE4H7gZKCU/aJ0a7BgFbAOsIcxyPLkiAIlJtur9ITfcW1af7izjdW2TWmO8v1Oe+kXP388zsLeB8Qv+mUuBz4B7g3431qXoksR/O4OgnldcBfQFLSu5+sZm9A1xAmJe2FfA1cCOh5WphIeMrJHe/38wmEeYz3gc4KNo1l/ClfKP6oIoUJ91fpKV7C6kx3Vtk1pjvL9RyLyIiIiIiIlLk1OdeREREREREpMipci8iIiIiIiJS5FS5FxERERERESlyqtyLiIiIiIiIFDlV7kVERERERESKnCr3IiIiIiIiIkVOlXsRERERERGRIqfKvYg0KGY208zczPoUOhYRERERkbqiyr2IiIiIiIhIkVPlXkRERERERKTIqXIvIiIiIiIiUuRUuRcRzKy1mV1iZh+a2QozW2NmU8xshJm1SSo7IurTPsLMtjKzB81svpmtjY75jZk1SXMdM7NTzWycmS2NjpluZv8ys55VxHexmb1rZsui+GaY2UgzOzTDcQeZ2atmttzMVpvZe2Z2RJqyW5rZrWb2ZRTXajP72sxeMLOfZftZioiIiIgUgrl7oWMQkQIysx7Ai8COwELgI2AtMATYAvgUGObuS6PyI4Argf8Ah0dl3wLaAcOAFsBo4Bh3r0i4jgEPAicBG4BxwBJgd2Cr6PXB7v5hUny9o/j6Aauiay0HegL9gfHuPiyh/EygN3AN8AfgQ2BGdPxAwIHh7j4q4ZgtgIlAN2AW8DGwDugO7AzMc/fts/1MRURERETqWsrWNRFpHKIK9+OEiv2twCXuviba1xK4EzgF+AdwRtLhpwFPAKe4+9romG2BscCRwM+B2xLKn0uo2M8HDnT3KdExpdH5LwRGmlk/d18X7SsBniJUzJ8Gzow9ZIj2tyU8HEjlEuBQd38hofzlwNXAtcCohLJnEyr2dwDnesJTTzNrDuyR5hoiIiIiIvWC0vJFGreDgT2B94CLYhV7gOj1z4EFwMlm1iHp2NXAebGKfXTMNOCKaPVXSeV/Ey2viFXso2PKgYuBrwkt7scmHHMEobV9JnBiYsU+Onalu7+a5r3dklixj/yV0Oq/jZn1StjeNVq+4EnpTO6+zt3fSHMNEREREZF6QZV7kcYt1l/9icQU+hh3/w4YT8jyGZK0+2V3X5DinA8DFYQKdHeoTP3vG21/IMV11gMPRavDEnYdHC0fSnzwkKVn01xnRrS6ZcKuD6Ll9WZ2pJm1zvFaIiIiIiIFpcq9SOPWN1r+LRokb5Mf4g8ANk869qtUJ4xS6r+NVntEy+7R8tvElv4kM5LKQmjJB/i8qjeSwtdptq+Ili0Stj1AeCixHaEbwHIz+9jMbjazodW4toiIiIhInVKfe5HGrTRavk5Ifc9kVh6ul+sInjUZ8XOTTIS0FwlZCyeb2bWEQQL3in4uBC40s3vc/Sc1iEVEREREpFapci/SuM2OliPd/V85Htsn1UYza0YYZR9gbtJySzNrHhswL0nfpLIQb33vl2Ns1eLuk4HJUDmY36GEFv2zzOwxd3+pLuIQEREREcmV0vJFGrfno+Vx1Tj2B2bWOcX2Ewn/t0x39zkA0XJGtP2U5APMrClwcrQ6LmHXi9HyFDNrQR1y9wp3f5YwSj+EafdEREREROolVe5FGrfRwARgPzO73cw6Jhcws25mdnaKY1sB/4qmiouV3Zow1RzATUnlb4yWV5vZ9gnHlBJGse9FSP1PnKLuacKc832Ah8ysfVJsbc3swCreY5XM7DQz2y3F9k6E2QQgP90SRERERERqhdLyRRoxd68wsyOB54BzgJPM7BNCun4LwgBzOxKmw7sr6fAHgMOA6Wb2NtAW2D867r9Acpr/bYR+7CcCn5jZOGAJYZ76vsBS4LjElP0ovqOBl4CjgYPM7C3CdHY9gQGE0fzTTYeXraOB+81sLuFhwjKgE7AP0Bp4kzDQnoiIiIhIvaTKvUgj5+5zzGx34CfAcGAXYA9gMaH/+w2krtjOIEyP9xfgAKB9tO0e4J/JU+u5u5vZyYSuAGdH12gJfAP8G7jW3WeTxN2/ilrVLwSOIVS4S4F5hOnu7q3J+4/cQBhQcCgwGOgALAImAvcRpuLbkIfriIiIiIjUCnOvyWDUItLYmNkI4ErgT+4+orDRiIiIiIgIqM+9iIiIiIiISNFT5V5ERERERESkyKlyLyIiIiIiIlLk1OdeREREREREpMip5V5EGh0zu8/M3MzOKHQsUmAj2vdiRPv3GNG+V6FDEREREakJVe5FpCDMbGZUwe5T6FiKnZmdEX2W9xU6liJ0M2FKx5sKHUgm9eHvxcyGRTGMK1QMVdGDOxERacxUuRdpwPpcNmaLPpeNua3PZWM+KnQs9czvgB2ApwodiBTQiPZDgYMI34U/iNZFREREilKTQgcgIvnX57IxWwBXAGcCBjQvbET1i7t/C3xb6DikgEa0LwHuAlpGW1oCdzKi/a6MWF5RuMBEREREqkct9yINSKylHpgBnAW0oJ5V7GMp5EDvaNNXURqtJ6YdJ6aam1knM7vZzL4ys/VmNjrhfMeY2T1mNsXMlpnZWjP70sz+ZWY908SQMnXXzEZE20eYWVczu8PM5pjZuuja15lZixzfb6mZ/dzM3jGz5VH8881sopndYGabpzimtZldYmYfmtkKM1sTvb8RZtYmqexM4N5o9fSkz/K+XGJtZE4GehEefhEtewMnFSyiFLL9e0kov4OZ3R39vq41s6Vm9oqZHZHm/Fua2a3R38xaM1ttZl+b2Qtm9rOEcuOAsdHqfkkxjMvyvWxmZn+JfpdXR9ebY2bjzOx3aY7paWY3mdkX0d/BCjN7O/pcLKFcn+hzOj3adG9SjGdkE6OIiEgxU8u9SANQZC31XwL3A8cCrYEngFUJ+1clle8MfAi0B94ExgOLE/Y/BqwFPgNeIbz3AcB5wHAz28vd/5djjD2BCYTP8h2gHbA3cCmwI5CyopTG3YQKxxrgLWBR9J62Bn4NjAQWxgqbWQ/gxeg6C4F3o/c3BLgSOMrMhrn70uiQUcD3gL2A6dE1YhJfS8yI9m0IfezbJO0J20e0f4oRy7+r+8BSyvrvxcxOiMo2A6YAzwKbA/sAB5rZ1e7+x4TyWxB+z7sBs4AXgHVAd8LvVB/gzqj4C4Tfwx8C86P1mM+rehNm1gp4m/B7vYDwt/odsEW07XvAtUnH7E/oOtM++hxeIPwbfY/wQOsA4LSEz+F+wt/p1tG1vkw4XeJrERGRBklT4YkUQJ/LxhTtH97M6w6zqktVLWpx7g1s5e4zU+w/g3iL9EvAse6+MkW54cCz7r46YVsTQkX4cuAFdz8k6Zj7CBXuM939voTtI6LjAP4PON/d10f7dgA+IFQu9nb3t7N4j72BmcBsYIi7z0/aPwD4xt0XROtGqJTsCdwKXOLua6J9IW0cTgHud/czEs5zBuGz2mh7gzKifdH+zTBieY3/ZrL4e9mV8BBsPTDc3Z9P2LcT8DzhodUB7j422v5H4E/AHcC5nnBDYGbNgT3c/Y2EbcMIrfevu/uwHOM/jVD5HgMc6e5lCftKgf3c/bWEbVsQHlC0A34C/CcWX5SR8wzhIV7y3/B9pPjbFhERaQyUli8i9d0G4JxUFXsAd388sWIfbStz9yuAb4AfmFnbHK85G/hFrGIfnXMq8EC0emCW5+kSLScmV+yjc34cq9hHDiZU7N8DLopV7KOya4CfE1o9TzazDtm+GWkU/kBosb8ksWIP4O5TCFkiABck7OoaLV9IrNhHx6xLrNjnQexaryRW7KNrlSdW7CO/BDoAN7j7/Ynxufts4Oxo9cI8xigiIlLUlJYvIvXdxFQtlYnMbDtCxXgbQst67MFlk+j1NkAuMwa8llixThBLP94yy/N8DqwEDjOz3wMPufusDOUPjZZPuPsmg7q5+3dmNj4qN4SQ0SCNnJmVEH7/ndBNI5XXo+WeCds+IHRfuT7qvv6yu9dWd4QPo+WlZraIkG2zLEP52N/CyDT7JxBS8QeYWQt3X5ufMEVERIqXKvciBZCv1PaYPpeN6Ua8z30JKfrc5/uadShtZThKv78N+CnxgdFSaZfjNb9Os31FtMxqUD13X2lmZwH3AH8G/mxmcwn96McAjyZVSvpGy7+Z2d+qOP0mA/E1aHlIbY+fq30bwr9xquyHJUCvetTnPhudiP+OL0gYZy6VxN+bB4AfEAYRfAooN7PJwBuE38138hWgu48zs78CF0fXdTP7nDAuxBPu/mLSIbG/hQ+reD8Q3v/cfMUqIiJSrFS5F2kAZl532Dzg/D6XjbmaKir5RShVC3rMRYT03G8IacfvAAvcfR2Amb1DaKnMtWKYt6nQ3H2Umb0C/BjYlzDw3bHRzwgz2ydKMwYojZavE/rqZ5IpA0AyGbF8FSPaX0R4MJQ4qN4q4KIiq9hD/PemHHgw24Oi7JCTzexa4HDC7+ZehFT3C83sHnf/Sb6CdPdLzex2wt/C3tG1zgbONrOXgMMSUvZj7yk2YGYm6/IVo4iISDFT5V6kAWnglfxUjouW57j7syn2b1OXwaQTpR/fH/1gZlsT5ljfH7ie+PRrsUr+SHf/Vx2H2dg8BFwG7EB4+OOEByYPFzKoalpEeAjWErjA3ZNnnMjI3ScDk6Eyxf9Qwudwlpk95u556/7h7l8B/4x+MLO9gUcIGQRnER+dfzbh7/fqaMwAERERqYIG1BNpgGZed9i8mdcddj4htfVu4OPCRpRSbLC6mjxk7BgtZyfvMLODqKep6+4+nZCmD9A/YVdsILTjyE0+PsvGZcTyCkKrcSwzZA1wdrS9Pkr7bxy1dr8SrR5bk4u4e0X0oOzpaFPi72fef8/c/S3gvhTX0t+CiIhIjlS5F2nAYpX8mdcdNrDQsaQQ6yO7Qw3OERvg7tyoxRGobBm/vQbnzQszG2hmx0fT2CX7UbRMTK8fTRgobD8zu93MOiYfZGbdzOzspM1VfpZmdoGZfW5m/8n+HTRwI5a/QxiUsAJ4iRHL3y1wRJlU9W98FWFmiZvM7ARL6qhuwe5m9oOEbaeZ2W7JJzKzTsQH3kv8/YzFsE003kXWzOwoM9s38e802t4S+H6Ka/2NMMbF783s/FTXM7OdzOzopM3Z/C38J/pbuCBdGRERkWKkJ9siUihPAcOAh6L+tsui7Ze6++Isz3EtYZTwc4D9zewjQmv+foRB6+YBQ/MYc656A48Cq81sIiHDoBkwkJBVsRL4Y6ywu1eY2ZHAc4T3dJKZfRId1wLYDtiRMB3eXQnXeY/wXneLRtOfQqjove3u90ZlOgP9onISdxGwRbSszzL+vbj7+Ggu+XsIae7XmdlnhAECNyfMCd+F0A0klmZ/NHB/NMjjx9E5OwH7AK2BN6PrAuDus6K/sYHAp2Y2gdDf/Qt3r2oAyP0In/HC6BwLgfaEv8+OhAd1dyRca3b0tzAKuBX4g5lNIfzubwbsAvQk9Ml/MuE6TxP+pn5pZjsDcwhdLu5JGCCwF+FvoXMVMYuIiBQVVe5FpFBuJYzwfTJhMK/YuADXAFlV7t39XTMbQkhxH0wYqOuraP16IHkE7rr2HvA7QsVme2AQIW14NnADcEvy1HjuPsfMdgd+AgwnVGL2IHwmc6Pjnko6Zp2ZHUx433sSKl8lhP/j70XSG7H8a+B7hQ4jC1X+vbj7o2b2IfAL4CDC7x2EBzofE2ZoSJwq7wbCwI1DCX8/HQj99ycSUuUfcvcNSXEcTfjb2g84kTDw3euElvZM7iMMjLc3sDOhYr0M+JLwMOJud1+ZeIC7jzWznQgD/B1G+HdqGr2fGYQBEUcmHfOxmR1PGJV/KPEBE98iDLgpIiLSYJm7FzoGEREREREREakB9bkXERERERERKXKq3IuIiIiIiIgUOVXuRURERERERIqcKvciIiIiIiIiRU6VexEREREREZEip8q9iIiIiIiISJFT5V5ERERERESkyKlyLyIiIiIiIlLkVLkXERERERERKXKq3IuIiIiIiIgUOVXuRURERERERIqcKvciIiIiIiIiRU6VexEREREREZEip8q9iIiIiIiISJFT5V5ERERERESkyKlyLyIiIiIiIlLkVLkXERERERERKXKq3IuIiIiIiIgUuf8HzsP3Oo94y7EAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))\n", "\n", "plt.suptitle(\"Adversarial training on \" + f\"{FLAGS.dataset}\".upper())\n", "axes[0].plot(accuracy_train, lw=3, label=\"train set.\" , marker='<', markersize=10)\n", "axes[0].plot(accuracy_test, lw=3, label=\"test set.\", marker='d', markersize=10)\n", "axes[0].grid()\n", "axes[0].set_ylabel('accuracy on clean images')\n", "\n", "axes[1].plot(\n", " adversarial_accuracy_train,\n", " lw=3,\n", " label=\"adversarial accuracy on train set.\", marker='^', markersize=10)\n", "axes[1].plot(\n", " adversarial_accuracy_test,\n", " lw=3,\n", " label=\"adversarial accuracy on test set.\", marker='>', markersize=10)\n", "axes[1].grid()\n", "axes[0].legend(frameon=False, ncol=2, loc='upper center', bbox_to_anchor=(0.8, -0.1))\n", "axes[0].set_xlabel('epochs')\n", "axes[1].set_ylabel('accuracy on adversarial images')\n", "plt.subplots_adjust( wspace=0.5 )\n", "\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Find a test set image that is correctly classified but not its adversarial perturbation" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def find_adversarial_imgs():\n", " for _ in range(iter_per_epoch_test):\n", " images, labels = next(test_ds)\n", " images = images.astype(jnp.float32) / 255\n", " logits = net.apply({\"params\": params}, images)\n", " labels_clean = jnp.argmax(logits, axis=-1)\n", "\n", " adversarial_images = pgd_attack(\n", " images, labels, params, epsilon=FLAGS.epsilon)\n", " labels_adversarial = jnp.argmax(net.apply({\"params\": params}, adversarial_images), axis=-1)\n", " idx_misclassified = jnp.where(labels_clean != labels_adversarial)[0]\n", " if len(idx_misclassified) == 0:\n", " continue\n", " else:\n", " for i in idx_misclassified:\n", " img_clean = images[i]\n", " prediction_clean = labels_clean[i]\n", " if prediction_clean != labels[i]:\n", " # the clean image predicts the wrong label, skip\n", " continue\n", " img_adversarial = adversarial_images[i]\n", " prediction_adversarial = labels_adversarial[i]\n", " # we found our image\n", " return img_clean, prediction_clean, img_adversarial, prediction_adversarial\n", "\n", " raise ValueError(\"No mismatch between clean and adversarial prediction found\")\n", "\n", "img_clean, prediction_clean, img_adversarial, prediction_adversarial = \\\n", " find_adversarial_imgs()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAABBsAAAFxCAYAAAAyFaBfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABHPElEQVR4nO3debwcVZn/8e+TjX0LYQlwIYoQVkGIbCImaJBJEHdAQEFZHR0BZWTJgDgaFRBxUEQF+UUWlVVAkiFkhKgEFxIDODpsgUhYAkjYw5KQ5/fHOe2tdHo53ff0rXtvPu/Xq169PXXqdHX16eqnTp0ydxcAAAAAAEAug8quAAAAAAAAGFhINgAAAAAAgKxINgAAAAAAgKxINgAAAAAAgKxINgAAAAAAgKxINgAAAAAAgKxINrTBzDYzs1PN7Ndm9qiZvWpmr5jZ383sRjP7vJmtX2O+UWbmcRrb+zUvn5mNLayDUWXXB0A6M7u88P29PmO5Z8Uy5+cqE7WZ2cy4rqd0oOwete9sBwAADCxDyq5Af2JmQyR9XdIJklatEbJ5nA6UdI6ZfdfdT+/FKgJAR5jZmpI+XHhqopkNd/dFZdUJAAAAfRc9GxKZ2WqSpko6RSHR8HdJX5Y0RtJISRtJ2kXSyZLukbSapNNKqSwA5PdRSWsUHg+TdEhJdQEAAEAfR8+GdN+TtF+8/1NJx7n761UxT0uaa2bfkfRJSRf2Yv36BXefKcnKrgeAln0q3j4s6XVJ28bnflBajdAydx9bdh3qcfezJJ1VcjUAAEAm9GxIEMdXOCo+nC7p0zUSDf/kwWWSdu987QCgs8ysS9LY+PDyOEnS7ma2dSmVAgAAQJ9GsiHNl+OtS/qcu3vKTO7+t3YWZmbrmdl/mNkfzexZM3vdzB4zs5+b2Z4N5jMz293MJpvZH8xskZktibd3mtkpZrZWg/mXG5wrDmh5kZnNj3V4ysyuNbN3tPO+YpkNBxArvHakmQ0ys+Pie3nezJ4zs9vN7P1V84wxs1+Y2QIze83MHjKzr8VTX+rVY1MzO97MfhUH+Xw9DvL5gJldbGY7JLyXoWb2JTObG+ddZGa/M7NPxteTBjszs/ea2c/iAKOvmdkLZnZXHIR0jUbzAr3kcHX/XlwRp0o7+MlmM5vZEDM7wcz+XPiu3GFmRzSZ70/xOzQ9YRmXxtgFZrbCb5uZrWZmJ5rZb8zsGTN7w8wWmtkNZjahQbnV7eJ2ZnaJmT0S243nq+L3MrMrzOxhC4MHL47f7d+b2bfM7J01lrGqmU00sx+Z2V/M7KVC/aaZ2aG13lNh/uUGfTSzfzGzm8zsCTNbamY31IutUdaOFn5/fhvX05LY/s4xs6+b2Qb16tFTjdrM6t8OM1sn1ue+uI6fiL8D21TN9zEzuy2+l8XxfXy6ST16vA7MbAMz+66ZzYvt+pNm9ksze1d8velAnfF78xkzuyVuC2/E+txqZoeZGb0EAQB9m7szNZgkrSlpqcKO9a97WNaoWI5LGlsnZpykZwtxtab/rDPvB5vM55IelPSWOvOfFWPmS3qPpOfrlPGapPe1uQ7GFsoZVeP1ymtHK4yRUWv5yyQdGeOPlLSkTtxtkgbXqcdzTdbTEoUeLPXex9qS/tBg/inF9VmnjFUl/axJPeZJ2qrs7wHTyj1J+lvcHu8sPHdbfO4RSdZg3jUk/bbBNn5Zve+KpH+Lzy+VtHGDZawq6YUYe3aN198e27VG37VLa7UXWr5d/KCkV6vme74Qe3KTZbikm2ss4/yE+aZKWqXO+59ZaHe+WWPeG2rF1ihnp4R6PCVp1zr1GFuIG9XGdlZzO6hR9t4KbWOt+j0X34dJ+nGD91Hvd7RH6yCWsYPCaZW15l2q8PtW93OIZWwu6e4m9ZgmaY2y2wcmJiYmJqZ6Ez0bmttD0uB4/7edXJCZ7SLpvyUNlzRXYfC1LeLjMQo7kpJ0hpkdXaOIpZJuknSMpHdJeoukEQo72idJekzS2yT9oklV1pF0vcIgmB9RGABzY0lHKCQgVpF0qYWrc3TK6ZLGS/qqpG0krS/pvZLuU9iJvMBCD4eLJf1a0j4K73Xr+JwUEjdHqba/SjozLmP7OO/bFK4kMkNhPJMfmdlOdeb/ibpPk5miMDjo+vH2pwrrqtkR38slfULSG5K+rfAZry+pK86/QNJbJf2KHg4oSzwSv218eHnhpcr9UQrfv3p+JOnd8f4VknZV+L6NUUi2fVLd40FU+4VCuzZY4btSzwcUEoCVZRTrv7mk2xXa0nkKbcKWCu3qjpLOU0hgflqhvalnPYX3PE9hsMyRkjZT+K7Kwukk34qxv5a0f1zmegpt8URJFygkk6u9oNBuHaSwXjaN5e8W6/eqpAmSvtagfpL0PkmnSvqVwjrfQKFd+16T+So81v0LCp/pVgpt0vYKf5Dvk7ShpOusQc+xXnCZQhLrKIX2ciOFz+ElSetK+r6kExXq/F8Kv4HrK7TZd8YyTjez7WqU3aN1ENvqXyms+9ckTVL4DDZQ+HzuUhjPact6b87M1lZI5u2kkNg4QeF3cLik0bHM1yT9i8L3CwCAvqnsbEdfnxT+uFeOIhzcw7JGFcoaW+P1e+Jrv5c0rE4Z34gxT0tarcXlj5S0KM6/b43XzyrUb65qHDFRSD5UYvZvYx2MLcw/qsbrxaM2K6xvhWTCMnX3PrhR0qAacb+rrMs2P6ufx/kvq/HanoU6fq/O/BcVYuY3WI/LJB1Yp4zN1H107OSyvgNMK/ek8EfVFQaFHF54fi1Ji+NrP6kz75jC9+DHdWIubfJduTm+NqdBHW+IMXfXeO0mdfcSWrfO/McW3uMmVa8V28X7Ja1Tp4xKL4yn6rXfPfgM9o9lvyxprRqvzyzU8Rdq3NOkEjuljXqsKemhOP9narzesH1PKL+yrmttB8WyX1SNHl8KyYBKzJJa7abCH/ZKL5hvdmAdnFaow0drvL6apD8XYlb4HBSSUq6QmKq5HiW9v1DGmJzbGxMTExMTU66Jng3NDS/cf75TCzGzcQpHX6TQff+NOqFfk/SKwlGS/erE1OTuT0r6n/hwfJPwU9z9lRrP36Du9bDCuccZzXL3q6qfdPcHFHbUpND74IvuvqzG/JV5d2mzB0blqO37arxWOQq7WOEIUy2nKhyNrOeEeHu1u99UK8DdH1M4QidJhzUoC+gIMxuq7stbTnP3RZXX3P0lhfZAkj5W50j3kfH2NYXLBtfy7wp/8uup9FTYpfp8/FjH4QpHeIuxldfeKumA+PBz7v58nWVcrHCVjWGSPt6gLme6+wt1Xqu0M880aL/b4u63SHpG4Wj+Xg1C31RoEz3n8gv1eFmh15vU/Dekky5w9wdrPH+1wp9vSXpc0neqA+I2PCM+bHkQ54R1UPl9uMPdr6sx/6sKvw81xZ4RlR55k9x9fp16TFfosSPx+wAA6KNINvQdlT+1j0p6zMzWrDUpdCe+L8aOqS7EwqCFR5nZVAuDSr5aGFTL1b0j3WgE+dfVvROznPjHvrKTt3GL77EVjQaEm1e5dfd5TWKGafmE0T+Z2Z4WBnr7q5m9aGbLCutpagwbaSsOqlnZ2Z/p7i/WKjv+IflNneWurtA7QpJur/dZx8/7f2Pc281sWJ33CnTKBIVTHqTlT6GouCzerq0wnkG1vePtTHd/rtYC3P1Z1fmuRDcqdI+XwkCV1T6u8D1fpnBaRtF7FU67el3SXQ2+Z2so9CyTarSrlaoqnOZWz9x4u72FgSDXbxC7gjig4CQLg8xWBiUstt2VQQkbtd13u/sTrSy3Rj3MzD5uZtdZGARzcVU9/j2hHp1W8/chtsfPxIf/UycRLXX/PtT8DWt3HcTEVyUhVjOJXKmbwkGDWvaStHq8/9smvw/NtlkAAErVyXPuB4pFhfvrdnA5o+Pt5uresW5muRGxzWxjSbcqnIfczDoNXnvG3Zc0eH1xvF29QUxPNdphrvQYeDIhRgrdVpdjZucqDOaWYh0t/5mMirf3N5nvPoXuz9XeKmlovP/DODUzSCFpsjAhFsilcpT2OXUn4IpmKGyTG8fY6vFgRsXb+9TY/6lOTy13f9XMrlPoJXGomZ1RdeS+koC4vcYf7Uq7uoqkfzSpQ0W9Kw38o15yMdZzpoWrPnxIoRfHyWY2W9IdCuP9/LpObzGZ2bsVeonUTIxWadR2P5wwf13xD+xNCuPd9KQenZbr96HWb0NP1sEWhft1fx/cfZmZPShp5xovjy7c/2tCHaT62ywAAKWiZ0NzjxTur9CFN6N2dtxWrXp8uUKiYYnC6ObvVdjZH65wfvVa6j7y1yjR9Gbi8jt52a2UOrRVTzM7TN2Jht8oDDy3ncIOW2U9TSzMUr2uKoM11jsyVfFynefb3Umv/ryBjjGz9dR9CsIsSdua2c7FSaG9qQycu5+ZbVRVzJrxtt53QYmvV3pVvEVh8NtKHbcoPK7V8yJHu1qxuM7zRQcpHPV+WKEX2u6SvqTQO+NpM/teHPzvn8xsHYVu+cMVxmg5RWFg4k1i/Stt0oI4S6O2O6WOjZyv7j/Z/0+hZ8uWCgMkVupRGQSzzIMVuX4fav2G9WQdFAfy7c3fB34bAAB9Ej0bmvuDwk7LYDUecb2nKjsef3L3ls8jNbMt1X0qxr+5e80RqrmqgSTp+Hg7S2GgzBW62jY5ZeEVhW7jzdblmnWeL+5kTnD3Rl2zgbIconB6ghSSDgc0iJVCG3mYlj9P/mWFP0/1vgsVzV6fqXA1nc0UejLcEZ8/TOEP46vqPo++qPJde9rdqxMh2cUeYd+W9G0ze5vC6VL7KKy7jSV9XtIeZranuy+Ns31M4VSVZZLGufvfapVdnaTILf42VK6g8y13P61OXJlXoeioDOugmGDI8fuwehzjAQCAfomeDU3EwaBujQ/HxT/1nVDp/vpWM2unx8BOhfuNLm2ZcorFQFdZV9c0OKe30Xr6e7xtds7y6DrPz1f4YyE1uPwZULJ6l6NsZZ758bZZr7BtG70Yv6eVXlkfjwNXSt0D490YB6ysVmlXR8QeBL3G3R9y98vd/RiFyzNeEF8ao+UTN5X26N4GiYYudf60hdEKp5tIK+9vSE/Xwd8L9+v+PpjZIIXLYdZSPBWG3wcAQL9GsiHNufHWJF2Ymgyw2tfwrqeS0Bghad8W5qtYpXB/cJ367KEwXsDKrrKu6q2nQQqnVtQzK96OqzF4ZKWMtRUu1baCOHjkn+LDg5tVFuhtZraVQld+KRzhtUaTuk9L2snMin/EKj0QxprZunWWtb6k9yRUq3KlieGSJpjZOxROfyq+Vq1y1YFBCj0IShF7MZxVeKqYXGnYHkW9cbWBlN+QzdTZHn5l69E6iFe6qIxPcmCD5bxX9Xs2/EbdV2fh9wEA0K+RbEjg7rcrnLsphWtbX2pmq9SLjyNZf1LhFIxUt6r7ygMX1Tj3uXoZo6rqUBxb4gM14teU9IMW6jOQVdZVvW7hp6nxkdbKueGrK1yKtJZvqMbgYwXnxdu9zeyLDeJkZoNjl2ygtxR7KFRf4aGWq9TdW6c475R4u6qkc+rMe66W/5NXk7v/Rd2j7x+u7j/gz6j+1QnuU/fAlt8ys4a9kcxswzhWRcvMbKuYqKyneJT62cL9Snu0Ta3vuYXLfZ7eTp1aNL9wv9ZvyBBJP9bAPv1yfuF+u+ug8vvwbjP7UI0yVpX0zXozx0FIL4kPv2Rm76oXG8tb28xGNooBAKAsJBvSfV7Sr+P9IyXdb2Ynm9k74g7qhvH+FxUugXaZwkBSSeLo6kconHu8laR7Yvk7mNl6sfydzexoM/uVpIeqyr9L3TutF5jZv5rZW+J8B0q6U6G7brMrKKwMro6348zsivi5rW9mu5jZxZK+rjA6fk3ufqe6zw8/wcx+YmY7mdnw+BldKulzajAyvLtfq+5uuueZ2S/NbIKZbWJm65rZFma2v5mdHcs5sWdvGUgTe25VrvDwv/FPfkPu/pik38WHh5rZ4Pj8bElXxuePMbPL4vdtePy+XSnp01o+WdpIpQfDAepONvyiMP5BLf+qMPDiCIXLX37VzHaN3/kRZra9mR1uZlcpXHq43a7rkyTNM7Nvmtl4M+uK3+UtzezT6m4zXpH0q8J81ykkaoZKmmpmB5rZxma2uZl9VmG9LtbyV0bKzt2fVPdneLqFy3BuHdfRexV+//5FDdrG/i7TOrhAYTuSpJ+Z2Wlm9ta4ve0by9hR0uMNypik8Fu9mqTbzOx8C5dq3iB+d0ZbuDTnpQpjmTRMSAAAUBp3Z0qcFHYGz5X0msI11xtNL0v6StX8owqvj62zjL0Udh6alb9U0npV845TSFbUin9T0kkKRxpd4br31cs+K742v8l6mBnjprSxDscW6jSqxuuV145sUEbd95CyHIWBu+Y0WLd3KIxA3qie6yicClGvjJ9K+mq8/1CdOg5TuOxls8/aJX2n7O2faeWYFE5pqGx3p7cw33GF+d5feH4NhStW1Nu2r2ih7dkktmXF+XdLqNs2kv6W+F3bqWre1LpNSSh7saQP1Zj3tAbzPK/QbX9+fHxWjflnKrFNbhQraXuFpEbddqjR+lCT9j2hbj0uu9F6SlxOj9ZBLOPtCj1uas3/pqRjFE6XcEmX1CljpMJvUco2e2Cn2wUmJiYmJqZ2Jno2tMDdl7j7vyv0PJiksNP2uELy4VWFoxk3KRzV7nL3r7axjDsVBpb6N4XzjZ9SuJTlqwpH/36lcJR7c3d/rmre2xXOs75W4ZrySxSuR36dwijn57dan4HIw7Xu91HowfCApDckPaeQPDhRYae24SXkPIy7sLfCZe7uVfh8nlfoQfIZdz9C3efk1hq4Tu7+hrsfL2lXST9SOFr2kkIiaZFCb5XvKZy6c3KtMoAOKJ4G8fMW5rtGoc1Zroz4fdtXIdk5V+G79YKk30s6yt0PVyJ3f0LdPcwk6QF3/1O9+MJ89yn8ATxCoQ19QuF7/7rCJSWnK/zh38rd76lXThOnKFzJYIqkuxXa7qUK3+m5Conqbdz9hhr1+6akDyv8AX1J4TdlnqSLJL3D3X9bPU8nuPtfFdqjKZKeVPg8n5J0i0KSpOEpXwNBjnXg7vcqJC0uUEh+vKHQu+Ymhd/ii9X89+FJSe9W2C6uUdi/eC2W9aTC/sd/StrZ3W9q+Y0CANALzN3LrgMwIJnZjQqDhN3s7iuc/wsAWPnEU5WeU+ghd7K7n9dkFgAA+iV6NgAdEAfkHBcfzimzLgCAPuXd6r6UKb8PAIABi2QD0AYzW9PMhjUIOUfdA3he3SAOADCAmNnwBq+trjDugxROh/hdvVgAAPo7kg1Ae3aQdJ+ZnVK4msVGZvY+M7tJ0mdj3M/d/W8l1hMA0LuONbNZZvbpeDWLdePVRT4h6Y8KY0JIYRDLN0usJwAAHcWYDUAbzGwPhQHuGvmNpA/GwSQBACsBMztV0jebhJ2/Mgy4CQBYuZFsANpgZmtJOkzhShHbS9pQ0uqSnpX0Z4VR/H/m7stKqyQAoNeZ2ShJh0p6n6S3KPw+DJK0UNIsST9yd06fAAAMeCQbAAAAAABAVozZAAAAgAHDzKaYmZvZlMT4y2O8m9n1metSKffInOViRWZ2VlzX8/tT2b2hv9c/Vavf/b7Cgm3N7Agzu9DM7jKz1wvtx6gWyzvAzG42syfN7DUz+7uZ/cTMduzNMiRpSCvBAAAAwEARL1X94cJTE81suLsvKqtOAFY6W0jKMqC8mV0k6fiqpzeX9BlJh5nZse5+WafLqKBnAwAAAFZWH5W0RuHxMEmHlFQXAHhM0i/VxqWRzezL6k4S3KBw9aMNFcaY+19Jq0j6iZm9q5NlFJFs6KBC15fq6RUze8jMrjSzfcuuZwozm1mvW1LZXbPMbFRh3Y4tow65mdmuZvYDM7vfzF42sxfj/avMrDrTCKAO2uHeMZDaYTMbaWZfi91YnzezJWa2yMx+b2ZnmtkGZdcRWX0q3j4s6f+qnkM/4u5nubu5+6iy6wK06FlJH5I00t273P0jkm5rpYD423RGfHirpI+4+5/d/Rl3v1XSWElPKZzZcF6nyqhGsqEcq0vaUmG06l+b2cVmZiXXqc8pe+e5DGY2yMy+I+lPkj4raWuFIy5rxfsHSbqovBoCAwbtcIKVrR02s/cr/OH8D0ljJK2jsFO1nqQ9JH1V0n1mtk9plUQ2ZtalsPMsSZfHSZJ2N7OtS6kUgNKZ2Ufj2AlJ+wUxEf35dpfn7i+5+43uvrDdMiQdIWnNeP80r7oKhLs/K+mc+HB3M9ulQ2Ush2RD77hS4c9iZdpU0uGSFsTXj5Z0SjlVQx9zsaSTJJmkKQo7QRsrdF+q7OjOK6luQH9GO4yGzGxzSdcpJBheljRJ0o6SNpC0i6TzJS2VNFzSDfRwGBAOV/e+8BVxquxcfzKlADPbIfaQqgyiNt/MLjKzLRrMs16MdTM7LWEZ82LslXVeH2JmnzGzW8xsoZm9YWbPmNmtZnZYvT9M1j2Y3sz4eE8z+7mZLYg9eu6uit/LzK4ws4fN7FUzW2xh0Ljfm9m3zOydNZZhZra7mU02sz/EXkKV3kJ3mtkpFi4nXu+9J9exUXK0p/Xoa8xssJkdbmY3mtnjFgYTfMbM5prZj63NHmY92JZyf85vj9vaY/G9PR5jtmznfbXCzLZSuIT9v0r6fkL8JIX98wvM7D0drl4jH4i389z9z3Viri7cP7BDZSzP3Zk6NCn8YLmkKXVe31rS6zHmOUlDy65zg/cys9F76dAyz4rLnF/2+++l93tIfL/LJB1Sdn2YmAbCRDvc42WuNO2wpK8XtpeP1Yk5pRBzQtl1Zqr7WU5J+a4oDMjmku4sPHdbfO4RxUvEN5j/o5LeKGwTxWmRpHcWHh9ZNe918fn/bbKMPQtl/EuN1zeXdHedOlSmaZLWaLCeZir0plxaNd/dhdiTmyzDJd1cYxkfTJjvQUlvafJZptSxbnuVoR59pi1UGExwbpP38nyr9e/htpTzcz5Y0mt1ylgkaceefvcT1vF/Fpb5Xw3iTi3EXaombUaLdTirUPaohPiXYuwVTeIei3G/6kQZ1RM9G0rk7g8oZM4kaV2FATiwEjKzwZLOjQ8vdfdflFkfYGVBO4yCnePtYkn1Ln94eeH+Nh2tDToqHoXfNj4sfq6V+6Mk1T1dxsy2k/QzSUMlLVTofryJQq+pIxWSEFc1qMIV8XZ7M3tHg7jD4+3TkmZU1WFtheTITgrnUZ+gsF0OlzRaoXfOa5L+RdKPGixjG0kXSPqDwiBwGyn8oT09LmdrSd+Ksb+WtH98fT1Jb5E0Mc7/bI2yl0q6SdIxkt4V40dIertCT87HJL1NUrP9noZ1TJCrHqUys/Uk3a7QXr0p6QcKCakNFHrCvkfSZEmPt1huT7elXOv3bZJ+KumPkvZT6NnbJelEhQMD69VZflbufqaks+PDL5jZCuMTmNnJkr4ZH14h6WiP/8R7m5ltqu7THx5uEv5IvB2du4yaeiMDt7JOanJELcYUM8UfLzw/Pz53Vnx8qKT/UWgAlkn6blU5QxQuR3KLwo/eG5KeURjc4zA1z86vrzDQx0MKjclChRFI946vz6z3XpSY7VVowC5SOB/2RUmvSLpfoXE6VjFTqnDqQLPs6MxCuaMKz49tsPwdJV2icBrCq7EOdys0yiMazDeluEyFhvMKhYbzdYUGfYqkLXuwrXyg8B62L3vbZWIaKBPt8ApxtMP1l3FtXMbLkgbVidmo8D6/Wvb2zdR0e5nSIOZ7MeZ1ScMLz6+lkHByST9pMP/Nhe1l6xqvb1Mox7Viz4ZhCn/OXdJ5dZYxJLYhXt3exNcviK89qzpHPhX+mFfqMKbOenKFke+H1Snj32LMU/VievBZjVQ4Wu2S9m3wWTasY4w9S232PkioR9tlZ15fP4z1WKYweF+9uCGt1L+n21Lmz/m/q+sfY75YiNmmznKafvdbrPd3Css8u/D8SYXnfy5pcAc+67MKy6j5mRRidyrEfr5JbKVX1bO5y6g10bOhfG8W7tc6D8riOXpXSnqvQobPqgI2lzRb0k/Une0dqpBRHK+wQzbVzIqXdirOP1rSXxS+xFsqXNJkI4UuUTPN7Nh231wsf3DMCM5VuJTKNgo/5qsrdGH+gEKWcnxPltOkDicr7NAeJemtklaNddhJISv+gCUM+GVmBysM3niYwtGLYQpHMo6QdJeZ7dhmFSfE2yfc/a+F5Q2KvR4AdA7tMO2wJFXOT11D3W1ytYML9/+7zeWgZGY2VN2Xt5zm7osqr7n7SwpJPkn6mJmtVmP+jRSO8ErS9z30kFqOu98n6cJ6dXD3NyRdEx8eYma19sn3V2hDpO6eEJU6rKHwXZKkSe4+v85ypiscCZfCd6aek2OdahkSb59pENMWd39SIYkrNW9/GtWxN+tRitj74Mj48HJ3r9cDS+6+tIVyc29LteZtZf2eUKf+Uwr3VxgfpBPc/YvqHrfhy3FMii8oJCGk8Kf7k+7+Zs0Cek9x3+K1JrGvxts1q57PUcYKSDaUb7vC/SdqvH6UwtG0/ydpN4Ufne0UB+foaben+CN6s0LG8Q1JZ0raSqE71n4KO4bfU+jW1K7zFHagTWEH8WMKXaLWl7SDpOMk/VYhQyaFzPVa6u6a9KiWH9htLXX/yDdlZocqnKIwSOH6sAequ/vd5xTO015P4Y/AWxsU1cmuXWPi7d/iIDvHmdmfFT67JXGAnJ+YGd12gfxoh2mHpfDH8Ol4/zIz+7yZbW5mq5rZlmZ2prpH4T7P3f/Q5nJQvgnq/hN/eY3XL4u3aysk/Krtqe596F82WE7dP4NVy95EUq1L8Fb+0N3n7rOrXttLIVkoSb81szXrTZLuiXFjVNuz7v7HBvWcG2+3tzAQ5PoNYldgZkPN7Cgzmxr3Z161wqWIJX08hja6AkizOvZWPcq0t0IiWlr+j3dPZdmWMq3fh2sl7yQpJgWfiQ83bvldtu8Lkn4c758u6b/i/RslfaKVxM5KKXeXD6bluphUuqJMqfP6FuruYveSpFULr81XjW47NcroaRe6YvfhQ2vMu4bC0ba670WNB+TZozDvtarRLaoQ21KXr0LcqMIyxla9torCzr8rdBteu8b871D3QDTX13h9SqH8trt2NXkPlTpep3BExetMr4nBI5mYkifaYdrhFreXravWdfV0p6SDy96umZp+jpXtZUqd1ytdgBdJWqXG64MlPRljptV4/cTCNrFOg3qsW4g7sk7Mw7XqqpDQeyW+9h815vt8g+203nRfnfX0p4R1+stCOUsVxk74tkLicIUBAwvzbSzp3sT6/U+DzzKljnXbqwz1qFt2Qr1WUzgCXGtK7n6v7tNZXNJ6Ldah0brJsS3l+px/1+R9zI9xX2nnu9/upJCo/03x/SvzKUUNPjMXp1GgFWa2lpl9QOE83kr3vP9y91pdVp6T9JU65eTo9nRkvP2ju/+sxryvKIy02q4T4u2Lko7yBtm/Rq/1wAcUjnxJ0inu/mKN5c5V95GwA63x5cxOqFPPKYX77XTtWifeHqBwFOUOhQz2agpHX45W2BZWkfTTJoNJAWiCdri2lbwdlocjah9S2KmsZRNJW5rZkDqvo4+LA+wdEB/OkrStme1cnBTGFvltjNkvnjZRVOw6/HKDxTV6raJyesRHqk7Z+LDC0WZX1SkU0To1nmtm1TrPL06Y9yBJ/66QHBksaXdJX1I4uvu0mX0v9vKqdrnC+lyicPnY9yokJ4eru5dUpd1r9L1KqWMjuerRjv9WSGbXmt7dQjnF9ftSttrl2ZZyrd/U0xFqXoKzgz6psF9eMVohwd1X/KNwf8O6Ucu/Xj2ga44yVkCyoXccUdWN6EWFwbgq3Yh+qZC9quW2Oju/Ug+7PcUf3Er34V82qP8tar+Rr3QLvMHdX2izjJ6oNAyLFS7XU0/lvMnBCuu1lk527ap8F4cpdJl+n7vPcvfX3P1Zd/+JwmjPy2LMV9tYBrAyox2mHW7KzL4o6QGF02X+Q+FqBcMVTouZpLCDNVnSf5vZ6vXKQZ92iMLvqBSSDnPrTAfFmMFaMUFYTCI0Ome56fnM6k4krKXlT9moXIViVp0kZrEOq7u7JUyjEupTk7svcfdvu/uWCqd5fUphsNeFCm3g5yX9upiIM7MtJb0vPvw3d/+iu9/m7n939+fc/WV3f1nLnyueXV+pRwbFBMNaGcvt0bY0gNZvTWZ2mMJplIMUTnmcEl/6ppn9e1n1qvKEuj/HRqciSuFKIVIYHDp3GSsg2VCeZxQ22I+6+0caHE1qdOmR4uVG/qr6WdOXFLr8SeEc4Iot1J0ZvK/eQjwMevJgg3rUZGZrqTvzdXer82eyRbx9oMkRu78W7m9RJ6bWudxFlT8C7eyAFhv6b7j769UB7v57de+ov9/M6h2hAJCGdrh39It22MyOVBjbYpCkA919srvfF3eW73f3byh0GZfCjvUZrS4DfcKnMswzv3C/0VhK2zZ4TdI/e9P8KT48TJLMbGN1Jwlr9WqQlm+Xtmy2nJzc/SF3v9zdj1EYM+WC+NIYdfcakUK37IpGlzxsd1DXVKXWw93HNvjTPrOFoh4q3N85YxV7ui31lc85OzM7RGGcoEEKPVQ+qtCbsTLeyjlmdlJJ1fsnD+c2VAY53r1enJltpjCwsiTNyV1GLSQbeseV6u5CtKbCuaYbuvsHvMFIslGjI1k97faU2g0w5fVaOtXdqxWVzG+z+qdkizvZtavYdel3DeIq3TqHqWeDxQErG9ph2uFmTou3t7n7/9QKiM9XToc52sx6uysvesDMtlIYw0SSvtXs6K3CeCqStJMtf5WT3yv0NJTC6Q71fCSxapWEwvvNbISkTyj0qHhDcSDaGn6jMCiqtPxVUnpVTCCeVXiqmGBZpXC/5pW1zGwPNT+K2lN9pR49dYe6rxJwRMZye7otDZT1uxwz+5hCUmGwwimXH3b3N9x9maRPq/u0kO9YuEJF2X4Vb98WTwer5eOF+zd1qIzlkGzoHUsrXYjc/RXPd3mUnnahS+0GmPJ6LZ3q7tVOHVp5f2XskBePaD7XIK74Wq1zIwHURjtMO1xXPJ2lckpN9aj/1e6KtyO0fC8V9H3FHgorjI9Sw1XqTir8c153f0rdlz79vJmtMLp+vHrUvybW6xcKgy4OVfizVzltY6q719wn8DD2ySXx4ZfM7F2NFmBma5vZyMT6VM+7ldW+NGdF8Wh48RzuRwr3P1Cj3DUl/aCdOrWor9SjR+JnPiU+/JSZfbBebCvjymTYlgbE+i0ysw9L+rnC+BK/lvShYq/juA/xKYU2QpL+y8w+1+sVXd5P1b1P8c3qZLiZDZf05fjwj+7+Z60oRxnLIdnQv/W029PfFQYekhp0AzSzwQrn57UkNl6Vy4jt3Or8mcyPt1s3aXi3rzFPb7qrcL/RJaWKr5Vx7jWA5dEONzc/3vbldrjY26RZb4XivpPXjUKfEneaK+Mg/K+7/6XZPO7+mLp7Gx4av4cVX1boebCGpJlm9kkzGxmnIyTNVLgKS1Pu/oyk6fHhyZJ2jffrnUJRMUnhnOnVJN1mZueb2Z5mtoGZDTez0Wb2cTO7VNJjkhr+iWyynHlm9k0zG29mXWa2roVLwn5a3Zf4fEXdR0alsG9T+SN6gZn9q5m9xcw2NLMDFa7uspMSzvvuob5SjxxOV3gvJulaM7vAzHY3s/Xj+9nLzM5S66fN9WRbGkjrt9ID6iqFRMNMhdPqXq2OiwmHwxWu8iRJ3zOz9/RguduZ2R6VSdJmhZffUXzNagyiHNuRr8WH+ytsHzub2QgzGx/fy8YKic0v1apDjjKqkWzo33rU7Slmy/8WHzbqBri/2huHQJIqXVE/ZLVHKW5kSbyt2SUr0R3xdnWF91HPx+LtmwrdI3vbjYX7jRqqsfH2FYVBzACUi3a4uf7QDv9D3e9110aB6h7g82UljMSNPmMfhdHxpXDEMlUldhN1D4Ind/+bQg+EJZJGSrpMYUyRJxSOPq+i1tqEyjnglTo+L2lqoxk8DPg6TuGqGsMUxoW5UyHB+KxCr8mrFbp8r6WQHGnXKIWr4twq6VGFnpYPSbpUYdyGVyUd7u5PFur3psL57a8p9Ma8UCFB+5TCfs/2CsmVP/SgXk31lXrkEH8z9lW4RO8Qhcth/kGhDXtKYVv4ipb/o5pSbtvb0kBav5Lk7g9K+pZCovEAd697KmU8jegTCgm3/6fu053b8QOF377KdFThteurXptYpz7nSPphfPgRhcFun1H43u6o8Lkd5e6z6lUiRxlFJBv6sUxd6KbE293N7NAa86yh8IVrV2XQoLUlXVx1VKB6WdVHvCo7cRu00h2sys3qPqp3dhwsrXq5O0n6bHx4Y8zq9Sp3v1fdDdR/WI1Rzs1snKT3x4fXu/uS6hgAvYt2OEmfb4dje1ppg8eZ2b614uKRnbHx4Yx47i76h+IpFK0kG65RdyJquYEi3f1aSbvE8hYq7IQ/qtAm7OruxV6LzdykcJWcfy631mDR1eKf+3crJCuvict/LdblSYUjkf8paWd3b3p+dR2nKFz6b4rCEfOnFI5svqTwR+RcSdu4+w016ne7wjgZ16o7qfeEpOskjXP389usU0v6Sj1y8HB1kl0ULos+XaF9XRJv/6xw6cnxbZTb9rY0kNavJLn7mZLGe7j0dLPYpQpXuTkmDrJYKnf/rMLpLNMUvquVdun/SRrj7pf1RhnFwpg6NCl0r3RJU9qYd36c96wmcesoZBtd4eja+ZL2VDiPdLjCSOkfV8g8vyjpY1Xzr6Ywwnll/jMUugKvr9BQ3RWff6zee1EYGMglza9Tx/ML6+JOhSzZppLWUxhI6DOSbpP0war53lmY7xsKRxWGKmRyBxfiRhXixtZY/qGF1+9RGCl5A4VM/PEKjaIr/GhuWWP+KfH1mTk+swbz76JwZMDjet8vfg6bS/pC/Pxc0iJJo8revpmY+sNEO/zP12mHm3/e71HoVeEKvcdOi5/duvH2dIXBQl1hx2uXsrdvprqfZWV7mVJ2XZiYmHpv4rvf96Z2j1Kgj3D3F+IR72sUzp86Ud2XV6tluS507v6qmR2gMLr2SIWM5X8WQpYpDHD0CXVf5qRVJyucW3aCwg74dXXiLig+cPe7zOxOheutn6bukcKl0HV5bMrC3f1nZraJpLMlvV3Ln09Y8ZzC4C/zUsrsBHf/s5l9QqEr5Rh1n79Z9IzCaLjze7NuAOqjHW6uP7TD7v4bMztGoSvr6grJlW/UCH1F0qc9YWAsAABWZpxGMQB4D7vQufv9CufgfEfhHKs3FLpi/UrSvu7+ox7W7013P1HhCNmlkuYpHMF/WWGwmOsUuh/V+nM9QaF73t/iPO3W4duS3hGX/4jC+nlZ0r0KO5Nbu3tPzrPKwkMXxLdL+p7Ckc7KerpbYcCWbT3xHCkAvYd2OKkOfb4ddvdLFc4v/rZCd+QXFXo7vKhwPfFvSdrO3a8prZIAAPQT5l76qSUAAABAFmY2RdIRkn7q7keWWxsAvYXvft9DzwYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVYzYAAAAAAICs6NkAAAAAAACyGtJK8IgRI3zUqFEdqgoAtGf+/Pn6xz/+YWXXozfQDgPoq+bMmfMPd9+g7Hr0Btri/N54442kuGHDhpVaZie88MILybELFy7MvvzRo0dnLxPl6Gv7xC0lG0aNGqXZs2d3qi4A0JYxY8aUXYVeQzuMnMo+ldKsz+wPIQMz+3vZdegttMX5LViwICmuq6ur1DI7Ydq0acmx5557bvbl33777dnLRDn62j4xp1EAAAAAAICsSDYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICshpRdAQAAyuTuSXFmVtqyW1l+K2WWrT/VNbdObE9Af/aXv/wlKa6rqyu5zNTYVr6Pqe3WtGnTkss899xzk2NnzpyZHJtqILZHU6dOTYqbMGFCh2uycqNnAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyGpI2RUAAKA/cPd+sXwz63BNGuvEemrlPZX9OaVqpZ5lf6ZYOSxYsCAprqurq8M1aazs70Pq8lv5jk+YMKHd6tTVifXUn9qtc889N3uZnficBjp6NgAAAAAAgKxINgAAAAAAgKxINgAAAAAAgKxINgAAAAAAgKxINgAAAAAAgKxINgAAAAAAgKxINgAAAAAAgKxINgAAAAAAgKxINgAAAAAAgKxINgAAAAAAgKyGlF0BAABSuHtyrJllX34nyhyIyl5PZS8f6EsWLFiQHNvV1ZUU18p3rJV2e+LEidnLXJmVvZ7KXj76Bno2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArIaUXQH0D0uXLk2Ke+ihh5LLnDRpUnLs9ddfnxQ3aFB6/uzss89OijvppJOSyxw8eHByLIDWmFm/Kje3ZcuWJcX1p3b4nHPOSYo78cQTk8tM/TzdPXuZrZbbieVj5bBgwYKkuK6uruzL7sQ23slyc+P7mCb18+xU+5pa7tSpU5PLnDBhQnIsAno2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArIaUXQHk99JLLyXFXXfddcll3nLLLUlx1157bXKZrRg0KC0vZmbJZZ566qlJcYcddlhymSNHjkyOBQYyd0+ObeV7W6ZW3lNqO3z11Vcnlzl9+vSkuE61w6mWLVuWHHvKKackxR166KHJZW600UZJca1sd6189kBv6OrqSoobiNt5f/nN6JTx48cnxc2YMaPDNcmjle1u0qRJHawJOoGeDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAIKshZVcAae6///7k2AMPPDApbt68ee1WZ6Vy/fXXJ8d+7nOf62BNgPK5e1KcmXW4JvmkvqdW2uGJEycmxT388MPJZQ5ES5YsSYq76qqrksv8whe+0G516upP2zP6ngULFiTFdXV1JZc5adKkpLjU9q0v4HuGZiZPntyRWHQOPRsAAAAAAEBWJBsAAAAAAEBWJBsAAAAAAEBWJBsAAAAAAEBWJBsAAAAAAEBWJBsAAAAAAEBWJBsAAAAAAEBWJBsAAAAAAEBWJBsAAAAAAEBWJBsAAAAAAEBWQ8quwED00ksvJcdecsklSXGnn356cplvvPFGUpyZJZfZCccdd1xy7OzZs5Pi5syZ02516nr44YezlwkMdO5e6vIXL16cHHvRRRclxZ122mntVqeuoUOHZi+zFa20w7NmzUqKmzt3brvVqauVdjh12yv7NxDoiW984xtZ49B/zJgxo+wqJEltYzu1v7BgwYKkuK6uro4sHwE9GwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFZDyq7AQHTSSSclx06ZMqVzFcno7LPPToo7/vjjk8tcffXVk2MXL16cFLfvvvsmlzlnzpykuB//+MfJZX7ta19LimvlvQOd5u7ZyzSzUpffiXa47Pd0zjnnJMUdd9xxyWWuttpqybEvv/xyUtw+++yTXObcuXOT4lpphydPnpwUt8YaaySX2QmtbCOtbHvovzrxObeynY0bNy4pbubMmW3WBvWMHz8+OXbGjBkdrMnA0dXVlRQ3bdq05DInTJjQbnVWWvRsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWQ0puwL9xaJFi5Jjf/e73yXHuns71Wlo5MiRSXFz5sxJLnOjjTZqtzpZrLHGGklxX/3qV5PLnDhxYlLc4sWLk8u8++67k+L22muv5DKBvsTMSlv2s88+mxx7xx13JMd2oh3eZJNNkuJmz56dXOaIESParU5drbz3NddcMynu61//enKZqe3wq6++mlzm3Llzk+L23nvv5DLLlvo5tfL97ESZqK2rqyt7mZ1ot2bOnJm9TKSZMWNGcmwnvrvjx49Pjk3Vynsq07nnnpu9zAkTJiTHTps2LXuZfQk9GwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFZDyq5A2V577bWkuHe/+93JZc6bNy851syS4jbeeOPkMufMmZMUt9FGGyWX2V+MHTs2OXa33XZLirvrrruSy7zzzjuT4vbaa6/kMoFOS22HWuHuybGvv/56Uty73vWu5DIfeeSR5NjU9z9y5MjkMmfPnp0UN2LEiOQy+4tx48Ylx44ZMyYpLnV9StKsWbOS4vbee+/kMlvRyrYPFHVi2+lE+46BZ8aMGWVXIbtp06Ylxc2cOTO5zNTYVr7Lqb9ZEyZMSC6zL6FnAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyIpkAwAAAAAAyGpI2RUo29KlS5Pi7r///g7XpLELL7wwOXajjTbqYE36NjNLjh0/fnxS3F133dVudYB+wd1LXf6SJUuS4h544IHkMocOHdpudeo6//zzk2NHjBiRffn9RSvt8H777ZcUN3v27HarU1fZ231/wXrqPdOmTUuOnThxYgdrAvSeVn4zBqI777wzKW7cuHFJcWX/Z61GzwYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJDVkLIrULavfOUrpS5/2223TYrbf//9O1yT3rdw4cLk2O9+97tJcdOnT08u8y9/+UtybO4yf//73yeXueeee7ZbHSCJmSXHunv25Z9xxhlJcUOHDs2+bEnabrvtkuIOOOCAjiy/TK20w9///veT4lpph+fMmZMcmyq1Hf7Tn/6UXOZuu+3WbnX6rNTvfSe+8yubBQsWJMVNmDChwzXBQNDKbzb6vttvvz0prr9+7vRsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWZFsAAAAAAAAWQ0puwJlu/fee0td/n777ZcUt8oqq3S4JnksXLgwOXa33XZLjn388ceT4swsucxOuPLKK5PifvaznyWXedJJJyXHnn766Ulx6667bnKZGPjcvdTl33PPPaUuf+zYsUlxQ4cO7WxFMmmlHd51112TY5955pmkuFba4U6s06uvvjop7pprrkkusxPt8Nprr51cZpnK/l1dmUybNq3sKgArnbL3gVKl1nPMmDEdrklr6NkAAAAAAACyItkAAAAAAACyItkAAAAAAACyItkAAAAAAACyItkAAAAAAACyItkAAAAAAACyItkAAAAAAACyItkAAAAAAACyItkAAAAAAACyItkAAAAAAACyGlJ2Bco2fvz4pLjbbrutI8t3946Um9tjjz2WFLfbbrsll/nUU08lxy5btiwpbtCg/pE/e/PNN5NjzzvvvOyxJ5xwQnKZ3/nOd5Jj0T+ZWXJsJ9qs/fbbLynujjvuSC5z6dKlybEDrR0eM2ZMcpmLFi1KjqUdzhvbSjv87W9/OymuE9tyK+0Dauvq6soa15+k7mdL0owZMzpYE6C2TuwDtVLm1KlTk+ImTJiQXGZf0j/2CAAAAAAAQL9BsgEAAAAAAGRFsgEAAAAAAGRFsgEAAAAAAGRFsgEAAAAAAGRFsgEAAAAAAGRFsgEAAAAAAGRFsgEAAAAAAGRFsgEAAAAAAGRFsgEAAAAAAGQ1pOwKlM3MssZ1avmdsGTJkuTYY489Ninu6aefTi6zlfc+bNiwpLiPfOQjyWUef/zxybGp7rjjjqS4H/zgB8llPvXUU+1Wp67p06cnxy5atCgpbvjw4e1WByu5VVZZJSluILbDy5YtS4496qijkuKee+655DJbee+pn9MBBxyQXOZnP/vZ5NhUf/zjH5PiWmmHFy5c2G516rr11luTY59//vmkuHXXXTe5THdPjgXaNWPGjLKrkGTs2LHJsTNnzuxYPdC3pf5mttK+dqLMvoSeDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICuSDQAAAAAAICtz9+TgMWPG+OzZsztYnd43f/78pLi3ve1tHVl+V1dXUtwdd9yRXOamm26aFPfEE08kl7n55psnx3bCZZddlhR36KGHdrgmeTz++OPJsccdd1xy7C233NJOdRraaqutkuL+7//+L/uyU40ZM0azZ8+20irQi8puh1v5zUi1YMGCpLhOtcMbb7xxUtysWbOSy9xss82S4p588snkMrfYYoukOLPOfBUuv/zypLiDDz44uczU7amV95Ra5mOPPZZc5jHHHJMcO3369OTYVFtuuWVS3EMPPZR92a0wsznuPqbUSvSSstviTn3PV1ZTp07NXubEiROzl9mKVn6vU7enTpQ5EHViXylVX9snpmcDAAAAAADIimQDAAAAAADIimQDAAAAAADIimQDAAAAAADIimQDAAAAAADIimQDAAAAAADIimQDAAAAAADIimQDAAAAAADIimQDAAAAAADIimQDAAAAAADIakjZFSjbxhtvnBQ3evTo5DLvv//+5NgFCxYkxe22227JZT7++OPJsf3FoYceWnYVstp0002TY6+88srk2J133jkpLnW7k6QHH3wwORZoR2o7vM022ySXed999yXHLly4MCnune98Z3KZTz75ZFLcoEHpOX8zS47thIMPPjh7mZ14T6lldnV1JZd51VVXJcfuuOOOSXGttMOPPvpocmxu7l7asoHeMnHixLKrkKwT38n+Uua0adOSY/vTZ5oi9b2/8MILHa5Ja+jZAAAAAAAAsiLZAAAAAAAAsiLZAAAAAAAAsiLZAAAAAAAAsiLZAAAAAAAAsiLZAAAAAAAAsiLZAAAAAAAAsiLZAAAAAAAAsiLZAAAAAAAAshpSdgXKtuqqqybF/ehHP0ou85hjjkmOffDBB5PinnrqqeQyt91226S4iy++OLnMst14441JcR/84Ac7XJPeN2fOnOTYp59+OvvyzzjjjOxlAkVDhw5NimulHT766KOTY1Pb4UWLFiWXucMOOyTFXXLJJclllu3mm29OijvggAM6XJPed/fddyfHtvJ7nerUU0/NXmYqMytt2QBWlPqddPcO16T3TZw4sewqlGbChAlJcWeeeWaHa9IaejYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICsSDYAAAAAAICshpRdgf5i7733To7dd999k2MffvjhpLilS5cml3nfffclxb3nPe9JLrNshx9+eFLcXXfdlVzmZptt1m516nrwwQeT4m644YbkMr/+9a+3WZv6VlttteTYD33oQ9mXD7Rjjz32SI5tpR1+5JFHkuKWLFmSXGZqO7zPPvskl1m2TrTDm266abvVqavsdjj193qttdZKLvOggw5Kjs3N3Utb9spmwYIFZVcBA4iZlV2FUo0fPz4pbsaMGR2uSR799fOkZwMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMhqSNkVGIguvPDC5NitttoqKe7kk09OLnPQoLQckpkll1m2xYsXJ8XtsMMOHa5JHu6eHNuJz+nMM89Mjt1pp52yLx/otAsuuCA5dvTo0UlxX/ziF5PLHIjt8CuvvJIUt/3223e4JnmU3Q5/5StfSY7tL79t6D1Tp05Nips4cWKHawL0TTNmzMheZiu/Gwjo2QAAAAAAALIi2QAAAAAAALIi2QAAAAAAALIi2QAAAAAAALIi2QAAAAAAALIi2QAAAAAAALIi2QAAAAAAALIi2QAAAAAAALIi2QAAAAAAALIi2QAAAAAAALIaUnYFVnZHH310Utzaa6+dXObZZ5+dFDdv3rzkMtE/bLjhhklxxx57bIdrAvQfRxxxRFLcKqusklzm+eefnxT30EMPJZfZCUuWLCl1+Z0wdOjQUpef2g6n/v4DtcyaNSspbuzYscllzpw5s73KoFeNHz8+KW7GjBkdrgnQHD0bAAAAAABAViQbAAAAAABAViQbAAAAAABAViQbAAAAAABAViQbAAAAAABAViQbAAAAAABAViQbAAAAAABAViQbAAAAAABAViQbAAAAAABAViQbAAAAAABAVkPKrsDKbs0110yK+8xnPpNc5kEHHZQUd9111yWXecsttyTFXXPNNcllIs2IESOSY88+++ykuHXXXbfN2mBlZ2alLdvdO1Juajt89NFHJ5d5yCGHJMVdffXVyWVOnz49KW5lb4eXLFmSvcz1118/OTa1HV5nnXXarQ4GqK6uruTY448/Pilu8uTJ7VanrjJ/BzDwjB8/Pjl2xowZ2ZffqX2L3KZOnZoUd8IJJ3S4Jq2hZwMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMjK3D05eMyYMT579uwOVgd91dKlS5PiXn/99Y4s/5577kmKu/POOzuy/Nze9773JceOHj06OXa11VZrpzr93pgxYzR79mwrux69oex2OPU3w2yl+Dh6rJXf4LLb4bvvvjspbtasWR1Zfm7jx49Pjt1mm22SY1Pb4YH4HTGzOe4+pux69Iay2+JJkyYlxU2ePLnDNRkYBuL3EWla+R3uhAULFmQtb+LEibr33nv7zAZNzwYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJDVkLIrgP5hyJC0TSU1rlV77bVX1jgA7TGzpDh3z17mQNTKex86dGjWuFbtvffeWeNa0cr2VLaVeXtG75k8eXJS3Lhx45LLvP3229utTr/Xn9qYMvWn9q2/fKZdXV1Zyxs2bFjW8nqKng0AAAAAACArkg0AAAAAACArkg0AAAAAACArkg0AAAAAACArkg0AAAAAACArkg0AAAAAACArkg0AAAAAACArkg0AAAAAACArkg0AAAAAACCrIWVXAAAw8JhZ2VXIzt2TYwfi+y8T6xNoz+233152FbKbNGlScuzkyZM7WJOVTyu/g4BEzwYAAAAAAJAZyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJAVyQYAAAAAAJDVkLIrAABAf2BmZVchO3dPjk19/50ocyBiPQHtmTx5ctlVyG7SpEnJsccff3xS3Oabb55cZmp7tGDBguQyu7q6kmPLNG3atOTYCRMmdLAmAxM9GwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFYkGwAAAAAAQFZDyq4AAADo+9y9X5TZX5hZ2VUA0A/98Ic/zF7mytwePfroo2VXYUCjZwMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMiKZAMAAAAAAMjK3D092OwZSX/vXHUAoC1buPsGZVeiN9AOA+jDaIsBoFx9qh1uKdkAAAAAAADQDKdRAAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArEg2AAAAAACArP4/8ZHlw9G4ueQAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "_, axes = plt.subplots(nrows=1, ncols=3, figsize=(6*3, 6))\n", "\n", "axes[0].set_title('Clean image \\n Prediction %s' % int(prediction_clean))\n", "axes[0].imshow(img_clean, cmap=plt.cm.get_cmap('Greys'), vmax=1, vmin=0)\n", "axes[1].set_title('Adversarial image \\n Prediction %s' % prediction_adversarial)\n", "axes[1].imshow(img_adversarial, cmap=plt.cm.get_cmap('Greys'), vmax=1, vmin=0)\n", "axes[2].set_title(r'|Adversarial - clean| $\\times$ %.0f' % (1/FLAGS.epsilon))\n", "axes[2].imshow(jnp.abs(img_clean - img_adversarial) / FLAGS.epsilon, cmap=plt.cm.get_cmap('Greys'), vmax=1, vmin=0)\n", "for i in range(3):\n", " axes[i].set_xticks(())\n", " axes[i].set_yticks(())\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.1 (main, Dec 23 2022, 09:28:24) [Clang 14.0.0 (clang-1400.0.29.202)]" }, "vscode": { "interpreter": { "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a" } } }, "nbformat": 4, "nbformat_minor": 4 }