{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "executionInfo": { "elapsed": 2, "status": "ok", "timestamp": 1671547917546, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "DwcrPoWd4r4I" }, "outputs": [], "source": [ "# Copyright 2022 Google LLC.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "gmlhojcwL5oT" }, "source": [ "`jax.experimental.pjit` example using JAXopt.\n", "=========================================\n", "The purpose of this example is to illustrate how JAXopt solvers can be easily\n", "used for distributed training thanks to `jax.experimental.pjit`. In this case, we begin by\n", "implementing data parallel training of a multi-class logistic regression model\n", "on synthetic data.\n", "\n", "**NOTE: `jax.experimental.pjit` is not yet supported on Google Colab. Please connect to Google Cloud TPUs to execute the example.**" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "form", "executionInfo": { "elapsed": 54, "status": "ok", "timestamp": 1671547917733, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "3Ck_x_LyL7PE" }, "outputs": [], "source": [ "#@markdown The number of optimization steps to perform:\n", "MAXITER = 100 #@param{type:\"integer\"}\n", "#@markdown The number of samples in the (synthetic) dataset:\n", "NUM_SAMPLES = 50000 #@param{type:\"integer\"}\n", "#@markdown The number of features in the (synthetic) dataset:\n", "NUM_FEATURES = 784 #@param{type:\"integer\"}\n", "#@markdown The number of classes in the (synthetic) dataset:\n", "NUM_CLASSES = 10 #@param{type:\"integer\"}\n", "#@markdown The stepsize for the optimizer (set to 0.0 to use line search):\n", "STEPSIZE = 0.0 #@param{type:\"number\"}\n", "#@markdown The line search approach (either `'zoom'` or `backtracking`), ignored if `STEPSIZE > 0.0`:\n", "LINESEARCH = 'zoom' #@param{type:\"string\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "bZmsaMOcONjQ" }, "source": [ "# Imports and TPU setup" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "executionInfo": { "elapsed": 1, "status": "ok", "timestamp": 1671547917854, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "J9uYAd6TBEEG" }, "outputs": [], "source": [ "%%capture\n", "%pip install jaxopt" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "executionInfo": { "elapsed": 40809, "status": "ok", "timestamp": 1671547958814, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "JJuP-Wz_MBeJ" }, "outputs": [], "source": [ "import time\n", "from typing import Any, Callable, Tuple, Union\n", "\n", "# activate TPUs if available\n", "try:\n", " import jax.tools.colab_tpu\n", " jax.tools.colab_tpu.setup_tpu()\n", "except KeyError:\n", " print(\"TPU not found, continuing without it.\")\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "\n", "import jaxopt\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "import numpy as np\n", "\n", "from sklearn import datasets" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "executionInfo": { "elapsed": 2591, "status": "ok", "timestamp": 1671547961531, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "Wi9vI0SAMEMX", "outputId": "197ca951-400d-48c7-bf75-24ea7bc3a0c1" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.local_devices()" ] }, { "cell_type": "markdown", "metadata": { "id": "RF4MUVeTNjah" }, "source": [ "# Type aliases" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "executionInfo": { "elapsed": 4, "status": "ok", "timestamp": 1671547961666, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "_NmZfvmkNlI-" }, "outputs": [], "source": [ "Array = Union[np.ndarray, jax.Array]" ] }, { "cell_type": "markdown", "metadata": { "id": "JkhZJHaxNmmU" }, "source": [ "# Auxiliary functions\n", "A minimal working example of how to create a `Mesh` for data parallel execution using `pjit`. Please note that, as opposed to `pmap`, `pjit` allows to seemlessly combine data and model parallel execution as well." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "executionInfo": { "elapsed": 68, "status": "ok", "timestamp": 1671547961887, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "wd_3OqM0NoXW" }, "outputs": [], "source": [ "from jax.sharding import Mesh\n", "from jax.sharding import PartitionSpec\n", "from jax.experimental.pjit import pjit\n", "\n", "\n", "def setup_data_parallel_mesh():\n", " global_mesh = Mesh(np.asarray(jax.devices(), dtype=object), ['data'])\n", " jax.experimental.maps.thread_resources.env = (\n", " jax.experimental.maps.ResourceEnv(physical_mesh=global_mesh, loops=()))\n", " \n", "setup_data_parallel_mesh()" ] }, { "cell_type": "markdown", "metadata": { "id": "nKnR5yfpNpvw" }, "source": [ "# Custom-loop\n", "The following code uses data-parallelism in the train loop. Through the `use_pjit` keyword argument we can deactivate this parallelism. We'll use this feature later to benchmark the impact of parallelism." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "executionInfo": { "elapsed": 74, "status": "ok", "timestamp": 1671547962097, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "EXiLgFO_R8_c" }, "outputs": [], "source": [ "def fit(\n", " data: Tuple[Array, Array],\n", " init_params: Array,\n", " stepsize: float = 0.0,\n", " linesearch: str = 'zoom',\n", " use_pjit: bool = False,\n", ") -> Tuple[np.ndarray, np.ndarray, float]:\n", " \"\"\"Fits a multi-class logistic regression model for demonstration purposes.\n", "\n", " Args:\n", " data: A tuple `(X, y)` with the training covariates and categorical labels,\n", " respectively.\n", " init_params: The initial value of the model's weights.\n", " stepsize: The stepsize to use for the solver. If set to `0`, linesearch will\n", " be used instead.\n", " linesearch: The linesearch algorithm to use. If `stepsize > 0`, linsearch\n", " will be disabled.\n", " use_pjit: Whether to distribute the computation across replicas or use only\n", " the first device available.\n", "\n", " Returns:\n", " The per-step errors and runtimes, as well as the JIT-compile time for the\n", " solver's `update` function.\n", " \"\"\"\n", " # When using `pjit` to distribute the computation across devices, it is not\n", " # necessary to override the `value_and_grad` of `fun` (though it is supported\n", " # if desired for other reasons, e.g. gradient clipping).\n", " solver = jaxopt.LBFGS(fun=jaxopt.objective.multiclass_logreg,\n", " stepsize=stepsize,\n", " linesearch=linesearch)\n", " # Apply the `jax.pmap` transform to the function to be computed in a\n", " # distributed manner (the solver's `update` method in this case). Otherwise,\n", " # we JIT compile it.\n", " if use_pjit:\n", " update = pjit(\n", " solver.update,\n", " in_axis_resources=(None, None, PartitionSpec('data')),\n", " out_axis_resources=None)\n", " else:\n", " update = jax.jit(solver.update)\n", "\n", " # Initialize solver state.\n", " state = solver.init_state(init_params, data=data)\n", " params = init_params\n", " # When using `pjit` for data-parallel training, we do not need to explicitly\n", " # replicate model parameters across devices. Instead, replication is specified\n", " # via the `in_axes_resources` argument of the `pjit` transformation.\n", "\n", " # Finally, since in this demo we are *not* using mini-batches, it pays off to\n", " # transfer data to the device beforehand. Otherwise, host-to-device transfers\n", " # occur in each update. This is true regardless of whether we use distributed\n", " # or single-device computation.\n", " if use_pjit: # Shards data and moves it to device,\n", " data = pjit(\n", " lambda X, y: (X, y),\n", " in_axis_resources=PartitionSpec('data'),\n", " out_axis_resources=PartitionSpec('data'))(*data)\n", " else: # Just move data to device.\n", " data = jax.tree_map(jax.device_put, data)\n", "\n", " # Pre-compiles update, preventing it from affecting step times.\n", " tic = time.time()\n", " _ = update(params, state, data)\n", " compile_time = time.time() - tic\n", "\n", " outer_tic = time.time()\n", "\n", " step_times = np.zeros(MAXITER)\n", " errors = np.zeros(MAXITER)\n", " for it in range(MAXITER):\n", " tic = time.time()\n", " params, state = update(params, state, data)\n", " jax.tree_map(lambda t: t.block_until_ready(), (params, state))\n", " step_times[it] = time.time() - tic\n", " errors[it] = state.error.item()\n", "\n", " print(\n", " f'Total time elapsed with {linesearch} linesearch and pjit = {use_pjit}:',\n", " round(time.time() - outer_tic, 2), 'seconds.')\n", "\n", " return errors, step_times, compile_time" ] }, { "cell_type": "markdown", "metadata": { "id": "HZjIR0q9NsLg" }, "source": [ "# Boilerplate\n", "Creates dataset, calls `fit` with and without `jax.pjit`, makes figures." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "executionInfo": { "elapsed": 74, "status": "ok", "timestamp": 1671547962305, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "4DTLFIPr4zfG" }, "outputs": [], "source": [ "def run():\n", " \"\"\"Boilerplate to run the demo experiment.\"\"\"\n", " data = datasets.make_classification(n_samples=NUM_SAMPLES,\n", " n_features=NUM_FEATURES,\n", " n_classes=NUM_CLASSES,\n", " n_informative=50,\n", " random_state=0)\n", " init_params = jnp.zeros([NUM_FEATURES, NUM_CLASSES])\n", "\n", " errors, step_times, compile_time = {}, {}, {}\n", "\n", " for use_pjit in (True, False):\n", " exp_name: str = f\"{'with' if use_pjit else 'without'}_pjit\"\n", " _errors, _step_times, _compile_time = fit(data=data,\n", " init_params=init_params,\n", " stepsize=STEPSIZE,\n", " linesearch=LINESEARCH,\n", " use_pjit=use_pjit)\n", " errors[exp_name] = _errors\n", " step_times[exp_name] = _step_times\n", " compile_time[exp_name] = _compile_time\n", "\n", " plt.figure(figsize=(10, 6.18))\n", " for use_pjit in (True, False):\n", " exp_name: str = f\"{'with' if use_pjit else 'without'}_pjit\"\n", " plt.plot(jnp.arange(MAXITER), errors[exp_name], label=exp_name)\n", " plt.xlabel('Iterations', fontsize=16)\n", " plt.ylabel('Gradient error', fontsize=16)\n", " plt.yscale('log')\n", " plt.legend(loc='best', fontsize=16)\n", " plt.title(f'NUM_SAMPLES = {NUM_SAMPLES}, NUM_FEATURES = {NUM_FEATURES}',\n", " fontsize=22)\n", "\n", " plt.figure(figsize=(10, 6.18))\n", " for use_pjit in (True, False):\n", " exp_name: str = f\"{'with' if use_pjit else 'without'}_pjit\"\n", " plt.plot(jnp.arange(MAXITER), step_times[exp_name], label=exp_name)\n", " plt.xlabel('Iterations', fontsize=16)\n", " plt.ylabel('Step time', fontsize=16)\n", " plt.legend(loc='best', fontsize=16)\n", " plt.title(f'NUM_SAMPLES = {NUM_SAMPLES}, NUM_FEATURES = {NUM_FEATURES}',\n", " fontsize=22)\n", "\n", " return errors, step_times, compile_time" ] }, { "cell_type": "markdown", "metadata": { "id": "fx1BGKFBNukq" }, "source": [ "# Main" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "height": 1000 }, "executionInfo": { "elapsed": 30301, "status": "ok", "timestamp": 1671547992729, "user": { "displayName": "Felipe Llinares", "userId": "01655756739108476525" }, "user_tz": -60 }, "id": "haEIB6LV0ZA5", "outputId": "8eb9bc9f-c08f-48f0-dcf2-76eb1995654f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "num_samples: 50000\n", "num_features: 784\n", "num_features: 10\n", "maxiter: 100\n", "stepsize: 0.0\n", "linesearch (ignored if `stepsize` > 0): zoom\n", "\n", "Total time elapsed with zoom linesearch and pjit = True: 3.79 seconds.\n", "Total time elapsed with zoom linesearch and pjit = False: 16.8 seconds.\n", "Average speed-up (ignoring compile): 9.57\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAm4AAAGYCAYAAAD7i26KAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAAsTAAALEwEAmpwYAACXUElEQVR4nOzdd3ib1dnA4d+RvPce8Yiz93YWITtAgIQ9C5RQRoFCS+dXyii0pbtl75Gwyp5hJySBBEz23stOHCdOvPeSzvfHK2/ZlmzJkp3nvi5dtt5x3keyJD86U2mtEUIIIYQQ3s/k6QCEEEIIIYRjJHETQgghhOghJHETQgghhOghJHETQgghhOghJHETQgghhOghJHETQgghhOghJHFzkFIqUymlbbfz2zluh+2YWS22r7JtX9TBdZbYjnugxfYHmlw/WyllbqeMs5ocq5VSaQ48xPZiSlZK/VsptV0pVaaUqrLFsF4p9aRS6jIHyrilSTx3dXDsoibHViqlIto5dlCLxzqrxf4HWuzXSimLUipPKbVCKXWDUkq1OGeW7bjMjh5Xk3NaXqOtW5qdc89WSr1re06rlVIlSqkDSqkvlVL3K6VGOBpHd2nxN2rrltDO+ZOVUh8opU7aXk/7lVL/VEqFd3DdHnGes5q877VS6l/tHPdEB58PSzq4Tv3fbVWL7bNU87/d0HbKCFNKVTQ5dpEjj7GNsuy9P1vePmxxTqYD57T6LGhRRniTx7DFzn5Hr9H0tsp2rt3n2M417H7OKKXS2ii/TCm1Wxmfuf3bKbfTz49SapJS6hWl1CHb671MKXVYGf+//qqUmtLeY/IEJx7v/XbONSulblNKfaeUKlJK1dre658rpS5yIoaRyvjs1kqpHS59gHb4uPsCvdTflFKfa62tHrp+EnAW8EUb+xe56kJKqRnAJ0AokA98D5wCIoGxwO3AlcC7HRT1kxa/P+JgCAHAVcAzbexf5GA5B4E1tt/9gSHAbNvtQqXUpVpri4Nltec9oKyd/c32KaUeAX5hu7sT2ABUAqnAdOBsIAz4jQtic4emz2tLlfY2KqWuBl4FzMB3wDFgCvBb4GKl1DSt9cmeep4L3KGUelRrne2Gsh21CPh9G/uuBAJdfL32Xkeb2tj+JXCinTLb23c1jY9hjFJqvNa66XXeBWJanBMCXGr7/WU7Ze5p53qd1fQ6fYBJGJ+5i5RSZ2utv2vnXKeeH6XUr4B/Awo4BCwDSoFEYDwwExgMdPhFvZvZ+1vViwIW2n5f2XSHUsoH+ByYB1RjvP7ygP7AfGC+UuoRrfUv27u4rZwlgG8n43ee1lpuDtyATEAD5baf17Vx3A7b/lkttq+ybV/UwXWW2I57oMX2B2zb19t+vtnG+WFABbALI0nQQFonH7M/xj8rjfGGDrBzzATgbx2UM8xWRhlQZPs9vZ3jF9mO2QTUAT+0cZwJOAoUAPvbeN7rn7clds6/GLDY9t/UZPss27ZMJ54r7exzjfGBooESYLad/UHAFcA1nn79t/M3avW8dnBesu31aQEubLLdB3jTVuYHPfW8Ljyf9e/7+s+XF9s47gna/3xo9+/R5O+2qsX2+tf8AaAQ431vbqOM723vy0048JnWQTwOxd3inEx773Unr7vOVka27ecTDpyTVv8+78xzbOc4u58z7V0HiAUybPt3uer5AcbYXuu1wFV29vsC5wE/c8XrvbtuwO9sz8VeO/tuse3LAlJb7DvH9lxoYHwH17i//jVk+7nD3Y9Lmkqd95jt54NKKT8PXH8tsBujlijCzv6rML5JLnHBtaZjfMvL0Vr/Rmtd1fIArfVGrfXdHZRzo+3nOxj/9KB5DVxbcjC+9U1W9ptu5mH8g30T4xuTU7TWHwCv2+5e7uz5LnCl7ecTWuuVLXdqrSu01m9rrV9vua8Huwvj9fmy1vqj+o1a6zqMD9IS4CKl1PAeel5XPQdUAde38Zp3tyqM91MfjFr9ZpRSg4GpwFcY788eRxldDyZifJFcZNv8I6WUv8eCcpDW+hSNte/D2msyddLlGF+E39Fav9lyp9a6Vmv9mdb6SRddr7vU/595yc6+2bafT2utjzTdobX+ksYaujabh5VSo4F7gffpuNXJZSRxc957GN/W+gG3eiiGJTQ2IbZ0A8Y3p1ddcJ04289TnS3AVo18re3uYhrfQFcrpQIcKGKx7eciO/tuaHFMZ2yw/ezbhTI6q/75dUdzm7e6yPazVTKqtS4BlrY4rqed11XHgMcxmmf/6uKyHeXu95ynNXyR1Fovx/giHIlRA98TbG3ye7yLyux1n0VKqWkYXWLqgFfsHOLol/28Nsr3xWjKLsNovu42krh1Tn3fj3uUUiEeuP6rGMnZDU03KqWGYHw7+EJrfdwF16n/FjJSKTW3k2UswPhwOQSs1lqvw2jGjQAuceD8jzCabq5TTQZk2GobLwJ2aq3XdzI2MJqWoRM1di5Q//wuUi7u7N6NBiql/qKUek4ZA1h+1NZ7QikVBgyw3W3rb1a/fVxPO8+F/obRpeBipdRkN5Tfribv0Yua1uorpUzAdRhdEz7u7rhcwfbPtv6L5BLbz/ok1JFWAG/Q9LMi10Vl1n8WXaaUSnJRmZ5W//f8rI3/h/V9xG9TSqU23aGUOgejRi4H+KyN8u/F6Od9l9baVX8Hh0ji1gm2Zq2vML6l/NoD1z+O0fF0klJqWJNd9YncEhdd6ntgC8a3/2VKqZVKqXuVUucppWIdLKP+zbNE2zoE4MQHpda6GngDo+nm7Ca7rsKodVziYByt2P4RXWi7u6Wz5XTBcxjfBscAWbbRXLcpYyRjp5vhnRhl1fSW2cnLTQPuAW7GeC+8DhxR9kcap9l+Ftlqreyp/wfSrwee5xJa60Lgn7a7f3d1+Q5agtHH9eom287GGBj1hu192RMtxOgndhBYbdv2Ksb7cG7Lf+Beqr6z/U7gsIvKfAWj5igJ2K+Uekcp9Qul1HSlVFBnC1WNsyk4devqg1FKBWP0DwZ4sY3D3gJewBgItk8ptUwp9aZSai1GUrcOmKO1bjXYTCk1DvgDRiWJvdo8t5JRpZ13N0YfkF8rpZ6y9T3oTosxOosuAv7PVhvl0m/DWmurMqY+WYLxWGfZbgAoYxj9s8Dz2s6ITKVUPHAuRofNpqOjXsWoVZijlErTWmd2EMpibCOpMEYBgZGk1gGvOfWgjLj8MKrQ7wfSMWovn3C2nDYcVs1nF2lqq9Z6bP0drfU6W4LzFEZiep3tBlCllPoM+HsnahTbG2XVFrvNAe04DvwF47V2CONvMQyjM/DFwFtKqfNsfUXq1dfElbdTbv2HZGgPPM+VHgXuBGYppc5p8Tx2h1cxmmoXAU/btrmzmfR6pdT19nZordt6Q61s571WrLWOsLO91RdJrfUJpdTnGAnRIuBPjgbdnZRSfTBi/AdQDNzY5MuwPQ4/P1rrI7ZapiXAIIyRo/VfvmqVMbXJv7TWy5wM+wuMwRLd7QqM9/EJ2qgxsz13NyuldmE8p/Oa7C4EvsZOP07b/4+XMUbN3+LasB0jiVsnaa03KaXexuhgfg9GZ+bu9DFGknadUuoPGN+G+2B0dK9x1UW01jnA2bZOmBdgdEyegNH8ORbjQ/1SpdT5dq57PcZrbHnTzp9a61xbUnIBxgflAx3EsEEZc+NcqJSKBBIwhsV/orVub7h7s1ja+MdQCvy0i82tTbU3HciRlhu01h8ppb7ASHDnYnSaHoNRm3gJxmO+VWv9gqMBaK3dPnWILZFomUz8AFyilPoP8CvgPy2Oqf8v4uw36p5ynstorSuUUn/CeH/9TSn1VQf/pF19/RNKqS+B820DMI5j1E7v0FpvdMMl25sOpC3tTXdR0XKDUioRY5oHK62n81iMLXFTSv25O5/r9rRR+5QFzNRaZ3VwulPPj9b6e1sLzhyM/yeTMboChGB8cT/L9ty0mg+tLVprT9UY1/djfMU2oKgVW5eI/2E8tr9gVAKcwEhc7wbuAy5QSk3XWpc2OfV+YBRwm9b6qJvib5ckbl1zL8a8PrcqpR7u4I1U/wZs8ytQi/3tfnBorWuUUv8D7sAYurzItsstnYa11tuAbfX3lVJjMGpXfoTxTeUXQMuJQ9uLaTG2xE0p9aADH5RLMKYkuZrGpixnHmvTfwwWjD5EW4GPtdZFTpTTkd84UIPYjK3Z6UPbDVvTxHyMWsnBwJNKqS+0Z+f1csZfMF4PI5RSqU2S9voPv/b6hdbva/pB2VPOc7UXMJqfx2F8QWw12q8Jl36+2CwGzsd4H2diNJ26a1DCGq31IifP+bvWepUTx1+PrduHnX+4n2AMwuqH0bdphZOxNOXKv0V9gumLEdsUjIFUbyilZnfQZO3s84Ot5WSZ7VbfJ3AW8BDGl8r7lFKf2PpBeiVljHyeZrtrbzRpvf9gvL7vbpFgbgWuslUSnI0xivePtrInAP+HMb3Xs66N3HGSuHWB1vqAUuoFjNGlf8L4YGhL/Tec4A6Krf+H0N4krvUWYyRudwEzgG26+SSSbqO13gpcY+uIfgHGQIGGxE0pdQZG0xnAL5VSLUfd1L/2+mLUNC3v4JKvYfT3+QlGzWIexoetozrzj8EjtNYVwPu2vhb7MOZzOxd43pHzlVL/phNNpa6qqdNaFyqlTmJM3JlEY01jpu1nhFIqrI3+Yyktju1J57mU1rpOKXUvRsL2F6XUe+0c7o7Pl6UYk25fi9Fk1KmuCV5kke3nEKWUvdq9+s+kn9C1xM1lf4uWn1lKqakYNWlTMb4g/bZzITpGa12L0b85A2OC4SSMmleHEjel1O8Bp6e16eJndX1z+Bqt9d424qrvWgR2Ro7b/A8jcZuHLXHDqJX1wWhxatkUHWH72U81rppxk9b6gJPxd0gSt677E/Bj4FrVzlI1GBPFAgzsoLxBLY5vk625dhuN8y0t6egcN/gKI3FrOVih6cCD9A7K+AkdJG625tX6figAj7mySdgbaa2P2fpfpNP6+W3PZTg/vUkWLlqdwfahWD/yreGfkta6RCl1EGPE5kSMPiQtTbL93NzTznOTtzFqtscDN7VznDs+X+pr9e/ESMI/1u5ZKcLtlFJnYvRrBaMzenuDEC5RSoVrrYs7ebn653aAUkq105rg8N+intY6QxlLBr4I/Fwp9YzW+mAn43SY1rrMlrxdhnOfRfMxVlxw1qJOnFP/2fNj2922BiWAMbCwft6+tv7ORbafUXb2DaOxYqKlIBofs1tmnZBRpV1kG+H5KMZz2d68S/Xf4C5QxtxmrShjOo9RGE15q+0dY8dzGN+KT+Lib8OqnZ6tTdR/ADY047UY0TNda63s3YD6NTgvVu2sR9rECxiPNZ+ePY8U0PHza/sQ6mO763AzqdY6ra3nvJ1bWhceSksLMD68Smm9DFD9ZLbXtDzJ1uekPjH/oIee51K2f/r1E1zfj/G82vMNxufGaKWU3eTN1ux1ge1uqwmf2/ASje85h/tZeqH6L5Ivtvc+wJjmJZDmo2mdtQkjGYjA6C/WlvrO/47+LeotxhgF74fRD6vLOvtZ3xGt9axOfBY5EktbzsP4klGKMeF7W/JpnAKqrQl2p9p+Nozc1Vo/0E7M9RP67myyfUunH0k7JHFzjX9gDBRYSNvTA3yA0c+qH0afpWaTz9pGDNUnXq9pB+dh01o/qbWO0VrHa9ePbF2olHpfKTVHGVNnNI1XKWMR3jtsm95qsvtyjNF2hzHWeGwr9l0YH3IBGH3l2qW1/tj2WGPc9YboZi8qpf6k7C88H4Yx0rUPxodQW3MJdTulVJAypi1p9W1SKXUejU26T9qaWpp6BGM01vVKqQuanOeD0WckDPjQ9troceep5guEp+ECWuuvML74JdC42kbLY45jNPko4DVbR/ymcQViDHToi7Gs1YcOXntLk/fc0o7P8D62L5L1K6N0NDF5/f5Oz+lmawl41Hb3GVufq6bx+CilHsRIDApov2bIXvmaxrlEr1VKDWrveAf9RSn1iDJWlWhGKRWgjAXaJ2F8Oei2FQI6of7v9qbWus1R4ba/Uf3r+TGl1ICm+5VSZ9M44LC9vqUeIU2lLqC1LlZK/R1j7iW734i11rVKqUswhkffgjESMwPjm1kiRmdKf4y5037eLYF3zIQxtcPFQIFSajNGzV4YMJzGJPUNmve/qn/zvNZOM0G9VzGagX6CMS2GN0lUSv3Qzv5NWuuWfff+rZRqr//QY036IUZhTLFwn1JqH8YM7hUY/6AnYlSzVwHXa62dna7Dnfww/lb/VUrtxmjqsWA0HdT3Z3kfo4aoGa31UaXUjRh/9w9tfY1yaOx0fQD4aU89j+ZfhlsmrV1xN8Zyd+3NqXUnRlPpGcAhpdT3GKNBwzGShGiMx3CJnYS6J/q9UmpRO/v/Z0t6r8R4Lx0Bvu2gzDeB/wITlVIjtdY7Ohnbn4HR2CYJt/VXzcL4+03C+EJWAlyutS5wtnCt9ZdKqZUYtTz30dg82JSjzw8Y/fF+AfxCKZUFbLfFF4cxOCYa4z1+l50vOV5BKRWHMdgAHEuGf4nRDWUIjX+jXIz3UP3E2m/gjf06tRcsBNsTbjQu3Gt3cXSMWqOjNC42PquN46Ixpr9Yj5G01WKMZvoaYxJT3zbOewAHF0Nuck5XF5kPwOij8G+MhDILI5GoxJi7623g/BbnDGjyHAx24BrxNC7mO9q2bZHt/idOxLrD3vNO5xaxntXkMbR3W9XkHEeO18BFTc5JwkjcXscYsXsSowN4MUZN5L+Bfp5+7dt5fvww/jF9aXtflAE1GMs1fYSRGHRUxmSMWp9TGE0WBzC++IT35PMwmr80xsSczjynS2zn/aadY95t8jp6oI1jfDE+R762xVxrez2tt70Xojt4zTu8QDbG4CCN5xaZ7+h2l+341bb7f3Ww/KW24//bYntafdkOlqMwuox8ipFA12DUnm+zvbdT2jjPoetgJIAa4zNjcJPtTj0/tnOiMZqHX8T47Dlue+2UYny2Po3t89lbbxgjsDVGU6Wj54RjDDzYgJGo1tneN18BVzt5faffQ529KdsFhRBCdJFS6lmMxGm87h3N+UIILyOJmxBCuIhS6hDwvdb62g4PFkKITpDETQghhBCih5DBCacRpdQSJw5/QWvt7BI0QojTlJOTra7RTizjJoRoJInb6eV6J45dhfNrBwohTl/OTrYqiZsQnSBNpUIIIYQQPcRpU+MWExOj09LSPB2GEEIIIUSHNm7cmKe1brXE2GmTuKWlpbFhwwZPhyGEEEII0SHbZMityJJXQgghhBA9hCRuQgghhBA9hCRuQgghhBA9hCRuQgghhBA9hCRuQgghhBA9hCRuQgghhBA9xGkzHYgQQgjhrJKSEk6ePEltba2nQxG9gI+PDwEBAcTGxhIQENC5MlwckxBCCNErlJSUkJubS1JSEoGBgSilPB2S6MG01tTV1VFWVsaRI0eIj48nPDzc6XIkcRNCCCHsOHnyJElJSQQFBXk6FNELKKXw9fUlMjISf39/Tpw40anETfq4CSGEEHbU1tYSGBjo6TBELxQYGEh1dXWnzpXETQghhGiDNI8Kd+jK60oSNyGEEEKIHkISNxdZ++bfyHjx154OQwghhBC9mCRuLqJO7mTckZfJz832dChCCCGESyxZsgSlFJmZmQ3bHnjgAVasWNHq2EWLFpGcnNyN0RlNjg888EDD/QceeKBZM2RRUREPPPAAmzZt6ta43KnXJ25KqYVKqeeKi4vdep34c35DgKpl3yePuPU6QgghRHc5//zzycjIIDExsWHbgw8+aDdx84SMjAxuuummhvs33XQTGRkZDfeLiop48MEHJXHrSbTWS7XWt3RmyK0z+g4Zy5agqQw9+haV5aVuvZYQQgjRHWJjY5kyZQr+/v6eDsWuKVOmNKvlS05OZsqUKR6MyP16feLWnfym/4JIStj26TOeDkUIIYRoZsOGDSilWLNmTcO2xx9/HKUU9957b8O2/fv3o5Tis88+a9VUWt8M+dBDD6GUatVUCbB582amT59OUFAQgwYN4plnnPufuGrVKpRSvPfeeyxatIjIyEjCwsK45ppryM/Pb3Zse02lmZmZ9OvXD4Cbb765Id4lS5Y4FY+3kcTNhYZNPod9PoPps/slLHV1ng5HCCGEaDB+/HgiIiKaNXOuWLGCwMDAVtvMZjPTp09vVUZ9M+SiRYvIyMho1VRZUlLCj370I6699lo++ugjJk6cyG233cbKlSudjveuu+5CKcUbb7zBQw89xMcff8xll13m8PmJiYm8//77ANx9990N8Z5//vlOx+JNZOUEF1ImE6Xjb2Xwul+xecWbjDv7Wk+HJIQQwoUeXLqTXTklHo1heJ8w/rhwhNPnmUwmZsyYwcqVK7n//vuxWq1888033HbbbTz22GOUlZUREhLCypUrSU9PJzQ0tFUZ9c2QSUlJdpskS0tLeeqpp5g9ezYAM2bM4KuvvuKNN95o2OaoESNGsHjxYgDmz59PVFQU1157LV9//TVz587t8Hx/f3/GjRsHQP/+/XtNE6rUuLnYmLOuI0fFEbD+Sbv71775N7b+fR411VXdHJkQQojT3ezZs8nIyKCqqootW7ZQVFTE7373O/z9/Vm9ejVgNFXOmTOnU+UHBQU1S9D8/f0ZNGgQR44ccbqsK664otn9yy+/HJPJ1GzwwelIatxczMfXjyODFzFl7z/Zs345QyfOA8BqsbDu+TuZcuJ1APZsXc3QSWd5MlQhhBBO6kxNlzeZM2cO1dXVfP/992zevJkxY8YQHx/PmWeeycqVK0lNTSU3N9fp2rF6kZGRrbb5+/tTVeV8ZUV8fHyz+35+fkRGRnLs2LFOxdZbSI2bG4xa8DNKCKZi1aMA1FRXsenRK5hy4nXWh58DQOGebz0ZohBCiNPQqFGjiImJYcWKFaxYsaKhZm3OnDkN2/z8/Jg2bZqHI4Xc3Nxm92tqaigsLCQpKclDEXkHSdzcIDg0gp19LmNs2WoObP2Off89h/SS5WT0+xnpv3iTLFMyQcfXejpMIYQQpxmlFDNnzmTZsmWsXr26WeK2efNmPvjgAyZPnkxQUFCbZfj5+VFZWen2WN9+++1m99955x2sVitTp051uIz6aUy6I97uIombmwxa8GvqMJH6/gUMqdrO+rEPMfX6v6JMJk5EjKdf5XYZeSqEEKLbzZkzh3Xr1lFRUdEwcnT8+PGEhYWxcuXKDptJhw8fzqeffsqyZcvYsGEDOTk5bolz586d3HDDDXz55Zc8/vjj3HbbbcycOdOhgQn14uPjiY6O5s033+Sbb75hw4YNraYU6WkkcXOTmD592RyzkDrM7J71HBMvuqNhnzntDMKoIHP3Bg9GKIQQ4nRUn5ilp6cTFhYGNI44bbq/LU888QTBwcEsXLiQiRMn8txzz7klzkcffRStNVdeeSV/+MMfWLBgAe+++26r45oucdWSyWTihRdeoLCwkHnz5jFx4kSWLl3qlni7i9JaezqGbpGenq43bOjeRKmutoaK8lLCIqKbbT+etZfExZNYO/T3TL7q7m6NSQghhGN2797NsGHDPB3GaWfVqlXMnj2bZcuWMW/evDaPKykpITw8nMcff5w77rijzeO8VUevL6XURq11esvtMqrUjXx8/VolbQCJfYdwghh8sn/wQFRCCCFEz7Zx40befPNNACZPnuzhaLqXJG4ekh02lrSSDWirFWWSFmshhBC9X10HfbvNZrND5dx8882cPHmSf/zjH0ycONEVofUYkrh5iCVlKjE7l5N9aBfJA0d6OhwhhBDC7Xx9fdvdv3jxYhYtWkRH3bg2bdrkyrB6FEncPCRh5GzY+Wdytq+QxE0IIcRpYf369e3ur18UXrRNEjcPSR0yjkJCUVnfAz/3dDhCCCGE26Wnt+prL5wknas8RJlMZAaNIrF4s6dDEUIIIUQPIYmbB1UnTSFZnyAvJ8vToQghhBCiB5DEzYOihs0CIGvLcs8GIoQQQogeQRI3D+o/aioV2p+6Q2s8HYoQQgghegBJ3DzIx9ePgwHDiS04fYc1CyGEEMJxkrh5WFn8ZNIsWRQX5nk6FCGEEEJ4OUncPCxsyAxMSpO5Wfq5CSGE8C5LlixBKUVmZmbDtgceeIAVK1a0OnbRokUkJyd3Y3SGVatW8cADD2C1Wt12jVmzZjFr1qxm11RKsWrVqoZtjzzyCO+//77bYqgniZuHDRg3kxptpmL/ak+HIoQQQjRz/vnnk5GRQWJiYsO2Bx980G7i5imrVq3iwQcfdGvi9tRTT/HUU0813B8/fjwZGRmMHz++YVt3JW4yAa+HBQSFsMd3MJGnNno6FCGEEKKZ2NhYYmNjPR2Gxw0fPrzZ/bCwMKZMmeKRWKTGzQsUxqbTv3YfVRWlng5FCCFEL7VhwwaUUqxZ0ziTweOPP45Sinvvvbdh2/79+1FK8dlnn7VqKlVKAfDQQw+hlEIpxQMPPNDsOps3b2b69OkEBQUxaNAgnnnmmVaxrFu3jnnz5hESEkJwcDBz585l3bp1zY5p2TxZLy0tjUWLFgFGs+2DDz4IGOug1sfkKKUU99xzDw899BDJyckEBgYyY8YMtmzZ0m4sLZtK09LSyMrK4vXXX2+IoT5GV5PEzQsEpE3ET1k4um+Lp0MRQgjRS40fP56IiIhmzZwrVqwgMDCw1Taz2cz06dNblZGRkQEY/dkyMjLIyMjgpptuathfUlLCj370I6699lo++ugjJk6cyG233cbKlSsbjtm2bRszZ86ksLCQJUuW8Morr1BSUsLMmTPZunWrU4/ppptu4sYbbwRgzZo1DTE545VXXuGzzz7jiSeeYMmSJeTm5jJ37lwKCgocLuODDz4gISGBc845pyGG++67z6k4HCVNpV4gMLIPABWFJzwciRBCiHZ9/ns4sd2zMSSMgnP/7vRpJpOJGTNmsHLlSu6//36sVivffPMNt912G4899hhlZWWEhISwcuVK0tPTCQ0NbVVGffNgUlKS3abC0tJSnnrqKWbPng3AjBkz+Oqrr3jjjTcatv3pT3/C39+fr7/+moiICADOOuss0tLSePDBB53qJ5acnNwwIGLy5Mn4+Dif1lRWVvLVV18RHBzcUM6gQYN4+OGH+fOf/+xQGePGjcPf35+YmBi3N6H2yBo3pVSwUuplpdTzSqlrPB1PV4VEGZ0+a4pPejgSIYQQvdns2bPJyMigqqqKLVu2UFRUxO9+9zv8/f1ZvdoYJLdq1SrmzJnTqfKDgoIaEjQAf39/Bg0axJEjRxq2ffvttyxYsKAhaQOjz9gFF1zAN99807kH1gXnnXdeQ9IGRrPnlClTnK656y5eU+OmlHoJWACc1FqPbLJ9PvAoYAZe0Fr/HbgEeFdrvVQp9RbwuididpXwWKPGzVImiZsQQni1TtR0eZM5c+ZQXV3N999/z+bNmxkzZgzx8fGceeaZrFy5ktTUVHJzc5slX86IjIxstc3f35+qqqqG+wUFBc1GqdZLSEigsLCwU9ftivj4eLvbdu7c2e2xOMKbatyWAPObblBKmYEngXOB4cDVSqnhQDJw1HaYpRtjdIuQ0AhqtA+UnfJ0KEIIIXqxUaNGERMTw4oVK1ixYkVDzdqcOXMatvn5+TFt2jS3xRAVFcWJE627Bp04cYKoqKiG+wEBAdTU1LQ6zpm+Z47Izc21uy0pKcml13EVr0nctNbfAi3/GpOAA1rrQ1rrGuBN4EIgGyN5Ay96DJ2lTCYKVTjmKte+GIUQQoimlFLMnDmTZcuWsXr16maJ2+bNm/nggw+YPHkyQUFBbZbh5+dHZWVlp2OYOXMmn376KaWljTMplJaWsnTpUmbOnNmwrW/fvuzbt69Z8vbtt982Ow+MGj2g0zF99tlnlJeXN9zPzMzkhx9+YOrUqU6V4+/v36XnxVHenvQk0VizBkbClgS8D1yqlHoaWNrWyUqpW5RSG5RSG06d8u7arFJzBH7V+Z4OQwghRC83Z84c1q1bR0VFRcPI0fHjxxMWFsbKlSs7bCYdPnw4n376KcuWLWPDhg3k5OQ4df377ruPyspK5s6dy3vvvcf777/PvHnzqKio4P7772847qqrriI/P5+f/OQnLF++nOeff56f/vSnhIeHt4oH4D//+Q9r165lw4YNTsUTGBjI2WefzYcffshbb73F/PnzCQsL45e//KVT5QwfPpzVq1fzySefsGHDhmarTbiStydu9iZj0Vrrcq31DVrr27TWbfZv01o/p7VO11qne/sEghW+kQTVdn/bvhBCiNNLfWKWnp5OWFgY0DjitOn+tjzxxBMEBwezcOFCJk6cyHPPPefU9UePHs2qVasICwvj+uuv57rrriMkJIRvvvmGMWPGNIvzmWeeYe3atSxcuJDFixfz2muvNRvUALBgwQJuv/12nnrqKaZOncrEiROdiufHP/4x559/PnfccQfXX389sbGxfP31182abYEO54f729/+xpAhQ7jiiiuYOHFiq/ntXEVprd1ScGcopdKAT+oHJyilpgIPaK3Psd2/G0Br/Tdny05PT9fOZuHdaf3Dl5NcvJnEBw54OhQhhBDA7t27GTZsmKfDEG5UPwHvX/7yl3aPGz9+PP379+fdd9912bU7en0ppTZqrdNbbvf2Grf1wCClVD+llB9wFfCxh2NyC0tANBG62NNhCCGEEMLm0KFDLF68mG3btnlsiauWvGk6kDeAWUCMUiob+KPW+kWl1B3AlxjTgbyktfbO8bldpINjCVQ1VJQVExQS3vEJQgghhLDLYrHQXouiyWTCZOq47uqxxx7j1Vdf5ZprruH22293ZYid5jWJm9b66ja2fwZ81tlylVILgYUDBw7sbBHdwhxq9MErOnVcEjchhBCiC+bOndvuZL7XX389S5YsaTe5A3jkkUd45JFHXBxd13hN4uYuWuulwNL09PSbPR1Le/zD4wAozc+BfkM9HI0QQgjRcz377LOtpg1pKiYmphujca1en7j1FIERCQBUyrJXQgghRJcMGTLE0yG4jbcPTjht1K9XWlvcegZnIYQQQgiQxM1rRMQYNW51pVLjJoQQ3sKbpswSvUdXXle9PnFTSi1USj1XXOzdU20EhYRTof1RFbJ6ghBCeANfX99uWcJInH4qKysblupyVq9P3LTWS7XWt7RcIsMbFZnC8anM83QYQgghgLi4OI4dO0ZFRYXUvIku01pTW1tLQUEB2dnZREdHd6ocGZzgRcrMEfjVyELzQgjhDeqXg8rJyaG2ttbD0YjewMfHh4CAAFJTUwkICOhcGS6OSXRBpW8kwTVS4yaEEN4iLCysIYETwhv0+qbSnqTaP5owiyw0L4QQQgj7JHHzIpbAGCJ0Mdpq9XQoQgghhPBCkrh5ERUcjZ+yUFoitW5CCCGEaK3XJ249ZToQAHOosexVcV6OhyMRQgghhDfq9YlbT5oOxD88HoDyghMejkQIIYQQ3qjXJ249SVCkkbhVFsmyV0IIIYRoTRI3LxIabaxXWiPrlQohhBDCDkncvEhEjJG4WctPeTgSIYQQQngjSdy8iH9AEKU6EFUuk/AKIYQQojVJ3LxMsSkC3yr3LDT/zoajLP7usFvKFkIIIYT79folr5RSC4GFAwcO9HQoDinzicTfTeuV+n/5O6J1KUz7xC3lCyGEEMK9en2NW0+aDgSg0jeC4Noil5dbVl3HpJoMBtXtd3nZQgghhOgevT5x62lqA6IJtRa5vNx9+/aQoAqJ0MVYrdrl5QshhBDC/SRx8zLGeqUlWC0Wl5abt+d7AEJVJSXl5S4tWwghhBDdQxI3L6OCY/BRVkoKXTsliMrZ0PB7Ub7MEyeEEEL0RJK4eRmfMPesVxpTvL3hd1lSSwghhOiZJHHzMgH165UWuq5WLL+knCGWgxzzHwBAZfFJl5UthBBCiO4jiZuXCYpMAKCqyHW1Ygd3biBIVVOWOgeAmhJZmUEIIYToiSRx8zJhtmWvaktdl1yVHPwBgLgJFwBQVyYrMwghhBA9Ua9P3JRSC5VSzxUXF3s6FIdERBs1btYy1yVufic2U6JCiRw01dhQ4Z6VGYQQQgjhXr0+cetpE/D6+PpRRAimCtfUimmt6VO2k5zg4WD2pZgQzJWSuAkhhBA9Ua9P3HoiV65XmnPyFP31UarjxwFQagrHt6rQJWULIYQQontJ4uaFyl24XunRHd9hUprQAVMAqPAJJ6BWEjchhBCiJ5LEzQtV+UYSUlfkkrIqD68DIGnkmUbZfpEEWdzb32/d4QIqa1y78oMQQgghJHHzSrUBUYRZXZNcBedtIcfUB/+wWADq/KMIdVHZ9hwtqOCKZzN47Ycst11DCCGEOF1J4uaFrEGxhOtS6mprulSOxWIlrXI3p8JHNpYdGE2kLqG2zj01YuszjSbezUelOVYIIYRwNUncvJApJAaT0l1eU/RI5n7iVCHWpPTGjcHR+CkLRUWu6UPX0sYsI2HberRnTL8ihBBC9CSSuHkhn1Bj2avSLq4pemLXGgCiB09pLDskxig73z3rldYnbseKKskvq3bLNYQQQojTlSRuXiggwrZeacHxLpVjObqBGu1D0tDJDdv8wo2+buWFrl+vtKSqlr25pVyTko8ftWzLllo3IYQQwpV6feLW01ZOAAiJNBK3quKuNZVGFG4jy28AZr+Ahm1B4XFG2SWuT9y2HCmiHzk8dOpOFvl8ydbsIpdfQwghhDid9frEraetnAAQFtMHgLouLAZfU1ND/5r9FEeNbrY9JMpYUqvWDQvNb8gq5DyzMf3IOf67pMZNCCGEcLFen7j1ROFRcVi0Qpd3PrnK3L2BIFWNOWVSs+1htsTNWu76heY3ZRVysf8GAEZZd7P76Cm01i6/jhBCCHG6ksTNC5nMZopUGKYurClasC8DgITh05pt9wsKo0b7oFy0Fmq9OouV/CO7GWA5BGnT8dPV9K3cSU5xlUuvI4QQQpzOJHHzUiWmCPy6sF6pOraBIkJJSBvWYoeiyBSO2cXrle7NLWWmxUgWmf93tDJxhmkH244WufQ6QgghxOlMEjcvVe4bSUAn1yutrq6kb9E6sgKHoUyt/8TlpnD8ql07j9vGrELmm9dREz8WEkai+4zjTNNOtko/NyGEEMJlJHHzUtV+nV+vdPN7/yWBPNSkm+3ur/CNINBFa6HWO3RgN2NNh/AddTEApn4zGWM6yL4jOS69jhBCCHE6k8TNS9UFRBOuna+tKi7MY8i+p9nhP5ZRMy+ze0yNXyTBLl5oPirrCwDU8AuMDf1nYsZKQM5arFYZoCCEEEK4giRuXsoaFEMY5dRUO9e5f9fbDxJJKQHnPWS3mRSgLiCScGuJK8IE4ERxFdNqviMvdChE9Tc2pkzGYvJjvGUrh/PLXXYtIYQQ4nQmiZuXMoUYKxwUO7E01Ykj+xmf8wYbws5i4Jgz2zxOB0UTrsqprHTNiM8de/YwwbSf2sELGzf6BlKVOJFppp1sk4l4hRBCCJeQxM1L+YYZqyeU5DneR+zoe/cAkHTpQ+0eZwo21istLujaygz1and8BEDM5CuabQ8YPJthpiMcOHTYJdcRQgghTneSuHmpINt6pWX5xxw6/uC2DCYUfcXmxCtI7Duk3WN9Qm0LzbsocUs5/hVHfNLwjRvcbLt5wGwATFmrXXIdIYQQ4nQniZuXShqaTrX2pXL3Vx0frDWVn95NiQpm2BV/6vDwgHAjKazs4lqoAJUFOQyv3cmRhLNa70wcS5U5mKSiDdRarF2+lhBCCHG6k8TNS4WGR7EreBL9Ty7HarG0e+z2b95nZPVm9gz6KeFRMR2WHRRhLDRfXdz1heZPrH0Hk9L4jLyo9U6zD4Wxk5jCDvaeKO38RQoOQ/7Bzp8vhBBC9BK9PnFTSi1USj1XXNzzJoK1DL+IOArYu2F5m8doq5Wg1X8mR8Uz7tLfOFRuaP1C86VdX/bKZ+8nHLD2YcjIiXb3+w2aTZoplwP7dzteqNZwbBOs+As8NRUeGwsvngVWqbUTQghxeuv1iZvWeqnW+pbw8HBPh+K0oTMup0r7UrLh7TaP2fn9JwywHCZ71B34BwQ5VG5YlNFUSlcXmi/Pp0/RBjICziQyxN/uIVEj5wFQt3+lQ0Xu/OIFyv4+BJ6fDav/A4FRMOISqMiHvL1di1cIIYTo4Xp94taThYRFsitkCgNOLcdSV2f3mNqMZykklNHn3uhwuWZfP0oJQlV2bdkr674vMWOlIOXsNo9RccMpMkUSfSqjw/K0pY6EtX8hu9KPB8138P7cVViu/wTm3GsccHRtl+IVQgghejpJ3LycHn4RMRSxZ13rQQonjuxjdNl37Em8mIDAYKfKLVbh+FR1LXEr2/k5uTqCxCGT2z5IKXIiJzKieiuV1faTz3o5274mWheyud/NbI0+j199ks25j37LypMh6KBoOLquS/EKIYQQPZ0kbl5u2MzLqdR+lG16p9W+zC8eByBt/h1Ol1vhE05ATWHnA7PU4Ze5klWWsUwfEtvuodZ+M4hTRRzas7Hd40o2vEW59mfG+dfw3m1n8NQ146mus3LDyxvYroagpcZNCCHEaa7DxE0p5aeU+oVSamR3BCSaCwoJZ3foVAbmfd2subSqspwhOR+wLfiMDudts6fSN5LArqxXmr2egLpSjsZMIzE8sN1D40efA0DpzmVtH2SpJfn4Mtb5TSIpLgalFOeNSmTZL2fy0xn9+awoFZV/AMrzOx+zEEII0cN1mLhprWuAvwNR7g9H2DXiYqIpZvfazxs2bf9yMZGU4jv1p50qstY/klBLUadDKtiylFptJnHcuR0eG5s6mIOmNJIOvd3myNCKvSsItZaQn7ag2XY/HxO/mDeIPb7DjA3Z0lwqhBDi9OVoU+luoL87AxFtGzbjMiq0P+W25lJttRKxYzFZpmRGTFvYwdn2WQKiiNCl6E5OsWHZ+xXr9RDOGjfIoeMPDbqBlLosTm5aand/wdo3KdGBpE66oNW+ID8fRkycSa02U7J/TafiFUIIIXoDRxO3+4H7lFKj3BmMsC8wOJTdYdMYnL+Sutoa9m1axaC6A5wY8mOUqZPdFINi8Fe1lJU531yqi7OJrdjPwfAziAsLcOicEWf/hGM6hppvH269s66GqKNf8Y2ayPj+CXbP/9G0oezUaRTt/c7peIUQQojewtH/+v8HhACblVIHlFKrlVLfNrl948YYBaBGXkIkJezO+JzSb5+iTAcy4txbOl2eKcRYYaEkz/llr3I3GrVm4WPOd/icPtFhrIi8nOSSzVizmg8ysB74miBrGUf7nIuP2f5LMikikIKoccSW7qS8otLpmIUQQojewNHEzQLsAlYDR4E627b6m0xp72bDp19MuQ6g7odnGF28kp1x5xMSFtnp8vzCjJGgZYUnnD63bMdnZOsYpk46w6nzoqbfRKEOoXDZP5ttL97wFkU6mMTx89s9P23cbAKpYdW3K5yOWQghhOgNHErctNaztNaz27u5O9DTXUBQCLvDz2Rcxff4qToS5t3ZtfLCjcStysmF5nVtFUkF69gVMpVYB5tJ680d3Y+31DlEZy+HU/uMjbVVBB3+ii+tE5kxNKnd8/uPNV5mhzatwGrVTl1bCCGE6A1kHrcexDzyYgC2+4+n75CxXSorJNJY9qq6xLllr7I2LyeQKvyGtl87Zk+Ar5mCkYuo1H7UrH7E2HhgGf6WcnZHzyO6jWWzGoQnURGYSFrlDlbuPen09YUQQoiezuHETSmVqJT6t1JqvVLqoFJqnVLqn0op+73JhcsNn3EJm4PPxHfePV0uKzQ6EQBrmXOJW96mpVRrX0ZPX9DxwXacP3k0b1tmYt7+FpTkUL3lXfJ0GDG2NU07EtBvChPNB3jpu8Odur4QQgjRkzmUuCmlBgNbgJ8DZcA6oBz4BbBFKeXYnBCiS/wDghj3208ZOtGxJKc9oWGR1GozusLxxE1rTVzut+wJHENURESnrjs6OZyvIy8HbYXV/8F84Eu+sExk1rA+Dp1vSp1CAnkcOrCPPSdKOhWDEEII0VM5WuP2D6AEGGzr03a1rV/bYKDYtl/0IMpkoliFYnZiofl9u7aRqnOoG3BW56+rFNMnTuRTy2RY/wI+lkq+C5jBiD5hjhWQMgmAyX4HWLwms9NxCCGEED2Ro4nbbOA+rXVm041a6yzgAdt+0cOUmiLwq3Y8ccta+yEAg6Zd0qXrXjQuieetxsTBJ3UkEUNnopRy7OSEUeATyOVxOXyw5Ri5JVVdikUIIYToSRxN3PyA0jb2ldr2ix6mwiecgNoih47VWhOWvZLjPimE9RncpevGhvoTP3gyiy3zebzuQmYNS3T8ZLMvJE0g3bQfq1XzxIoDXYpFCCGE6EkcTdy2AHcqpZodr4xqkttt+0UPU+UXSZCDC81vP3yccZYdlKa6pnL18vRkHqz9MW8xnzMHxjh3csok/PN2cM2EWN5cf4SjBRUuiUkIIYTwdo4mbn8C5gG7lVJ/UkrdppR6ENgJnAU86K4Au0optVAp9VxxsfNLO/V2tQGRhFkde15yti7DX9USP6Fza6O2NGdoHDEhfkwZEE2wv49zJ6dMBmsdPx9ailKKx77e75KYhBBCCG/n6AS8XwALMJpF7wGeBO7FGGG6QGv9ldsi7CKt9VKt9S3h4eGeDsXr6IBownQ5lrrajg/O3ogFE+GDznTJtX3NJt68ZSr/vHS08ycnTwQgumAL107uy3ubsjl0qswlcQkhhBDerMPETSnlq5S6ENirtU4HQoEUIFRrPUlr/aW7gxTuoYJjMClNSUHHk9lGFu/kuG8q+AW57PoD40JICHdu9QUAgqMhehAcXcftswfg72Pm4eVS6yaEEKL36zBx01rXAm8Dabb7FVrrY1pr6VjUw/mE2haaL2h/2avK6jr61+6nOHJEd4TlmJTJcPQHYkzl3DAtjaVbc9h9XOZ1E0II0bs52sftEBDnzkBE9/MLM/6klUXtLzR/8NA+YlUxpqTx3RGWY0ZfDtVl8OxMbh9YRGiAD/9dts/TUQkhhBBu5Wji9k/gHqVUrDuDEd0rKMJI3KqKT7V7XN6+tQBED5rs9pgc1n8W/MRopQ/53wKe7P8Dy3adYOvRIo+GJYQQQriTo8P55gBRwGGl1A/AcUA32a+11te7OjjhXiFRxkLztaXtJ27WY5upw0TswAndEZbjkifArd/Chz9jxt7/8lLAJJ76IpBnb57r6ciEEEIIt3A0cZsO1AKngAG2W1O61RnC64VHJQAdLzQfXriDY7596evCgQkuExgJV70OGU8ya9kfGZ29iKL/JhAR4AtoY01UnwC44mWITPN0tEIIIUSXOJS4aa3T3ByH8IDAoCDKdCCqMr/NY2rrLKTV7Odo7Ez6dmNsTlEKzrgDa9JEDrz5Z0qKKpmYFk1ksB9Y6mDf53BoFUxY5OlIhRBCiC5xZDoQP6XUJqXU2d0RkOhexaYwfNpZaD7z0D6iVQmqz9juC6qTfPpOZtidH/D3sPuYe+xmjp71HFz1P/ANhpO7PR2eEEII0WWOTAdSA/QD6twfjuhuZeZw/GoK29x/at8PAEQNmtRdIXVJeJAvL1yfjsWqufHl9ZTWWCBuqCRuQgghegVHR5UuA6TGrReq9IkgoK7tZa8s2Zuo1Wb6DJnYjVF1Tf/YEJ6+ZjyHTpVz5xubscZK4iaEEKJ3cDRxexy4Win1b6XUmUqpAUqp/k1v7gxSuE+1XyRhlrZr3EILdnDUty9mv8BujKrrzhgYw4MXjmDV3lMsz4uG8pNQ3v4gDCGEEMLbOZq4fYOxzNWvbL/vA/a3uIkeqCJqGHE6n4LDW1rts1qspFXvoyDMi1ZMcMI1k/tyw7Q0Xj1kGw0rtW5CCCF6OEenA7nBrVEIj+k7+wZqXnqcYyueI+rGp5rty8naR7IqQ/eAgQltuWP2QM79LsW4c3I39Jvu2YCEEEKILnB0OpCX3R2I8Iz+fdP4LmAqo7I/Rtf+F+XbuOh77t4MkoHIgT1jYII90SH+WIPjqSCUoFNS4yaEEKJnc7SpFACllEkpNVIpNVMpFeyuoET3qh19HWG6lMzv3m6+/egmarSZlKHpHorMNYYmhnFYpUpTqRBCiB7P4cRNKfUz4ASwDVgBDLFt/1Ap9XP3hCe6w4Q5F3NMx1K3oXnFakjBDjJ9+uEf4IUrJjhhcHwo22oS0Sd3gZZFPoQQQvRcDiVuSqmbgUeBD4ErANVk92rgUpdHJrpNaKA/2+MuYFDZBipPHgRAW62kVu0lP2yYh6PruqEJoeyyJKOqiqH0uKfDEUIIITrN0Rq3XwH/0VrfAnzQYt8ebLVvoudKmPkTLFqRuexZAPKz9xFGOdaEsZ4NzAUGJ4Syz1o/QGGXZ4MRQgghusDRxK0f8GUb+8qBCJdEIzxmzIgRrPMZT9zB98BSx4k9GQBEDJzs4ci6bnB8CPt0knHn5B7PBiOEEEJ0gaOJWx6Q1sa+IcAxl0QjPEYpReHQq4i25pG7+VOqj2ykWvvQd9gET4fWZUF+PoRGJVBijpQBCkIIIXo0RxO3pcD9LVZI0EqpGOCXGH3fRA83fu7V5OkwSr57geC87Rwy9yMkqGcPTKg3OD6UA6RKU6kQQogezdHE7V6gGtgBLAc08BiwG7AAf3JLdKJbJUSFsi58Pv0K15BatYe80J4/MKHe0IRQttYkok/tAavV0+EIIYQQneJQ4qa1zgfSgb8BvsBBjMl7nwCmaq3bXqVc9CjBU27ABytBVFHXCwYm1BucEMoeawqqtgKKsjwdjhBCCNEpDs/jprUu1Vr/WWt9ptZ6sNZ6qtb6Qa11iTsDFN1r6qQpbGQ4AGH9J3o4GtcZmhDKfqttgMIpGaAghBCiZ3Jq5QTR+/n5mNg96Kessowhbeh4T4fjMv1igjlskilBhBBC9GyOLjIvTiOXXH4tu3IuIDo8xNOhuIyv2UR8bBz5pXFEy8hSIYQQPZTUuIlWgvx8SE+L8nQYLjc4PpS9OkWmBBFCCNFjSeImThtDEkLZVtMHnbcPLLWeDkcIIYRwWo9M3JRS/ZVSLyql3vV0LKLnGBIfyj5rEspSAwWHPB2OEEII4TRHF5lfoZQa2sa+wUqpFY5eUCn1klLqpFJqR4vt85VSe5VSB5RSv2+vDK31Ia31jY5eUwgwatz26voBCtJcKoQQoudxtMZtFhDWxr5QYKYT11wCzG+6QSllBp4EzgWGA1crpYYrpUYppT5pcYtz4lpCNEiKCOS4byoaJYmbEEKIHsmZUaW6je0DgDKHC9H6W6VUWovNk4ADWutDAEqpN4ELtdZ/AxY4EWMzSqlbgFsAUlNTO1uM6CVMJkVqfDQnCvqQKFOCCCGE6IHaTNyUUjcAN9juauA5pVRpi8MCgZHA112MIwk42uR+NjC5ndiigYeAcUqpu20JXita6+eA5wDS09PbSjzFaWRoQii7TyWRcHI3ytPBCCGEEE5qr8bNirEOKYBqcb9ePvA08I8uxmHvf2ibiZZtCa5bu3hNcRoaHB/Kjs19mF2wAWqrwDfA0yEJIYQQDmszcdNavwy8DKCUWgncprV211pB2UBKk/vJQI6briVOY0MSQnnDmoLSVsjfDwmjPB2SEEII4TBHF5mf7cakDWA9MEgp1U8p5QdcBXzsxuuJ01SzkaW5Oz0bjBBCCOEkhwcnKKXCgPOAVKBl+5LWWv/ZwXLewBilGqOUygb+qLV+USl1B/AlYAZe0lq75L+qUmohsHDgwIGuKE70cDEh/hQF9qVchRN8cAWMucrTIQkhhBAOcyhxU0pNA5YCEW0cogGHEjet9dVtbP8M+MyRMpyhtV4KLE1PT7/Z1WWLnmlQYgRr89KZs+9LYwUFs6+nQxJCCCEc4ug8bo8AmcBEIEBrbWpxM7srQCFcbXB8KB9WjoGqIjiS4elwhBBCCIc5mrgNA+7VWm/UWte4MyAh3G1oQijLa0aizf6w9/O2D9QaXrsMXr4A8g50X4BCCCFEGxxN3I4A/u4MRIjuMjghlAoCyIudAns+NRI0e3I2wYFlkLkanj4D1jwMlrruDVYIIYRowtHE7UHg97YBCkL0aEMTQvHzMbHWbwoUZUFbqyhsXAK+QXD7Whh8Nix/AF6YA8e3dWe4QgghRANHE7cFQDxw2LZe6Cstbi+7McYuUUotVEo9V1xc7OlQhJcI8vNhxqAYnssdbGzYa2dMTFUJbH8PRl4CsYPhytfgileg5Dg8Nwu+/jPUSa8BIYQQ3cvRxO1MjJGjJcAIYLqdm1fSWi/VWt8SHh7u6VCEF5k/MpFtxYFUxI6FPXYSt+3vQG05T5WeyR8+2E5ljQWGXwg/Wwujr4TV/4YX58Gpvd0euxBCiNOXoxPw9uvg1t/dgQrhSvOGxeFjUqwPmGr0ZSs53rhTa9i4mOro4fxzRyj/W3uES57+nqz8cgiKgoufhitfh+JseHYG/PAMWK2eezBCCCFOG47WuAnRq0QE+TF1QDRL8oYbG/Y1GV16bBOc2M7KkPPxMZn49+VjOFZYwcLH17Byz0njmGEL4LYM6DcTvvg/eO0SKJFV2oQQQriXw4mbUipYKfVzpdS7SqmVSqlBtu1XKaWGui9EIdxj/sgEVhZGURPWt3lz6cbFaN8g/pY9krnD4rhsQjKf3Dmd5MggfvLyeh5etg+rVUNoPPzoLVjwCBxdC69f7rHHIoQQ4vTgUOKmlEoBtgH/AgYBM4BQ2+7ZwG/cEp0QbnTW8HiUUuwImQaHv4HqMqgqhh3vkZ10Hlnlvlw1MRWA1Ogg3rvtDC4el8SjX+/n7ve3G4UoBek3wNSfGaNTZboQIYQQbuRojdt/gGqMpG0CoJrs+wYjkfNKMqpUtCUuNICJfaN4o2QkWGrg4Ne2QQkVvFQ1k4SwAGYMjm04PtDPzH8uH8P1U/vyzsajHCuqbCwsrA9oK5Sf9MAjEUIIcbpwNHE7C2Mx+CMYo0ubOgYkuTQqF5JRpaI980cm8H5eChb/SKO5dMMSamNH8nJWFJenJ2M2qWbHK6W4eYYxFueNtUcad4TZ3gK9oJ/bpiOFXP/SOsqrpfZQCCG8jaOJmx9Q2sa+cKDWNeEI0b3OGZmABTMHIs6Ane9D7nbWhC3AqhVXpKfYPSc5Mog5Q+N5c/0Raupso0lDE42fJce6KXL3KK+u4643t/DNvlPsOl7i6XCEEEK04Gjitg24tI195wIbXROOEN0rKSKQMcnhfFAxBiw1xqCEY6OYNjCalKigNs+7bmpf8spq+HyHbRqRhhq3422e0xP87fPdHCmoACAzr9zD0QghhGjJ0cTtX8CNSqnnaezPNlwp9SBwo22/ED3S/JGJvHJqIFbfIHL7LmRfkeJK26CEtkwfGEPf6CBe+yHL2BAUBWb/Hl3j9u2+U7z2wxEWnZGG2aTIyq/wdEhCCCFacHQC3veB24HLgeW2za8AdwF3aK2/cEt0QnSD+SMTqCCA9ye+wT+5noggX84eHt/uOSaT4trJfVmfWcju4yXG6NKwRCjtmTVuxZW1/N972xgQG8zvzx1KSmQgh/Olxk0IIbyNw/O4aa2fwRiEcA5wLUYTabLW+jk3xSZEt+gXE8zQhFBe3G3mk90lXDQ2iQBfc4fnXZ6ejL+PiVfra93Cknrs4IQ/f7KLk6XV/OeKsQT4mukbHWysFCGEEMKrOLVygta6XGu9XGv9P631l1rrtgYsCNGjzB+ZwO7jJdRYrFw50f6ghJYigvy4YEwfPtx8jJKqWmOAQg9sKl22K5d3N2Zz28wBjE2JACAtOoisvAq0bjmIXAghhCe1mbgppWYopUKa/N7urftCdo7M4yYcMX9kAgBjUiIYlhjm8HnXTe1LRY2FDzYdM+ZyKzlurHXaQxSU13D3+9sZmhDKz+cOatieFhNMaXUd+eU1HoxOCCFESz7t7FsFTAHW2X5v67+Rsu3ruG3JA7TWS4Gl6enpN3s6FuG9hsSHcv3Uvswd1n7ftpZGJ0cwJjmcV3/I4sdTE1GWaqgogOBoN0XqWou/O0xhRQ2v/GQSfj6N3+PSooMByMovJybE31PhCSGEaKG9xG02sKvJ70L0WkopHrxwZKfOvW5qGr95Zyt7K8MYClCa02MSt5W7T3BGciDD+zSvZUyLMRK3zLwKJvSN8kRoQggh7GgzcdNaf2PvdyFEcwtGJ/KXT3fx0SGMxK0kBxJGeTqsDp0sqWLyybf4XcCHUJgBkX0b9iVFBGI2KTJlgIIQQngVpwYnCCFaC/A1c96oRL46ans79ZABCqv2neJi83f4W8rhs98065vn52MiKSKQTJnLTQghvEqbNW5KqRVOlKO11nNdEI8QPVL/mGDeqg5BB5pQPWT1hB07tnGFKRMdNxy1/yvY9RGMuKhhf1pMsKyeIIQQXqa9GjcTxsCD+ttQYBaQBgTafs4Chtj2C3HaSooIxIKZusBY987ltvVNyD/Y5WLqLFZCM415s9UVr0LiGPj8/6CqcfR1WnQQmfnlMiWIEEJ4kTYTN631LK31bK31bOBRjIXkp2it+2utp2qt+wNTbdsf7Z5whfBOyZHGuqblAfHG4AR3OLEdPvgprH22y0VtPlrELOsPlIQPhZiBsPBRKD8JX/+p4Zi+0cGUVtVRWFHreMGWOnjzGvj0N1B0tMtxCiGEaM7RPm5/Bu7TWq9rulFrvRZ4APiLi+NyGZnHTXSH5MhAAArNbqxx++EZ42fh4S4XtX77Liao/fiNutDY0GccTPoprH8Rjq4HoF+MkYwedqa5dM9S2PMJrH8BHhsHH98JBV2PVwghhMHRxG0QcKqNfSeBga4Jx/W01ku11reEh4d7OhTRi0UE+RLkZ+Ykke5J3MpOwva3jd8LM7tcnHX3p5iUJmDURY0b59xjrP7wyV1gqaVvk7ncHJbxJET2g19shQmLYOtb8PgE+OBWY3JiIYQQXeJo4nYY+Gkb+34KZLokGiF6KKUUyZGBZFsioboEql28Gtz6F8FSA0MXQGEWWK2dLupkSRWjS1dTGNgX4oY17vAPhfP+Bbk74IenSIkMwqRwfGTpkbWQvR6m/syYWuT8fxsJ3ORbYcf78NW9nY5ZCCGEwdHE7UFgoVJqh1LqAaXUbbafO4DzMZpLhTitJUUEcrDKNpGtK2uXaqtgw4sw6BwYOBcs1V3qR/f9jgNMNe3CMmQBqBbjioYtgCHnw8q/4VddQFJkoOMjSzMeh4AIGPujxm1hiTD/rzD4bDi+pdMxCyGEO3277xTvbsz2dBgOcShx01q/CZwDFAN3A0/afhYB52it33JXgEL0FMmRQeypCDXuuHKAwo53ofwUTL3daIaELvUbK966FF9lITr9UvsHTP811FXCoVWkRQc71lRacAh2fwLpPwG/4Nb740cao2FrZF44IYT3eSUjk0e/3ufpMBzi8AS8WuvlWutpGFOBJACBWusztdZfuy06IXqQ5MhADlTb+lK6qp+b1pDxlJH49JsJUbbErZP93OosVlJyl1PkE4tKGm//oMQx4B8Gh78lLTqYw3kOTAnyw9Ng8oFJtzTbfLy4ko+2HON/maGAhlO7OxW3EEK4U3m1hYKyGk+H4ZD21iq1S2ttxRiQIIRoIikykFwdadzp4uoJWmu+3JnL8OrNpJ7cCRc+aTRrhiUbCVInR5ZuPXSMM/RWjve9ioiWzaT1zD7QdxpkrqbvuLsoqaqjqKKWyGA/+8dXFMDm12DU5RCWyOr9p/hoSw5rD+dztKASgFQVyI/8wXJ8B+akCZ2KXQgh3KW8po7yGgtVtRYCfM2eDqddDiduSik/4FyMCXcDWuzWWus/uzIwIXqa5MggqvGjxi8Cvy72cXt+9SH++tkeXvD9FyE+4bySO5pzjpcwNCEUFZ7S6abS7PUfM0HVEjvpsvYP7Dcd9n3OkMASADLzy9tO3DYuhtoKmPoz8suquf6ldYQF+jIpLYpFZ/Rjcr8ojuSXUf6eP6f2rCct/fpOxS6EEO5yTukHXOxzlPzyOSRFBHo6nHY5lLgppfoAazBWS9A0rpTQtP1EEjdxWqt/s5f5xRHVhabSb/ad4u+f7+HHg2qZd3Qz7wZfy2PfHOWRVUfpHxPMR+EphHayxi0i8wuKTeGED5re/oH9ZgAwuGIzEEdmfjnjUiNbH1dXA2ufg/6zIWEkKzYcxarhtRsnMzKpcQqeYYlh7P6gLxzZ2qm4hRDCncbWbKKfOZP8shqvT9wc7eP2L4x53FIxkrbJQH/gIeCA7XchTmsxIX74+5jIN8d0uqk0M6+cO/+3icHxodwX+y2Y/bjs1j+y7p55/OWikeSVVbO1PLJTNW4nC4qYUL2O7Pg5YOqgKSBuBARGEZO3FqUgM6+NQQU73oWyE3DGHQAs351LQlgAI/qENTvMbFKYE0eRVH2IHdlFTscuhBDuFGCtJJoS8suqPB1KhxxN3KYD/wHqqxGsWutMrfX9wLvAY+4IToiepH4utxM6Ckqdbyotq67j5lc2YDYpFi8Ix3f7G0a/sZA4YkL8uXZKXyb3j2Z7RRRUFUFloVPl7834hBBVRcjYizs+2GSCtDMxZ62hT1iA/ZGlWsP3T0DccBgwl6paC9/uy2Pe8DiUnf5zaSMmEanKeO+bDU7FLYQQ7qS1xl9X4q/qKCl27nPVExxN3KKBHNvAhHKgaZvJCozF5r2SLHklulNSZBBHasON6Tvqqh0+z2rV/PKtLRzKK+fFBZEkfngF+IXAjN82O25sSgRbyiKMO06MLK2zWKne9j5lBJE6Yb5jJ/WbAcVHmRRZymF7k/BmroGTO2HKbaAUGYfyqay1MHdYvN3iApNHA3B09zpOlTr+3AghhDtV1VoJxhhIVVmU6+FoOuZo4pYNxNh+Pwic3WTfJMBr6xZlySvRnZIjAzlQPwlv6QmHz3v06/0s25XLP2aHMH7lj8FaB9cvbZz+w2ZsSgRZ2pYYOdFcuuSbXUyp+o7CtHNRPv6OnWTr5zbDZ5f9GrcNL0FAOIw0Bjos35VLkJ+Zqf2j7ZcXNxyAgTqL/6094nDsQgjhTuU1dYQoI42pKfH+STMcTdxWAjNtvz8L/EYp9ZVS6lOMQQnvuiM4IXqapIhADlbXr57g2ACFXTklPPr1fm4eZebSHbdBXRVc/zHEDW117KjkcI4SZ9xxcIDCoVNl7Fr5JiGqipSZNzh0DgAxgyEknlG12yiqqKWooskcR2UnYfdSGHsN+AWhtebr3SeZMSi27aH0gREQnsLM8JO8+kMW1XUWx2MRQgg3Ka+uI9hW/2QpbWtZdu/haOJ2L/A0gNb6aeAXQBCQCPwT+LVbohOih0mODOS4ttU4OThAYcWeXPqofH6f+1tUdRn8+COIH2H32LAAXxJjYyg2RTrUVGq1an7/3nYuMa/GEppszM/mKKUgbTopxRsA3XzN0s2vgrXWWCkB2JlTwomSKuYOi2u/zPgRjPLNJq+smk+3yaLzQgjPK6usJkgZ3TdURe9J3GqBrPo7WuvHbasmjNda/0Fr7bVNpUJ0p+TIIGNwAjg8QGHD3kzeDfwr5upiuO4DSBzd7vFjkiPItMaiHWgqfX3dEQ5nHmIa2zCPvdIYdOCMftPxrzrFAJXT2FxqtcCGJUZTaswgAJbtykUpmDO048QtuPQQQ2L8WPxdZscrMgghhJtVVZQ2/G6uzPdgJI7p8FNcKeUD5NO8X5sQwo7kyEBKCaTWHOhQU2lZdR2xx5bTx3ocrnwF2lqGqomxKeEctMRhyT/U7nHHiir5+2e7uSthKworjL7K4cfRwNbP7QzTLg7XLzZ/YDkUH2mobQNjGpAJqZFEh3TQfy5+BMpax51jYPuxYjZmef8ILiFE71ZV3jh40a/a+z+TOkzctNZ1QC4gHVKE6EBsiD9+ZjMlvnEOJW5rD+VzptpGTUAMpM1w6BpjUiI4ouMwl+a0OXJVa809H2xHA5f7fgd9xkHsYGceiiGyH4QlM9t/D1n1TaUbXoKQeBi6ADDWI92ZU9LmaNJm4kcCcFb0KcICfHglI6uDE4QQwr1qmtS4BdUWeDASxzjabvIacJM7AxGiNzCZFEmRgeSpKIcStzX7TjLdtB3zwDkON2MOTQjjmEpAoaHoqN1jPtxyjFV7T/HXaWb8Tu3oXG0bGP3c+s1gIjvJyiuFoiOw70sY/2Mw+wKwfLcxCuus4R00kwJEDQCzP/55uzljQAy7jpd0Li4hhHCRuorGz6Ewa7HXD5xyNHHLBCYqpdYrpe5VSt2olPpJ05sbYxSiR0mKCOSY1bHELWfvWqJUKeZB8xwu38/HhE+0bZoQOyNLtdY8unw/Y5LDuUCtBmWGkZc6XH4r/aYTai3BJ28PbFxiJHPjG9cb/Xp3LmnRQQyIDem4LLMPxA6B3J0kRgRwvKhS+rkJITyqttKocasxBxGtSigsr/VwRO1zdJH5J20/k4AJdvZr4CWXRCRED5ccGUjmsXCoPWF05G9jeanswgoGFK8DX2DAbKeuEZ0yFIrAkn8Q86Czmu3bdKSIzPwK/n3ZSEzfvAMD50FIbCcfDZBmrGs6tmYj1k1fYxp0DkSkAMYw+u8P5HPd1L52V0uwK34kHFxBn76BlNdYKKmqIzzQt/PxCSFEF1irjRq3yuAUoopPkV9eTUJ4gIejapujNW79OrjJWqVC2CRFBHKoJtyYRLe87aHla/bnMcO8jaqYERDiQDNjEwP79adc+1N0bF+rfe9tyibQ18z5oQegNAfGXOn0Y2gmIoWK4FRu9/kYU/lJasYtati1ev8paixW5jnSv61e/AgoO0HfQKPP3PHiyq7FJ4QQXWCtMgZeWcL7EkUJBWXevbKLQ4mb1jqro5u7AxWip0iOCmycEqSd5tJ1e48wwbQf/8GON5PWG5sayREdR2XuwWbbq2otfLI1h/kjEwjc9S74h8GQ85wuvyWfgTOJUOUctcYy+W34xxd7OF5cyfLdJwkP9CU9LbLjQurZ5qhLqzM+No4XyWxCQgjP0dVGU6mKSrOtV+rdAxScnNQJlFKmFjcH20eEOD0Yc7nZEpk2EjeLVVN38Ft8qUMNnOv0NfpGB3HclIBPcfPvTCv2nKSkqo7LRkfC7o9h+IXgG+h0+S35DZwFgJ6wiEn9Y3j2m4Oc+Y+VfLw1h1lDYvE1O/FRYhtZGl91AIAcqXETQniQqikDwD92AAAVhd69Xmmbn7ZKqQSl1KdKqR832WbGmIy36a1IKeVEO4kQvVtSRCAnGlZPsJ+47ThWzIS6TdSZAyF1itPXUEpRHZpKRPUxaNK5//1N2SSEBTC15geoKYMxnRxN2tLQ82HWH0g95+c8e1063/x2Njee2Y+YYD+uSE9xrqyQWAiOI7R4L2aTclmNm8WqeW9jNnUWq0vKE0KcHky1RuIWEGsM+vL29Urb+5p8OzAeeKfFdgW8APwJY53SHOBWt0TnAkqphUqp54qLizs+WAgXiA8LoNgUhkX5GH3M7Fi9/xQzTNux9p0Gji763oJfzAACqKEiPxuAvLJqVu09xUXjkjBteR3CUyD1jE4/jmZ8A2HW/0GAsQ5rSlQQfzhvGN/fPZdpA2OcLy9+BKbcncSH+rusxm3NgTx+/c5W1hzIc0l5QojTg6m2gkoVgCnU6Gvs7euVtpe4zQee11q3/FTVwLNa6we11g8ATwBd70TjJlrrpVrrW8LDwz0dijhNmE2KhIggiszRbda47d69g36mE/gNPsvufkdEpw4BIHP/DgA+3pJDnVVzVWoxHP7GWNnA2SWuukv8CDi1h6RwP5fVuB06ZXxrPl4sfeaEEI7zqSunWgVBkO1LaDuDyrxBe5/qQ4Dv7Wxv2adtn+1YIYRNckQQJ5X9xK2suo7I46uNOwPmdPoaqQONTv65R/YC8P7mbEYlhZO290XwDYb0GzpdttvFj4S6KkYH5btsVGn9yg4nJHETQjjB11JOtTkIgo3EzafKu9crbS9xCwDKmm7QWluARGBrk81VtmOFEDbJkYFkWyLsJm4/HDSWuaoO7tOwSHtnRCUOwIKJytwD7D1Ryo5jJVw33Aw73oMJ10OgEyM9u5ttZOlI81GOF1e5ZBLezHxjSH9uiSRuQgjH+VkqqDEHgW8gVSoQv+qeO6r0JHbmZ9Na59oSuHr9AO+uVxSimyVFBpJZE4Euab2e6Hf7jjPNtAOfQXONVQg6y8ePQp84fIqyeH9zNj4mxYLKj43BCpO9ttupIXYIKDP9rZlU11kpKK/pcpGZeUbidkISNyGEE/ytlVh8ggCo8I0ksLbIswF1oL3EbQ1wnQNl/Bj4zjXhCNE7JEcGsdY6DFVXCa9dCpVFDfvy9mYQqiqdWuaqLdWhqcTW5fDW+qOcOyiQoG2vwoiLILJvl8t2Kx9/iBlMYpUxD11X+6XVWqxkFxpNrtJUKoRwRoCupM4nGIBq/yhCLUXUevHo9PYSt8eAOUqpfyulWi2NpZTyUUr9F5gFPOqm+ITokZIiAlluncDeqf+GIxnw0nwozia7sIKBpWuxYoL+M7t8Hd/Y/qSqXIoqavlZ2HdQUwpn3OmCR9ANEkYSUWqs/JBT1LV+bscKK6mzaoL9zNJUKoRwWE2dlSBdidXXSNwsAVHGeqUVXW8FcJc2EzetdQbwO+CXQLZS6lWl1EO226tANvBz4G7bsUIIm+RIY9LbLVHnwLXvQckx9AvzeOndpcwwbacmfqxL+qBFJg0hWpWSElDJkMzXjHVF+4zrcrndIn4kfmXHCKOsy4lbff+2CWlRFFbUUlVr6eAMIYQw1lsOUVVovxAAdFAs0arEJd033KXduQK01v8B5gFbgEuBu223S23bztZa/8u9IQrR8ySGB2A2KaP5rv8sLIs+p6iyll9m/4IxpkMEDD3bJdfxjTG6oT6StBJVehzO+LlLyu0WCcYKCqN9j3a5qbR+ROmU/sZSYydLvHutQSGEdyivqSOIKrAlbqaQGGO90lLv/QzpcJInrfVKrfV8IBRIwBhVGqq1nq+1XuHuAIXoiXzMJhLCAsgurMRq1fx2dR3nlv2RmpAkTFihE8tc2RVlzPQ94fibEDsUBnV+XrhulzAagEmBOeR0MXE7nFdOkJ+ZUUnGfI0yQEEI4YjyqlpCVBXK30jc/MLi8FMWir14vdJWfdfaYhtJ6t3rQAjhRZIiA8kurOCeD7fz/qZj/PqsqURPuwiObYCUSa65SGSa8dNaZ/Rt60lLB4fEQXAco6xH+baLTaVZ+eX0jQ4mIcyYmUgSNyGEIyrKSwAwBYQCEBBhrOBZUXAcb52i1uHETQjhnOSIQN7ffIz1mYXcMXsgd861zdnWhUl3WwkIh8AoMPvBqMtdV253SRjJoOzDLmkqHZoYSny4kbjlyshSIYQDasqN5TB9Ao3l/IIiE4ztJd47y5kkbkK4SXKUMS/QzdP78euzB7vvQrPuhvCkTq956lEJo0g4tJq86jIsVo3Z5HyNYZ3FypGCCs4ZmUCovw9BfmapcRNCOKS6wqhx87HVuJlDYgGwlEniJsRp59rJqQyIDeaCMX1Q7mzCnHyL+8p2t/hR+Oha+uocTpVWkxDu/CIsOUVV1Fk1adFBKKVICAuQudyEEA6prSgFwC/ISNzql73y5vVKvXQFaiF6vriwAC4cm+TepK2ns40sHa6yyOnkmqWHbVOBpEUb8zDFhwVIjZsQwiF1VUaNm3+wMbCpfqF5c5X3Dk6QxE0I4TnRg7Ca/RlmyuJ4UeeSraz6xC3GSNwSwqXGTQjhmLpKo8bNP8jo44ZvAJUqCL9q711oXhI3IYTnmH2wxgxluMriuDM1boVZ8NW9UF1GZl4Fgb5m4kKNPn7xYQGcLK3Cau36wvVCiN5NV5cB4B8c1rCtwjeCwNpCT4XUIUnchBAeZU4cxXDTEXIKnUjclt0H3z8OXz9IZn45fW392wASwvyptWgKvHjJGiGEd9DVRo2b8g9t2FbtF0WopRiLl375k8RNCOFRKnE00aqE8oJjjp1wcg/s+hhCE2Hdc4Tlrm3o3wY0DHCQ5lIhREdUjVHjVr9yAkBdQDTRlFDkpV/+en3ippRaqJR6rri42NOhCCHsiTcGKAQV7HLs+NX/Bt8guHEZOrIfv6p4lEGRjR9lCeHGOrGy2LwQoiOqxugji1/jlz8dHE2UKiHfS9cr7fWJm9Z6qdb6lvDwcE+HIoSwJ34EANFl+zo+Nv8g7HgPJt4IESmcmv1vUtVJFuS90HCIrJ4ghHCUqbacKvzBZG7YZg6JI4oS8r10vdJen7gJIbxcYATF/n1IrT1ETZ21/WNX/wfM/sbyXsDewDEsqTubwZmvQ9b3AMSE+GFSsnqCEKJjPnXlVJmCmm3zDYvFT1koKfLOkaWSuAkhPK40YgjDVVb7zZuFmbD1TZiwyFjnFMjMK+efdVdhCUuBj34GNRX4mE3EhvpLjZsQokM+dRVUt0jcAm3rlVYWnfBESB2SxE0I4XGW2JH0U8c5kV/U9kFrHjaaM6b9vGFTZn4FVt8gzBc9DgWHYOVDgNFceqLEO5s5hBDew99aTq05sNm2+vVKq0tOeiKkDkniJoTwOL/k0ZiVpiJ7m/0DirNh8+sw7joI69OwOTOvnLToYFT/WZD+E8h4Ek7uIT4sQJpKhRAd8rNWUusT3GybT/16paXeueyVJG5CCI8LTxsPgD6+3f4B3z0KaDjzrmab6+dwA+CMnxvHHMkwVk+QplIhRAcCrJVYWiRuBBuJmy7P80BEHZPETQjhcUFx/SklkEB7U4KUnoCNL8OYqyEitWGzxao5WlDZsNQVEX2NuZhO7iI+LIDiyloqayzd9AiEED2NxaoJ1JVYfFsmbsZ6pT6VkrgJIYR9JhNZ5n5ElraYEsRqhc9+A9Y6mP6rZrtyiiqpsVgbJ981mSBuGOTukilBhBAdKq+pI0RVNpvDDQAff9t6pd657JUkbkIIr5AbNJDk6oNGslZv5V9g91I4608Q1b/Z8Vn5FQDNVk0gbjjk7iAhzFi3VFZPEEK0paLaQhDV6CarJjTs840ksLbAA1F1TBI3IYRXKA4fRhCVUJRlbNj6pjFv2/jrYerPWh2fmW/MeJ4W02Qof/xIqCqij7kIkNUThBBtK6uqJZiqZuuU1qv2iyTEUoTVC9crlcRNCOEV6mKHA1BzbCtkZcDHd0LadDj/P2BbQL6pzLxyAnxNxIcGNG6MN8pIqDoESFOpEKJtleUlmJTG5N+6xq02IIYoSimurPVAZO2TxE0I4RV8E0di0YrqHUvhrWsgPAWueAXMvnaPz8yvoG9UMCZTk6QuzkjcAgv3EOLvI02lQog2VVWUAGAOaF3jpoOiiVbFXrleqSRuQgivEB8dwSHdh9C974LVAj96G4Ki2jy+2VQg9YKiILQP5O4kPsxfmkqFEG2qKTcSN9/AsFb7zCExRFJKQZn3TeQtiZsQwiv0CQ9kh07DqnyMmraYgW0ea7FqjuRX0C8muPXO+OHGyFKZy00I0Y4aW42bT1DrxM03LN62Xqn3TQkiiZsQwiskhAfwj9qreHfMC9B/ZrvHniiposZipW+0ncQtbjjk7SUx1EdWTxBCtKmu0kjcAoJbN5UG2NYrrfDC9UolcRNCeIUAXzO1wYlstrZd01YvM8/OiNJ68SPBUsMw35OcLK32ylFhQgjPq6sqBcA/OLzVvqBII3GrKfK+9UolcRNCeI3EiACOF1e2e4zVqnl+9SH8zCaGxLf+plw/snQQWdRZNXnl3tdHRQjhedaqMgACglonbn5hcQDUlXnfeqWSuAkhvEZieCBH8ivarSV7+puDrNp7ivsXDic6xL/1ATGDQZlJqskEILdYEjchhB3VRo2byc6oUoKMZa/wwvVKJXETQniNGYNiOJRXzu/e24bFTvKWcTCf/3y1lwvG9OGayal2SgB8/CFmMDHl+wGZy00I0YYaI3FrteQVNKxXavLC9Up9PB2AEELUu3ZKX/LLa3hk+X6q66z894ox+JqN75cnS6v4+ZubSYsJ5q+XjELZmZS3QfxwQo6sA+BEB02vQojTk6oxls3DzpJX+PhToYLwq/K+Za8kcRNCeA2lFHfNG0yAr5m/f76HmjoLj109Dh+TiV+8sYXSqlpeu3EyIf4dfHTFDce84z3CTZVS4yaEsMtUW0YV/gSY7X+eGOuVet9C85K4CSG8zq0zBxDgY+KBpbv46asbGZIQSsahfP59+RiGJNjpj9JS/EgAJgfncqJ4kJujFUL0ROa6cqpMgQS0sb/GP4qQqiJqLdaGmn9vIImbEMIrLZrWD39fM3/4YDur9p7iivRkLpuQ7NjJtpGl4/yP8V0XatyyCyvQGlKi7Ew7IoTo0XzrKqgxBbZ9QFAMUcUHOVpQQf9YO82pHiKJmxDCa109KZUQfx++3p3LgxeMdPzE8BTwD2OYKZv3upC43fG/zQB8+LNpnS5DCOGdfC0V1Pi0/aXMNzyOqBOb2ZlfLombEEI4auGYPiwc08e5k5SCuOH0K8wkt7RziVtJVS3bsoswKUVVrYUAX3OnyhFCeCd/awW1ZjsjSm2C4/oTsLeY7JzjMDS+GyNrn/c02gohhCvFDyeh8iCl1bWUV9c5ffrGzEKsGuqsmp05xW4IsIfTGt7/KRz+1tORCNEpAdZKLO3UuAUOPBOT0qijP3RjVB2TxE0I0TvFDcffUkYiBZ0aWbr2cAFmkzHlyOYjRS4OrhcoPAzb3oRNr3g6EiGcprUmUFdg8W27xk0lpVOLD1Gn1ndjZB2TplIhRO9kG1k6xHSE3OIqBjjZR2Xt4Xz+Gfkh5VW1rDt6pzsi7NmObzV+ZmV4Ng4hOqGy1kKQqqbI3hxu9XwDOBI4nLTyrd0XmAN6ZI2bUuoipdTzSqmPlFJnezoeIYQXihsGwDB1lKOFFU6dWlFTh8+x9Vxa/hY/trxPbdZad0TYs9UnbiXZUHTUs7EI4aSy6jqCqQTf9r/Q5cWkM8R6kKpy7+ku0e2Jm1LqJaXUSaXUjhbb5yul9iqlDiilft9eGVrrD7XWNwOLgCvdGK4QoqcKjECHJTPG7xjf7HNuoeiNmfncY36F6sA4yv1iuKXyRU7JRL7NHd8KfrY59Y5IrZvoWSqq6gimCuXfdlMpgCVlKj7KSt7uNd0UWcc8UeO2BJjfdINSygw8CZwLDAeuVkoNV0qNUkp90uIW1+TUe23nCSFEKyp+BGP8c1i19xRVtRaHzytZ9wZjTQfRc//IyfTfMMG0n+MZb7gx0h5GayNxG7YQ/MMkcRM9Tnl5KWal7S8w30TowGnUaRPVB1d3U2Qd6/bETWv9LdBy8a9JwAGt9SGtdQ3wJnCh1nq71npBi9tJZfgH8LnWelN3PwYhRA8RP5z4miPU1lSzZr+Di0XXlDPl0GPs9xlEwPgfkTDjRnZbU0nd9E+oq3ZvvD1FyTGoyIek8ZAySfq5iR6nvunTFBDW7nGpiXHs0P0IzPGekaXe0sctCWjaSSLbtq0tdwLzgMuUUre2dZBS6hal1Aal1IZTp5xrKhFC9AJxIzBZaxkVcJIvdp5w6JS61Y8Qbc3nh0G/BpOJwAA/Xg+/mYjqHFj7rJsD7iHq+7cljoHUKXBqN1R432LcQrSltqIUAJ/A9mvcwgN92WYeQVzJDqj1ju4S3pK4KTvbdFsHa60f01pP0FrfqrV+pp3jntNap2ut02NjY10SqBCiB4kfAcAVCSf4encudRZr+8cXZ6O+f4xPLFNIGjOncfuA2Xyrx6G//ReU57sx4B7i+FZQJuP5TT3D2HZUBnCInqO6wqhx8+sgcQM4Fj4OH10Lxza6OyyHeEvilg2kNLmfDOR4KBYhRG8ROxTiR7Kw7B1KKqpYl9lBrdDyB9FWK/+ou5r0tKiGzWNTIvlzzdVQUwbf/MPNQfcAx7dCzGDwC4akCWD2k35uokexVJYA4Bcc3uGx5fGTsKIg63t3h+UQb0nc1gODlFL9lFJ+wFXAxx6OSQjR05lMMOtugsuyuML3O77amdv2sUfXw/a3WRp8CeF9BhAW4Nuwa2xKBPt1ModSLoUNL0LegW4I3osd32o0kwL4BkCfcdLPTfQotVVlAAQEdZy4JSYksNeaguWwd4ws9cR0IG8AGcAQpVS2UupGrXUdcAfwJbAbeFtrvbO7YxNC9EJDz4fEsfza/0O+3nEUre30wrDUwue/Q4fE8+eic5jcL7rZ7v4xwYQF+PBm8HXgEwDL7u+m4L1QaS6UHm9M3ABSp0LOZqit7FLRNdlbKPvsfqgp72KQQrTPWmX0cfMPaX9wAkBadDBrrUMhe53xWeFhnhhVerXWOlFr7au1TtZav2jb/pnWerDWeoDW+iFXXU8ptVAp9VxxsfdMnieE6EZKwex7iKk7wfTyr9iWbeezYMVfIGcTByfcS0GdP5P6RTXbbTIpxqREsOaECc68C/Z+Cke8Z5RZtzqxzfjZMnGzdr0P0JH37iNk3aNYXzxHJvUVbqWrjcTNt4PpQADSYoJYax2Gua6icWCOB3lLU6nbaK2Xaq1vCQ/vuDpUCNFLDTqLuj7p3OnzAcu3H2m+b/8y+O4RmHADn1unAjApLapVEeNSIth7ooTycbdASIJR62av9q4zXFVOdzi+xfiZMKpxW+pk42dXmkury0gtzOAH6zCsBYfh+dlwRAY8CDepNppK8e94Kby06GDWW4cad7K+c2NQjun1iZsQQqAUPnPvJVEV4Lv11cbtJTnwwU8hbgTM/xtrDxcwNCGUyGC/VkWMTY3AqmH7qTqY9XtjFOXez7oeW3ke/LMf/C0FHp8Ai8+DdxbB57/3mlFszRzfClH9IaDJl+HASIgbbn+AgtUK/7sKPry93WJLtn+GH7U8XHsZr49aDP6h8PIC2Px6l8KtqrXYbx4XpzVV3xzv13GNW7C/D+awOE76pXrFAAVJ3IQQp4f+szgROYErq97mYM5JsNTBezcZczNdvoRakz8bswqZ3K91bRsYI0sBthwtgnHXQfRAWP6gUU5XZK6BykIYPL+xFuvEdtj0MrwwD774g3f1+Wo6MKGp1KlwdB1YW6xQsWkJ7Psctr4JpW3PpVey+X3ydBh7/EbwxYkwuOlro8yPbocv7zESQCcdK6pk8l+/5rUfspw+V/RuptoyqvEDs49Dx6dFB7PVPNyoVW75Gu9mkrgJIU4PSuE77z7iVRHHlz9lTOuR9R0s+C/EDmb7sWIqay1M7h9t9/SoYD/6Rgex5UiR8WE/94+Qtxe2/q9rcR3JAJ9AuPBJuHwJ3PAZ3LkRfr0HJtwAPzwJT02BA1937TquUFEARUfaTtxqSiG3yTLUxdnw1f1GQqotsKWN56q2ipjjq/hGTeTiCalsOlJItV84XPseTLwZMp6A5c4NCNFac9+HOyiurCXjkMy9J5oz15VTpQIcPr5fTDDfVA2G6mI4ucuNkXWs1yduMjhBCFEvesRstviOZdzhZ+Hbf8HYa2DMVQCsPWTM8dZyYEJTY1MijBo3MNbpTJ4IK/8GNRWdD+pIBpXx47CafJtvDwg3ksobPgezP7x2CXxwq2dXKLA3MKFeX6N/YEM/N63hk18aCduVr0HfabD5Vbv9+fTBrwmwVnK8z9lMHRBNdZ3VGERi9oXz/gWTboHvH4fvHnM41E+3H2fFnpOE+vuw/Zh8/ovmfOoqqDYFOXx8v5hgVlQONO5kerafW69P3GRwghCiqYMj7yJYV1AbNZDc6X9mY1YBH205xifbchgYF0JMiH+b545NieBESRXHiyuN0arzHoTSHFjXyaWwqkvRJ7bzfFY8T65sY264vmfArWtgxm9h+zvwxd2du5Yr1I+oS7CTuIUnQ3hqYz+3bW/D/q9g7v0QmWY0Lxccstu5u3TzB5ToIGJHn9UwMOSHg7ZaMqVg/j9gxCWw7D7Y8kaHYRZX1PLAx7sYlRTOT2f252hBJcUVnp/GQXgPX0sFNWbHE7e0mGByiKE6JNnjAxR6feImhBBNjTvjLH5a80tmHb+Tyf/6gUufzuAXb25hZ04J549KbP/cVFs/tyNFxoa0aUbftNUPd64m7Og6lLay3jqEJ1cdIKeojXnQfANgzr0w6nLY/6X7+9jU1djffnwrhKdAsP3mZFKnGIlb2Un44v8geZJRWwYw/ELwD4NNrzY/x1KL/8EvWWYdzxmDE4kM9mNoQihrDzd5Pk0muPgZ6D8LPvoZ7Puy3fD/+tluCitq+PuloxiTEgHAjhypdRON/K0V1DqRuPWLCQYgN3KCMUDBgwNeJHETQpxW+seGMPbs61g4YzJ/vmgki2+YyLJfzmDXn87hl2cNbvfcYYmh+JlNjc2lYPR1qy6B1f9xOpaCPd9i0YrkUTPRGv7xxZ72Txgw1xjIkLPF6WvZU1Vr4YsdJ8jKL28cebnzQ/h7Kuz8oPUJbQ1MqNd3KpTlwlvXGgMqLnwCTGZq6qzUmgNg1GWw6yOoapJEZa7Gv66EzcEzSIky/pFO7hfFxqxCapuuLevjbzS5JoyCt69vc6qQjIP5vLXhKDdN78eIPuGM7GO0tkhzqWgqwFpBnU+ww8enRgWhFOz1H2UsfVec7cbo2ufYcAohhOhFbps1oFPn+fuYGd4njM31NW4A8cNhzNWw7nmY+bvm02R0oGD3Nxwnjd8snEBsdCaPrTjAdVP6NlsntZkBswEFB5ZD8oROPYamHl+xnydXHgQgJsSfc/pU8MdjP8PXUoX68GfGWq9xw4yDq0sh/yCMvrLh/K935xIW6Mu4lAh8zCZjgAIYU6XMvZ8jphRe+WQXb284ysikcF4/7zrUhpdg+7sw8UYALDs/olr74zt4XkO5k/tH83JGFtuyi5nQN7IxYP9QuOZdeOkceP0y4/mI6g9RAyB6AFVh/bjng72kRgVx11wjCY8M9iM5MpAdkrgJG601AbqSOl/HE7cAXzN9wgP50jSds37/c+OLhIdIjZsQQjghvW8kW7KLqKpt0lyZfgNYqjtswmvqeEExSWU7KYlLJzrEn1tnDSAhLIAHlu7Eam2jGSY4xlgX9MDyLj4KKK2q5ZWMLGYNieUvF41k9sBwrs1+kMo6K5dW/5EK/I2as/rasRM7AN1Q4/b62ixufHkDlz+TwYS/LOfONzbz/tFgrIExlEYO55YDZzDz3ytZ/H0m/WKC+f5gPqtKkyB+JGx6xSjTasGy6xNWWMcydUhSQ2z1A0TWHrYzGjQkFq77APrPhNydxqCFj++AxecS8OhQ+hd8y0MXjyTQz9xwysg+4ZK4iQY1FitBVKH9Op58t6l+McHsL6jzaNIGkrgJIYRTpg2KoabOyvrMJn2wktIhtI/RDOigz7/6ikBVw6D0swEI8vPh7vOGsuNYCe9ubKcZZuBcOLbBaDJtQ1FFDSdLqtq9/ms/HKG0qo7fnD2Ea6f05V+RHzLMegAueILEkTO5oexnxgoGH95uzKFWPzAhcSyr9p7k/o92MmtILE9dM56zh8eTcTCfX72znXOKfsfs43ey6VgZd8weyHf/N4d3bj2DlKhA/v3VPqxjrzVWXzixHY6uxa8qj2V6ElMGNPabiwnxZ2BcSMNI31Yi+xrNpnduhHty4edbyL/oDfJ0GHdEb2T6oNhmh49KDiczv4KSKhmgIKC82kIIVWgnatzAWPrqcF65xyd07vVNpUqphcDCgQMHejoUIUQvMLlfFH5mE2v25zUmCCYTDL8ANiw2mhT925+Nvaiihrxdq8AEMcNnNmy/YEwfXsnI4p9f7uHcUQmEBvi2PnngPGMqk0OrYMTFzXbtzy3lpe8O8/6mYwT6mfnqlzOIC209V1VVrYUX1xxm+qAYRiaFw94vjLnSJt5M+IRL+e8YC9eXV/O3oz/inj2vwncPQ94BCElgd1kgd/wvg8HxoTzxo/GE+Ptw3qhErFbNzpwSvt1/ij4RAZw3KhF/n8Zar1/OG8yv3t7Kct9ZnG32NwYpKBO1+JCfOIuwFo91Sv8oPth0jDqL1WiGbYvZB6L68f4OTZBlIldXZRiL3fsGNhwyMslovt5xrJgzBsS085cRp4Pyqlr6UOnQcldN9YsJoaSqjsKKWqLsrK7SXXp9jZtMByKEcKUgPx/G941g9f685juGX2g0l+7/qsMyXsnIYqzeTU1YGoTGN2xXSvHAwhHkl9fwxIo2pgdJSgf/8IbmUqtVs3LvSa57cS1nPfwt7286xoLRfaiosXDvBzvs1g68szGbvLJqbp81EIqPwYe3Gp3+z/4LYPTle+7H6ayOupzP9FT0ir/Avs+pjh3JT5asJ9jfzEuL0gnxb/zubzIpRiWH87PZA7l4XHKzpA3gwrFJDIoL4R/f5GIdej5sewvrro/41jKK9CGprWKc3C+a8hoLO3NKOnw+AT7ZfpzdkbMw1Va0mqx4ZJ8wAHYec6ws0btVVJRhVhqTAwvMN9Uvxhg8czivzB1hOazXJ25CCOFq0wfFsut4CXll1Y0bUyZDSHyHzaUVNXUsXnOIqb778es/rdX+UcnhXD4hmZe+O8yBk3b+QZh9jP5dB1aQX1rFFc9mcMPi9ezLLeW35wwh4+65/CfmEzYH3cGNB35G1ss/hR+eNpKZU/uoy93Ll6tWsbBPCVNCco1lv+pq4LIlxrQjNmEBviz5yWT+G3Anh3QSVBby/vEYSipreWnRRBLDA1vH1g6zSfHrs4dw8FQ5q0POhaoiTKU5fG6dxPRBrWvBJvdvp59bC0cLKth6tIiUcWdDQATs/rjZ/ugQf/qEB8jIUgFAVbnxOjB1UDPeUlq00bR6OK8LE267gCRuQgjhpDMHGonGdwea1LqZzMZqCvuXtbu26FvrjxJVlUWotcSY98yO35wzhEBfM1c+m8Hq/adaHzBwHpTm8Jun3mT7sWL+ceko1vzfHH42eyBRtbnw3SMERSYS7qeIyvwEvvi9sfLCkxPxeXoSr1X9nMcLbkU9PRWOfG+s0BDTujtJQngAz9w4g7v4LbusfXmnZBhP/Gg8I/p0rgXjnBHxjE4O554tUejwVCyY+cFnEmOSI1odGxcaQP+Y4Lb7uTXx2fbjAJw7JhWGnm80/baYi25kkgxQEIbqcqPm1TfQucQtJSoIs0mRmefZtYMlcRNCCCeNTAonPNCXNfaaS2sr2hz1WVlj4flvD3FFrG3wQeoZdo+LCw3g/dvPIDrEjx+/tI7/LtuHpclI020B6QCMrtrAG7dM4cqJqfjW9wNb/W/QGnX1//C5ZRnpdS/w277voq//BOvFz/O3wN/w18DfYL30JbhsMdzwRcOyX/YMjAvlj4sW8mO//3LphRcze2icg89Sa0opfnvOELKLq1mZdhfP+lzD8AFpbfZhm9w/inWZBc0euz2fbT/O6ORwUqODjOS5uhgOf9vsmJFJ4RzKK6dUBiic9moqjMTNJzDMqfN8zSZSIgM5LImbEEL0LGaTYtrAaNYcyGvehyz1DAiKgV3Nm+pOllbxn6/2csbfvyanuIoLo48Yx0W3PZ/cwLhQPvzZNC4Zl8xjX+/nxy+t5VRpNZ9vP87lbxzhsErh1uRMxqc2mees4DBsfg0mLIKIVAbGhfLLeUN4Z28Nn5QMYIXvTJ4tHM+ws2/ANOpSGHlJ4xqj7UhPi2L9PXO5ZnJfZ5+qVs4cGMOU/lH8Yksy/yybb7eZtN7kftGUVtWx+3jbfdOOFlSwNbu4cdWL/rPBLwR2N2+yHmUboLDLwT5zoveqqSwFwD/IucQNjKWvPJ249fpRpUII4Q5nDozls+0nOHiqnIFxttFpZh+jqW7He1Bbxe68Gl5cc5iPt+RQa7Vy1rB4bp7Rn4SP/mA0kyrV7jWC/Hz4zxVjmNw/ivs+3MHZD39DUWUt41IiSEhdQODmF41mWT/btAbf/gtMPjD91w1l3Dy9H1/sOM79H+2gT0QgyZGBLBzdx+nHqzqI1ZlyfnvOEC592ljTtOXUHU019nMraBgZ2tKntmbS8+oTN98AGHwO7PkUzn/Y+JvQOLJ0+7FiJve3s2RXYZb9NSj7jIe4oQ49NtEzWOsTt+BOJG7Rwaw7XIDW2mXvCWf1+sRNpgMRQrhDfU3Rmv2nGhM3MJpLN73Muq/f5YpVkQT6mrl6Ugo3TOtHWkwwlByHwszGNTwdcEV6CqOTw7nrzS2cOSiWf102moAj1bDhachcYyQqeQdg6xsw+TYIa1xz1cds4l+Xj2HBY2vYmVPCny8c0f70Gt1gQt8ozhoez+G8cvpGt71eZGJ4IKlRQaw9lM+NZ/aze8yn244zJiWiYbksAIZdYCTPR76HfjMAiA31Jz7M3/4o1ZoKePFsKDvRep9vMPxsLUSkOPUYhfeyVBmvgYBOJG79Y4OpqLFwqrSauLDWU+10h16fuGmtlwJL09PTb/Z0LEKI3iMlKoi06CBW789j0bQmSUW/GeiACPLXvcPIpN/w2o2TiQhqMufTEaOmqa2BCW0ZmhDGF3fNaNyQegb4BBqjRQefA6v+Bj4BcOYvW507OD6U+xYM44PNx7g83TsSkMevHkd1nbXDWovJ/aJYtjsXq1VjMjU/Niu/nO3HivnDeS1qxAadZTw3uz5uSNzAaC61O7J07dNG0nb1m41LfAFU5MOSBfDZb+HqNzqsIRU9g642Rmv7daKp9PxRiZw5MIboEFnySgghepwzB8Xww6H85ouhm33ZFzmDaZZ13HNO/+ZJG8CRH4xanIR2Fmt3hG8ApJ1pDITI3WXUME26xVgSyo7rpqbx/u3TCPA1293f3QJ8zYQH2plguIVZQ+Ioqqjln1/ubbWvVTNpPb9gY4WJ3UuNVR9sRiaFc/BUGeXVdY3HVhTAmkdg8Lkw5FyITGu8JU2AWXfDvs9hzydOPb46i7G6hqdn2Ret6RojcVNOTgcCxtQy/WNDMJs8l8RL4iaEEJ105sBYymsszRadL66o5bHjIwhTFUxlZ+uTjmRAcnpD36suGTgPCg7C0p8bHfKn/aLrZXqZ80YlcM3kVJ755iAvrjncbN+n244zNiWC5Eg7za3DLzRq0Y5taNg0sk84WsOupoMdVv/HWO1i7v32A5hyG8SPgs9+B1WOD2x4ZPl+Ln8mg3c2NFm+zGo1Voyocu+0JHll1c2XZHNCncXKaz9k9e7Rt7bEDSfXKvUWkrgJIUQnTR0QjUkZ/dzqPbXqAMurh2LxC2s+GW9dtbEoeu4OSO14JKdDBs4zfmavh6m3Q1CUa8r1Ikop/nThSOaPSODPn+zioy3HADicV87OnBIWjE60f+Lgc8Dk2+xvMCq5cekrAIqOwrrnYczVED/cfjlmX1j4CJQeh5UPORTz1qNFPP3NQUwKHluxv7FGdv+X8PEd8NW9DpXTWf/+ci9XPpvB0QLnJ4r9ZNtx7v1wB283TTi7y94v4IlJxuhoNzLXllODD/h4btmqrpDETQghOik80JcxKRGstk3Ee6yoksXfZ7JgXD/MQ88zkoYlC+DhkfCXeHj6DNDWZv2uuiR6AESkQkA4TLndNWV6IbNJ8chVY5ncL4rfvLOV1ftPNU6627KZtF5AOAyYbayiYGuujA8LIDbUv7Gf26q/Gz9n/6H9AJLTYeJNsPZZOLax3UOrai385p2txIb48/CVY8kurOTdjbYkaNOrjT9ztnT0sDvFatUs330Sq6ZVDWVHtNY8v/oQAN8fyOvgaPu+2HGctzccxdrB3Ht2Lm7008zbC+/fDJa6js/pJHNtOVWq7UEx3k4SNyGE6ILpA2PYerSI4spa/vvVPgB+dfZgYy61oEijpq3vGTDz/+DiZ+HmlcZ9V1AKLngCrngFAiNcU6aXCvA11k8dEBvCra9u5H9rjzA+NYKkiHaW3hq2EIqOwPEtDZtG9gkzatxO7oGt/4NJNzs2YnTufRCaAEt/0W5S8fDyfew/WcbfLx3FBWP6MDYlgidWHKC6KAf2fQETboCgaGM1Czf0f9t2rJi8smriQv15a/1RCstrOj7JJuNgPjtzSogJ8Wft4QLqmvbddIDFqvnDBzv43bvb+PFL6zhWVOn4ydnrjb/ToLON37/5h1PXdoZPXQVVJueWbPMmkrgJIUQXnDkoFquGJd9l8v7mbBadkWYkE32nwl3b4aZlcMlzMPtuY4WCpPGuHZ3Yfyb0n+W68rxYeKAvL/9kEhFBfhwrquT8juajG3K+MdL21Uvgm39CZSGjksI5cLIMy/IHjT5OTea8a1dAOJz7DzixHdY+Y/eQjVmFPP/tIa6amMKsIXEopfjVWYM5VlTJzs+eBW2BqXcYSeCRDNj5vpPPQMdW7M7FpIxRu5W1Fl77Icvhc59ffYiYED/uPncoZdV1Tq/tui27iILyGhaMTmTTkULmP/wt72w46tgAjbXPgn+4sZrH2GuMFUCyvnfq+o7QWmOtKsXiIzVuQghxWhqXGkGwn5mHl+8j1N+H22e1vRqC6Lr4sABevXESV09K4ZJxSe0fHBwNN3wGKZOM/mkPj+Ti/Gc5S63HvO8zmPZz5/oFDrsABs+HlX+Fkpxmu6pqLfz2na0khgdyz/mNU4pMHxRDemoEMfvfwpoy1VgTdtx1kDAKvrrfmEOuKwozYePLDTfzllf5XexaJpv3MWtILC9nZFJVa+mwmP25pazce4ofT01j1hBjZPL3B/OdCmXV3lOYFPz5wpF88YsZDOsTxm/f3cbNr2zgZGlV2yeWnoBdH8K4a8A/xEiQI9PgvZuhstCpGDpy4GQZvpZyfJ1c7sqb9PrETSm1UCn1XHGxLC4shHA9X7OJKbaZ+H82e2Dr6T+Ey/WPDeFvl4wmMtiB5zppAvzoLbj1Oxg8n7R9i3nW72HKfKP5Ie5Kvt13ipV7T/L17ly2ZRe13zyoFMz/O1hrYfmDzXb9+8u9HMor5x+XjiY0wLfJKYoHxhSTqo/zfdh5xkaTGeb/A0qy4fvHOvMUGPL2w/NzjVHFttsvKh7n1pJHYfG5/HZgDnllNXyw+ViHRb2w+jABviaundKX6BB/hiaE8v1B5/q5rdp7kkcj3iby+7+QGh3EmzdP4b4Fw1m9P4/rXljXds3bhsVgtRj9CAH8Q+HSF4xRwZ/80nVNyhUFFK18jBGmLILDIjs+3kvJBLxCCNFFV05MobrOyvVnpHk6FNGWhJFw2Ysw+w+88eQf+bJ8GKuWbG91WKCvmbEpEUxMi2RCWhTxYf4cK6zkaEEF2YWVZBdWcnbAxVyy7U3uzp7MAb+haA0bjxRy7ZRUzrSz9uqIkx9RoYL4w97+fFVrMebSS5sGIy425pAbe43zKzMUHYVXLjKSyZtXQEgC723O5l9f7OXNmyaQ9uWNDM/4NdMT/8vzqw9xZXpKqwmM650sreKDzce4YmIyUbZkeNrAGF77IYuq+ng7kFdWTU3Odhb6fQjfAbFDMY39ETee2Y9Qfx9+9942NmYVkp7WooazrgY2LjYmTW66dm/SBJh9D3z9oDF6ety1zj0/9bQ2mqU3LoGdHzLRUs1O0yCGz2o9UXVP0esTNyGEcLezRyRw9ogET4chHKCiB3DmHc/Rt6CC20wKs0lhMinMSnGkoIKNWYVsyCrgiZUHaDkwMsDXRHJkEFWhVzK7chk/KX2WB2IfRisTF49L4u5zh7W+YFUxaudHlA66iCPbFK+vPdK4fNdZf4K9n8PyP8JlLzn+IMpOwasXQXUJLPoUEkcDsPTwMfyjU+g7YDhcvhj13Gz+G/Qkk4/fyfLdua1foxUFsPtjPsuKpdYKN57Zv2HXGQOieXHNYTYdKeSMAa2T0Za+3XeKW80fY/ENxpwwCj79NSSlQ+xgzh+dyINLd/L2hqOtE7ddH0FZLkz6aetCp/1/e3ceXlVx/3H8/b0JCSGBbOxhS9hF9k0UkII8xQIuFSkVtOjP+tOfbUVr61aLrbW1rZbirhUrrdRS94W2KlpEiogI/gQRKwRlJxA2gzEQMv3jnJCbS4IhC/ee5PN6nvuQM3Pu3Em+D+HLzJyZa2DDG96o21t3l69zzlszWFLi/elKvFG7SCWHvX3zEpvh+l/M1FXdadVtCLNy+n3l9xSrlLiJiEiD0j6jSfmzTX1926cxsa/3wENBUTHvb9rH3i8O0S49ifYZTchMTig7ouv9O0h//irmDd0Mfb9V+YetfhqKC2l15hWcfvAQDy5aT9eWKf7zKUl06Hk5HVffR17TU2g6ZBpJ6V/xH4Av98MT34T9W+Hi544mbV8cKmbphnymDe3o9bFlTxh/Fy1euJqfpHTikcWZ5RO3batg/iWwfxPTgbEp2WStnQa9J0F6J4ZkZxAXMpauz69S4rZm9SpuiVtGaND3vD0FHzoDnr4ULl9IcmISE/q05aUPtjFzYi+SE8NSj+UPQ0Zn6Dz62EZDcd6U6Zu/9jZJjmRx3j0W8l6hOKCCUcWsgdDrPDbsK2HpksX8OifY+x0qcRMREYmQkhhf4bTnUX2meJv3LpwJPcZ7i+orsurP0OpUaNuf68buZdJDb3PJY8uPVicxgHkJXRjw9u0cWfoL3g31ZmWzr7GjzVguGNGHU7NSy9o6XAh/mQJ5a71zVTuWbeS85JPdHCouYUzPlmX395sKG9/i0g/m89rmLrz3WU8GdkiDlXO981eTW/LawAd5c9lybmi5Gt643Xu1H0rTc++nT7tUf51b9+P+rI6UOE7Z+DglFk/c6Vd726ac/zDMmwSv3AITfsfkwe2Yv2IzC1ZvZ3LpeblbV3pbf4y7E0KVLLlv2homzDru51fVslzvCduh2Zm10l60KHETERE5UaGQ9/TjnLGwZJa3xUekHau9ka2zfwNmDOqUwavXjuRAYfnjpIpLRvGvLatJ+eQFOu18hcH7Z3N4333sWZfKkZRE4kpH+YoLoXCft1av69hybbz+UR5NE+MZHD4VaQbj78ZtfY978u/nt6/2pEvTl0hdN5+D7Uayfcx9/OLpXDKyLiTlqlmwf7N35u3Se+GvF3FmzkPcu2QnBUXFpCRWni58+PHHTHSL2Jp9AZ2a+qN6XcfC6d/32so5kwE9zyGnRTJPrdhclrgtf8Q7t7ffRSf606+WZbn5tG7WmI6Zwd0KBJS4iYiIVE/7IdB7specDLgE0juWr1/5Z4hLhN4XHi3q1qqSg81zRsHIUd7are3/z77lT7F45Ye0tsYM79K8bAKw29e9jYXDlJQ4Xl+Xx8juLUiIjxi5SkwhNHkuaQ99jdu3TCfRipldfD6z119Ayfo1ANwwroc3vZrWAYZf661N+9M5TGtyF78vmcryjfmM7tGq0h/DobfuIY4SMsb+qHzF6J96e7G98H2sTV8mD2rPnf9YR+6uAnKSCr0kccAl3h55dcw5x7LcPQzvklk23R1QStxERESq66zbYN3L8PfrYdD/hFU4+GC+l2SdyF5xZtC2Hy3O68f+zFx+tOAjfjm8NxcN7VDpW0pPSzgrfJo0XKteuAmzOPL6Hbzb91a6tBnFvX5Vk8Q4RnVrUf7+7BEwZibNF87kikaZLF2fU3ni9sUeem17hqVJZzKybdfydfEJcMEceHgk3DeYK0KNmJZ4hEYPhiBUAkcOwZArqvZzqaENuw6yu6Do6NY9QabETUREpLpSs2Dk9fD6z+GTV4+tHzi92k1fdkY2iz7exe0vr+W0nAxyWlS8jq70tIRR3SpJ3ICEgVNJGDiVwVX98DOugS3vcsO6edywrhdMOKXC2wreeoAUvmTrqVdW3E5GNkx7FtY+Twj494c72F1QxLcHdSDUohu0OP76udryzkZvM+GhStxEREQauBE/hO7f8B4eCJeQXKPEJBQy7rqwL+NmL2bG/Pd55qrTaRR37CL+hR/lMbBjetU2JK4qMzjvAT6fPZwff34ne3eOJ71VxKhfUQEJKx7mtSMD6d1/WMXtALQf7L0Al7WDW554jzadBh13+rW2LcvdQ6tmiXQK+Po20MkJIiIiNdeyp3cObfirFkaTWqc25lfn9+aDLfv5/cL/HFO/bV8ha7cfYEzPOkiCGqeybdwfSKGQkr9d6m1BcmB72eudB0k4fIAnEybRq23VjpAa3aMlmckJPLViS+33txLe+rZ8TssJ/vo2aAAjbjo5QUREguzs3m2YPKgdDyzaQOvUJFqkJB6tK50CrHR9Ww11O3UINz/7v/wm/x6Ydex06XJ6kdnjjConRAnx3mbFc9/+lPyCIjLDvpe6snH3QXZ9Xj/Wt0EDSNxERESCbubEXqz4dC+3Pr/mmLrOLZLpXMn6t5qKjwuRn30O1+/I5K6zyj/9+Wn+Qa75VzNu7XFiSeOFg9rz6JKNPP/+trJTJOrQstw9AAzNDvbGu6WUuImIiMS45MR4FvxgBBt3Hzymrm1a4zqdAhzWOZNfrOvKdV1G0zYt6Wj5U6+sIy+UyxldvvpkhXDdWzelb/s0Hn5zA1v3FtK5pZd4dm6RQvOUhFr/Xpbl5tOyaSLZzZNrtd1oUeImIiISAEkJcZxSxbVktan0yKuL57xD08aNjpbn7ipgYMd0UpMaVfbWSt18dg9+/vJanly+icLDZWeMNm4UKttw2JfSOJ47L+jD17qf+HRwfVvfBkrcRERE5Dh6tG7KRUM7sGVv+adm+3dIZ/rpnarV5tCcTBb8YAQlJY7tB75kQ14BG3YVsG1fIc6Vv3fJ+t18d+4K7p7cl3P7ZZ3Q53ya/wV59Wh9GyhxExERkeMIhYxfnt+7ztrOSksiKy2JkZEbAfsOfHmYy+euYMb89zlQeJiLh3WqcvvLckv3b6sf69ugAWwHIiIiIsHVrHEj/nTZEMb0aMWtL3zI7IWf4CKH5YBDxSUcLCou9/r3+t20aJpITj1Z3wYacRMREZEY17hRHA9NG8ANz6xm1sL/sKvgS/pkpbF+VwHr87zX5r1fHDPNCjCxb9t6s74NlLiJiIhIAMTHhfjtpD6kNWnEnCUbgU0kxIfIaZ5Mn3apnN8/i+TEuHLvMYyv92odnQ7XESVuIiIiEgihkPGT8T355oAskhPiaZ/RhLhQ/RlNqwolbiIiIhIYZkavtqlffWM9pYcTRERERAJCiZuIiIhIQNT7xM3MJprZI/v37492V0RERERqpN4nbs65l5xzV6SmNtz5cBEREakf6n3iJiIiIlJfKHETERERCQglbiIiIiIBocRNREREJCCUuImIiIgEhBI3ERERkYBQ4iYiIiISEErcRERERAJCiZuIiIhIQJhzLtp9OCnMbBfwWR1/THNgdx1/hlSPYhObFJfYpLjELsUmNtVFXDo651pEFjaYxO1kMLMVzrlB0e6HHEuxiU2KS2xSXGKXYhObTmZcNFUqIiIiEhBK3EREREQCQolb7Xok2h2QSik2sUlxiU2KS+xSbGLTSYuL1riJiIiIBIRG3EREREQCQolbLTCzcWb2sZmtN7Mbo92fhszM2pvZv8zsIzP70Myu8cszzOw1M/vE/zM92n1tiMwszsxWmdnL/rXiEgPMLM3Mnjazdf7fnWGKTfSZ2bX+77E1ZvakmTVWXKLDzB4zszwzWxNWVmkszOwmPyf42My+Xpt9UeJWQ2YWB9wPnA2cAnzbzE6Jbq8atGLgh865nsBpwNV+PG4EXnfOdQVe96/l5LsG+CjsWnGJDbOBfzrnegB98WKk2ESRmWUBPwAGOedOBeKAKSgu0fI4MC6irMJY+P/mTAF6+e95wM8VaoUSt5obAqx3zuU65w4BfwXOjXKfGizn3Hbn3Er/68/x/gHKwovJXP+2ucB5UelgA2Zm7YDxwKNhxYpLlJlZM2AkMAfAOXfIObcPxSYWxANJZhYPNAG2obhEhXNuMbAnoriyWJwL/NU5V+Sc2wisx8sVaoUSt5rLAjaHXW/xyyTKzKwT0B94B2jlnNsOXnIHtIxi1xqq3wM/BkrCyhSX6MsBdgF/9KexHzWzZBSbqHLObQXuAjYB24H9zrlXUVxiSWWxqNO8QIlbzVkFZXpUN8rMLAV4BpjhnDsQ7f40dGY2Achzzr0X7b7IMeKBAcCDzrn+wEE0/RZ1/nqpc4FsoC2QbGbTotsrqaI6zQuUuNXcFqB92HU7vOFsiRIza4SXtM1zzj3rF+80szZ+fRsgL1r9a6DOAM4xs0/xlhOMNrMnUFxiwRZgi3PuHf/6abxETrGJrrOAjc65Xc65w8CzwOkoLrGksljUaV6gxK3m3gW6mlm2mSXgLUh8Mcp9arDMzPDW6nzknPtdWNWLwHf8r78DvHCy+9aQOeducs61c851wvs78oZzbhqKS9Q553YAm82su180BliLYhNtm4DTzKyJ/3ttDN6aXcUldlQWixeBKWaWaGbZQFdgeW19qDbgrQVm9g289TtxwGPOuTui26OGy8yGA28BqylbS3Uz3jq3vwEd8H4hXuici1xoKieBmY0CrnfOTTCzTBSXqDOzfngPjSQAucCleP+xV2yiyMx+BnwL72n5VcDlQAqKy0lnZk8Co4DmwE5gJvA8lcTCzG4BLsOL3Qzn3D9qrS9K3ERERESCQVOlIiIiIgGhxE1EREQkIJS4iYiIiASEEjcRERGRgFDiJiIiIhIQStxEJDDMbLqZOTPr4l/PMLNvRrE/aWZ2m5kNqKBukZktikK3RKQei492B0REamAGsARvV/loSMPbz2kLsDKi7v9Oem9EpN5T4iYiEsbMEp1zRTVtxzm3tjb6IyISTlOlIhJI/rmnHYGp/vSpM7PHw+r7mtmLZrbXzArN7N9mNiKijcfNbIuZDTOzpWZWCPzGr5tiZm+Y2S4zKzCzVWb2nbD3dgI2+pd/COvDdL/+mKlSM+tuZs+Z2T6/T8vMbFzEPbf57XQ1swX+Z39mZj81s1DYfSlmdq+ZbTKzIjPbaWYLzaxHTX+2IhK7lLiJSFCdD+wAXgGG+a/bAfw1Z0uBDOC7wAVAPrDQzAZGtJOKd/D9k8DZwF/88hy8A9enAucBLwGPmtmVfv12oHR93a/C+rCgos6aWVu8ad2+wPeAycA+YIGZnV3BW54D3vA/+3ngZ5Sdiwgwy2/jZ8BY4ErgfbzpWxGppzRVKiKB5JxbZWZFwG7n3LKI6t/inR042jl3CMDMXgHWALfiJUOlUoBpzrlyh3U7535Z+rU/0rUIaANcBTzknCsys1X+LbkV9CHSdUA6MMw5t95v9+94B7rfAUSeZXi3c+6P/tcLzWw08G2gtGwYMM85NyfsPc99RR9EJOA04iYi9YqZJQFnAk8BJWYWb2bxgAELgZERbykGXq6gna5m9qSZbQUO+6/Lge7V7NpIYFlp0gbgnDuCN9LXz8yaRdwfOXK3Bu8w61LvAtPN7GYzG2RmcdXsl4gEiBI3EalvMoA4vJG1wxGv7wHp4WvFgDw/gTrKzFKA1/CmNW8ERgCDgceAxBr0a3sF5Tvwksr0iPI9EddFQOOw6+8DDwOX4SVxeWY2y8yaVLN/IhIAmioVkfpmH1AC3A/8qaIbnHMl4ZcV3DIM78GHEc65JaWF/shdde0BWldQ3trvQ2SidlzOuQLgJuAmM+sITALuBA4BN9SgnyISw5S4iUiQFQFJ4QXOuYNm9hbeaNnKiCStqkpHrQ6XFphZOnBuBZ9PZB8q8SYww8w6Oec+9duMA74FrHLOfV6NfgLgnPsMuNvMpgKnVrcdEYl9StxEJMjWAiPMbALelONuPym6DlgMvGJmc/CmKJsDA4A459yNX9HuUuAAcL+ZzQSSgZ8Au/GeQi21E+9p1Slm9gFwENjonMuvoM1ZwHTgNb/NA3ib9HYDxp/g942ZvQ28CKwGCvDW9fUF5p5oWyISHFrjJiJBdhPwMfA3vHVetwE451birUnLB+4BXgVmA73xErrjcs7twttuJA5vS5BfAY8CT0TcV4L3wEI63oMP7wITK2lzGzAc+BB40G83AxjvnPtnlb/jMovxtgOZh/cgwyTgWufc7Gq0JSIBYc5VtLxDRERERGKNRtxEREREAkKJm4iIiEhAKHETERERCQglbiIiIiIBocRNREREJCCUuImIiIgEhBI3ERERkYBQ4iYiIiISEErcRERERALiv3UiOZ4MEHiDAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAmYAAAGYCAYAAADoXC5+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/av/WaAAAACXBIWXMAAAsTAAALEwEAmpwYAABu20lEQVR4nO3dd5ycZbn/8c81M7ubTa8kIQESJCBNJCYUEQgBpQsqTQWJih7xoKKen8cGBuzHBoiIiBILgihdUEBIKBJKSGihlxCSQBKSbArZzU65f3/cz7M7OzvlmbK7s8n3/Xrta3bmaffO7M5ec113MeccIiIiItL3Yn3dABERERHxFJiJiIiI1AkFZiIiIiJ1QoGZiIiISJ1QYCYiIiJSJxSYiYiIiNQJBWYBM1tiZi74OrbIfk8H+8zIeXxe8PisEteZE+w3O+fx2VnXX2Zm8SLneH/Wvs7MJkX4EYu1aaKZ/dTMnjKzTWbWFrThUTP7lZmdFOEcn81qz7kl9p2VtW+rmQ0vsu+UnJ91Rs722TnbnZmlzewtM7vHzD5pZpZzzIxgvyWlfq6sY3KvUehrUp5jP2Bmfw+e0y1mtsHMXjKzO8zsfDPbM2o7ekvOa1Toa1yR4/c3sxvNbFXw+/Simf2fmQ0rcd1+cVy5sv7unZn9pMh+l5Z4f5hT4jrh6zYv5/EZ1vW1e2eRcww1s81Z+86K8jMWOFe+v8/cr5tyjlkS4Zhu7wU55xiW9TM8nmd71Gtkf80Ljs37HOe5Rt73GTObVOD8m8zsWfPvuTsXOW/Fz4+Z7WdmfzSzV4Lf901m9qr5/18/MLMDiv1MfaGMn/f8PMfGzexsM/uPmbWYWTL4W/+nmZ1YRhv2Mv/e7czs6Zr+gHkkevoC/dQPzeyfzrlMH11/AvB+4F8Fts+q1YXM7BDgH8AQYA3wILAaGAG8G/g8cCrw9xKn+lTO9xdFbMIA4DTg8gLbZ0U8z8vAA8H3TcBuwGHB1wlm9hHnXDriuYq5HthUZHuXbWZ2EfCl4O5iYAHQCuwIHAx8ABgK/E8N2tYTsp/XXK35HjSzjwJ/AuLAf4DlwAHA/wM+ZGYHOedW9dfjauAcM7vYObesB84d1Szg6wW2nQo01/h6xX6PFhZ4/A7gzSLnLLbto3T+DPuY2VTnXPZ1/g6MzjlmMPCR4Ps/5Dnnc0WuV6ns62wP7Id/z51lZh9wzv2nyLFlPT9m9hXgp4ABrwB3ARuB8cBU4FBgV6DkB/Felu+1Co0Ejg++n5u9wcwSwD+BI4At+N+/t4CdgaOAo8zsIufcl4tdPDjPHKChwvaXzzmnLz/J7hLAAW8Ht2cU2O/pYPuMnMfnBY/PKnGdOcF+s3Menx08/mhwe22B44cCm4Fn8EGAAyZV+DM34f8ZOfwf7IA8+7wH+GGJ8+wenGMT0BJ8P63I/rOCfRYCKeChAvvFgNeBtcCLBZ738Hmbk+f4DwHpYPtZWY/PCB5bUsZz5cp9rvFvGA7YAByWZ/tA4BTg4339+1/kNer2vJY4bmLw+5kGTsh6PAFcG5zzxv56XBXPZ/h3H76//K7AfpdS/P2h6OuR9brNy3k8/J1/CViH/7uPFzjHg8Hf5UIivKeVaE+kduccsyTf33qZ130kOMey4PbSCMdMCv/OK3mO8+yX932m2HWAMcD8YPsztXp+gH2C3/UkcFqe7Q3AMcB/1+L3vbe+gK8Fz8XzebZ9Ntj2GrBjzrYjg+fCAVNLXOP88HcouH26p38ulTK7uyS4vcDMGvvg+g8Dz+KzPMPzbD8N/0lwTg2udTD+U9oK59z/OOfacndwzj3mnPtGifN8Orj9G/6fGnTNoBWyAv+pbX/LX1o5Av8P9Fr8J56yOOduBK4O7p5c7vE1cGpwe6lzbm7uRufcZufcdc65q3O39WPn4n8//+Ccuzl80DmXwr9RbgBONLM9+ulx1boCaAPOLPA739Pa8H9P2+Oz8l2Y2a7AgcCd+L/Pfsd814Dp+A+Ks4KHP2ZmTX3WqIicc6vpzJ7vXqykWaaT8R90/+acuzZ3o3Mu6Zy73Tn3qxpdr7eE/2d+n2fbYcHtr51zS7M3OOfuoDPDVrB8a2bvAr4N3EDpqlHNKDDr7nr8p63JwOf6qA1z6Czx5fok/pPPn2pwne2C29WVniBI854e3L2Kzj+Qj5rZgAinuCq4nZVn2ydz9qnEguB2pyrOUanw+e2Jcli9OjG47RZsOuc2ALfm7NffjqvWcuCX+PLpD2p87qh6+m+ur3V8UHTO/Rv/QXcEPoPeHzyR9f3YGp1zq3svMrOD8F1WUsAf8+wS9cP8WwXO34AvNW/Cl5d7jQKz/MK+F98ys8F9cP0/4YOvT2Y/aGa74aP7fznn3qjBdcJPEXuZ2eEVnuM4/JvHK8D9zrlH8GXW4cCHIxx/M760coZlDXgIsoUnAoudc49W2DbwpV+oIONWA+HzO8tq3Jm8F+1iZt8zsyvMDxD5WKG/CTMbCrwjuFvoNQsf37e/HVdDP8SX/D9kZvv3wPmLyvobPTE7K29mMeAMfNeBW3q7XbUQ/DMNPyjOCW7DIDNKFr8eZL9XrKzROcP3opPMbEKNztnXwtfz9gL/D8M+2meb2Y7ZG8zsSHxGbQVwe4Hzfxvfz/pc51ytXodIFJjlEZSd7sR/yvhqH1z/DXzHzv3MbPesTWGgNqdGl3oQeBz/6f0uM5trZt82s2PMbEzEc4R/HHNcUJCnjDdC59wW4Bp8aeUDWZtOw2cN50RsRzfBP5oTgruPV3qeKlyB/zS3D/BaMBrqbPMjASsuk5cxSin7a0mFlzsI+BbwGfzfwtXAUss/UndScNsSZJ3yCf9BTO6Hx9WEc24d8H/B3R/V+vwRzcH3Mf1o1mMfwA88uib4u+yPjsf303oZuD947E/4v8PDc/9B16mwM/ti4NUanfOP+MzPBOBFM/ubmX3JzA42s4GVntQ6ZyMo66vaH8bMBuH75wL8rsBufwWuxA+0esHM7jKza83sYXzQ9ggw0znXbTCXme0LfBOfBMmXjetRGpVZ2DfwfTC+amaXBbX/3nQVvjPmLOB/g2xSTT/NOucy5qcGmYP/WWcEXwCYH2b+G+C3Ls+IRjMbCxyN7xCZPbroT/iswEwzm+ScW1KiKVcRjETCj6IBH4SmgD+X9UP5djXiU9znA9Pw2cdLyz1PAa9a19k3sj3hnHt3eMc590gQwFyGDzzPCL4A2szsduBHFWQEi41SKiRvur6IN4Dv4X/XXsG/FrvjO9t+CPirmR0T9NUIhZm0t4ucN3wTHNIPj6uli4EvADPM7Mic57E3/AlfSp0F/Dp4rCfLmGea2Zn5NjjnCv1BzS3yt7beOTc8z+PdPig65940s3/iA55ZwIVRG92bzGx7fBt/DKwHPp31YTefyM+Pc25pkCWaA0zBj7wMP1wlzU/98RPn3F1lNvtf+MEIve0U/N/xmxTIeAXP3WfM7Bn8c3pE1uZ1wN3k6UcZ/P/4A37U+Wdr2+xoFJgV4JxbaGbX4TtwfwvfWbg33YIPws4ws2/iP81uj+9I3l6rizjnVgAfCDo5fhDf8fc9+PLku/Fv2h8xs2PzXPdM/O/Qv7M7VzrnVgZBxwfxb4SzS7Rhgfm5YU4wsxHAOPyw8X8454oNB+/SlgJv/BuB/6qyHJqt2HQZS3MfcM7dbGb/wgewh+M7Je+DzwZ+GP8zf845d2XUBjjnenxqjSBQyA0WHgI+bGY/A74C/Cxnn/C/RLmfiPvLcTXjnNtsZhfi/75+aGZ3lvgnXOvrv2lmdwDHBgMc3sBnl592zj3WA5csNl1GIcWmg9ic+4CZjcdPg5Ch+3QXVxEEZmb23d58rospkD16DTjUOfdaicPLen6ccw8GFZiZ+P8n++NL9YPxH8zfHzw33eYDK8Q511cZ37Af4R+DATvdBF0W/oL/2b6H/5D/Jj4w/QZwHvBBMzvYObcx69Dzgb2Bs51zr/dQ+4tSYFbct/Hz2nzOzH5R4g8l/AMr+BEmZ3vRNwbnXLuZ/QU4Bz+0d1awqUc65TrnngSeDO+b2T747MjH8J80vgTkToxZrE1XEQRmZnZBhDfCOfgpOz5KZ6mpnJ81+40/je/D8wRwi3OupYzzlPI/ETKAXQRloZuCL4LSwVH4rOKuwK/M7F+ub+e1Ksf38L8Pe5rZjllBefjmVqxfZrgt+42wvxxXa1fiy8P74j8Adhstl6Wm7y+Bq4Bj8X/HS/ClzZ7q9P+Ac25Wmcf8yDk3r4z9zyTolpHnH+o/8IOcJuP7Ft1TZluy1fK1CAPIBnzbDsAPVLrGzA4rUVIu9/khqHzcFXyFffJmAN/Hf2g8z8z+EfRDrEvmRw4fFNzNNxoz9DP87/c3cgLIJ4DTgiTAB/CjYL8TnPs9wP/ip7/6TW1bHp0CsyKccy+Z2ZX40ZkX4v/wCwk/oQwqcdrwDb/YJKWhq/CB2bnAIcCTruskiT3GOfcE8PGgo/cH8R3xOwIzM3svvrQF8GUzyx21Ev5u7YTPFP27xCX/jO9v8yl8ZvAt/JtpVJW88fcJ59xm4Iagr8ML+PnMjgZ+G+V4M/spFZQya5Vpc86tM7NV+IkpJ9CZKVwS3A43s6EF+m/tkLNvfzquppxzKTP7Nj4g+56ZXV9k9554f7kVP6n06fiSTkVdB+rIrOB2NzPLl50L35M+RXWBWc1ei9z3LDM7EJ8JOxD/Aej/VdbEaJxzSXz/4vn4CXQn4DOnkQIzM/s6UPa0L1W+V4fl6gecc88XaFfY9QfyjLwO/AUfmB1BEJjhs6oJfMUot1Q8PLidbJ2rPpzlnHupzPaXpMCstAuBTwCnW5GlVPAToQLsUuJ8U3L2Lygopz5J53xDc0od0wPuxAdmuYMBsjv2Tytxjk9RIjALyp9hPxCAS2pZsq1HzrnlQf+HaXR/fos5ifKn/3iNGq0uELzphSPHOv7pOOc2mNnL+BGP0/F9OHLtF9wu6m/H9ZDr8JnpqcBZRfbrifeXMCv/BXyQfYvrmZUOepyZvQ/frxR8Z+9infw/bGbDnHPrK7xc+Ny+w8ysSDUg8msRcs7NN7+k3e+AL5rZ5c65lytsZ2TOuU1BcHYS5b0XHYVfMaBcsyo4Jnzv+URwt1Cnf/AD98J56wq9zi3B7cg823anM/GQayCdP3OPzNqgUZklBCMkL8Y/V8XmHQo/gX3Q/Nxe3Zif7mJvfKnt/nz75HEF/lPtKmr8adaK9BzNEr7BdZTZckbEHOycs3xfQLgG5IesyHqYWa7E/6xr6N/zKAGln9/gTWb74G7kMqZzblKh57zI16QqfpRcx+HfnDbSfZmacLLWj+ceFPT5CAPvG/vpcTUV/FMPJ3A+H/+85nMv/n3jXWaWNzgLylIfDO52m9C4gN/T+TcXuZ9jHQo/KP6u2N8BfhqUZrqORi3XQvw/++H4/lqFhJ3ro74Woavwo8gb8f2gqlbpe30pzrkZFbwXRWlLIcfgP0RsxE9oXsgaOqdIKjSB7IHBbcfIV+fc7CJtDiesXZz1+OMV/yRFKDCL5sf4jvjHU3j4/I34fk6T8X2GukyuGoy4CQOrP7uI85A5537lnBvtnBvraj8y9Hgzu8HMZpqfWiK7vWZ+kddzgof+mrX5ZPxotVfxawwWavsz+DexAfi+akU5524JftbRPfUL38t+Z2YXWv6FzYfiR4puj3+TKTSXTq8zs4Hmp/Xo9mnQzI6hs+T6q6AUku0i/GimM83sg1nHJfB9NoYCNwW/G/3uOOu6APUkasA5dyf+g904OleLyN3nDXxJxoA/Bx3ds9vVjB9IsBN+2aWbIl778ay/uVtLH1F/gg+K4coepSbeDrdXPKdZkMm/OLh7edDnKbs9CTO7AP+Pfy3FMzv5zu/onEvzdDObUmz/iL5nZheZXxWhCzMbYH4B8P3wwX+vzXBfgfB1u9Y5V3BUdfAahb/Pl5jZO7K3m9kH6BzQV6xvZ59QKTMC59x6M/sRfu6hvJ9onXNJM/swfvjwZ/EjGefjP1mNx3dWbMLPHfbFXml4aTH81AcfAtaa2SJ8Zm4osAedQeg1dO3/FP5x/LlIGj/0J3yZ5lP4aSPqyXgze6jI9oXOudy+cz81s2L9dy7J6gc4Ej8FwXlm9gJ+BvLN+H/A0/Fp8DbgTOdcudNZ9KRG/Gv1czN7Fl+KSeNT+2F/khvwGZ4unHOvm9mn8a/7TUFfnxV0dmp+Cfiv/nocXT/M5gal1fgGfjm2YnNKfQFfynwv8IqZPYgfTTkMHwSMwv8MH84TMPdHXzezWUW2/yUIak/F/y0tBe4rcc5rgZ8D081sL+fc0xW27bvAuwgmwQ76i76Gf/32w3/g2gCc7JxbW+7JnXN3mNlcfJbmPDrLd9miPj/g+8N9CfiSmb0GPBW0bzv84JNR+L/xc/N8iKkLZrYdvjM/RAt2v4zvJrIbna/RSvzfUDhx9DXUY79KVwcLkdbDF50Lw+ZdfBuf9XmdzsWsZxTYbxR+eohH8UFZEj8a6G78JJ0NBY6bTcTFdrOOqXYR8wH4PgI/xQeMr+EDhVb83FXXAcfmHPOOrOdg1wjXGEvnYrHvCh6bFdz/RxltLbR4fPi8zSnjXDOyfoZiX/OyjomyvwNOzDpmAj4wuxo/4nUVvoP1enwm8afA5L7+3c/z/DTi//HcEfxdbALa8csJ3Yz/x1/qHPvjszar8SWFl/AfbIb15+Pw5SmHn3iynOd0TnDc/xTZ5+9Zv0ezC+zTgH8fuTtoczL4fXo0+FsYVeJ3PvICzPjBN46+W8S81Ne5wf73B/d/EPH8twb7/zzn8UnhuSOex/BdOm7DB8jt+Oz3k8Hf9g4Fjot0HXyA5/DvGbtmPV7W8xMcMwpfvv0d/r3njeB3ZyP+vfXXBO/P9fqFH8Hs8KXEqMcMw3fsX4APRFPB382dwEfLvH7Zf0OVfllwQRERKcHMfoMPjKa6raPcLiJ1RoGZiEhEZvYK8KBz7vSSO4uIVECBmYiIiEidUOf/rYiZzSlj9yudc+UukSIi26gyJxN9wJWxzJiIdFJgtnU5s4x951H+2nUisu0qdzJRBWYiFVApU0RERKRObBUZs9GjR7tJkyb1dTNERERESnrsscfecs7lXf5qqwjMJk2axIIFC/q6GSIiIiIlBRP95qUlmURERETqhAIzERERkTqhwExERESkTigwExEREakTCsxERERE6oQCMxEREZE6sVVMlyEiIlKJDRs2sGrVKpLJZF83RbYCDQ0NbLfddgwdOrTicygwExGRbdKGDRtYuXIlEyZMoLm5GTPr6yZJP+aco7W1leXLlwNUHJyplCkiItukVatWMWHCBAYOHKigTKpmZgwcOJAJEyawatWqis+jwExERLZJyWSS5ubmvm6GbGWam5urKo33amBmZjuY2Vwze9bMFpvZl/LsM8PM1pvZ48HX+b3ZRhER2XYoUya1Vu3vVG/3MUsBX3XOLTSzIcBjZnaXc+6ZnP3ud84d18ttExEREelTvZoxc8694ZxbGHy/EXgWmNCbbaipBy+F9cv6uhUiIiKyleizPmZmNgnYF3g4z+YDzewJM/unme3Zuy2LaPNauPNb8MzNfd0SERGRss2ZMwczY8mSJR2PzZ49m3vuuafbvrNmzWLixIm92DpfEpw9e3bH/dmzZ3cpE7a0tDB79mwWLlzYq+3qaX0SmJnZYOB64Fzn3IaczQuBnZxz+wC/BG4qcI7PmtkCM1uwevXqHm1vXulk11sREZF+5Nhjj2X+/PmMHz++47ELLrggb2DWF+bPn89ZZ53Vcf+ss85i/vz5HfdbWlq44IILtrrArNfnMTOzBnxQdrVz7obc7dmBmnPudjO7zMxGO+feytnvCuAKgGnTprkebnZ3mSAgy6R6/dIiIiLVGjNmDGPGjOnrZhR0wAEHdLk/ceLEXs/a9YXeHpVpwO+AZ51zPy+wz7hgP8xsP3wb1/ReKyNKKzATEZH6sGDBAsyMBx54oOOxX/7yl5gZ3/72tzsee/HFFzEzbr/99m6lzLBM+P3vfx8z61ZKBFi0aBEHH3wwAwcOZMqUKVx++eVltXPevHmYGddffz2zZs1ixIgRDB06lI9//OOsWdP1X32xUuaSJUuYPHkyAJ/5zGc62jtnzpyy2lOPeruUeRBwBjAzazqMY8zsc2b2uWCfk4CnzewJ4BLgNOdc72fESgkDMgVmIiLSx6ZOncrw4cO7lCHvuecempubuz0Wj8c5+OCDu50jLBPOmjWL+fPndyslbtiwgY997GOcfvrp3HzzzUyfPp2zzz6buXPnlt3ec889FzPjmmuu4fvf/z633HILJ510UuTjx48fzw03+KLbN77xjY72HnvssWW3pd70ainTOfcAUHSCD+fcpcClvdOiKoQBmfqYiYhsNS64dTHPrMjt+ty79th+KN85vrxxb7FYjEMOOYS5c+dy/vnnk8lkuPfeezn77LO55JJL2LRpE4MHD2bu3LlMmzaNIUOGdDtHWDqcMGFCtzIiwMaNG7nssss47LDDADjkkEO48847ueaaazoei2rPPffkqquuAuCoo45i5MiRnH766dx9990cfvjhJY9vampi3333BWDnnXfO297+SjP/V0qlTBERqSOHHXYY8+fPp62tjccff5yWlha+9rWv0dTUxP333w/4UuLMmTMrOv/AgQO7BGBNTU1MmTKFpUuXln2uU045pcv9k08+mVgs1qVz/7ZKi5hXqqPzf7pv2yEiIjVTbqaqnsycOZMtW7bw4IMPsmjRIvbZZx/Gjh3L+973PubOncuOO+7IypUry85uhUaMGNHtsaamJtra2so+19ixY7vcb2xsZMSIER0LgG/LFJhVKh32MVMpU0RE+t7ee+/N6NGjueeee1i0aFFHZmzmzJlcd9117LDDDjQ2NnLQQQf1cUth5cqVXe63t7ezbt06Jkzov3PO14pKmZXSdBkiIlJHzIxDDz2Uu+66i/vvv79LYLZo0SJuvPFG9t9/fwYOHFjwHI2NjbS2tvZ4W6+77rou9//2t7+RyWQ48MADI5+jqakJoFfa25sUmFWqY4JZBWYiIlIfZs6cySOPPMLmzZs7Rl5OnTqVoUOHMnfu3JJlzD322IPbbruNu+66iwULFrBixYoeaefixYv55Cc/yR133MEvf/lLzj77bA499NBIHf9DY8eOZdSoUVx77bXce++9LFiwoNuUG/2RArNKaboMERGpM2HgNW3aNIYOHQp0jtjM3l7IpZdeyqBBgzj++OOZPn06V1xxRY+08+KLL8Y5x6mnnso3v/lNjjvuOP7+97932y97CaZcsViMK6+8knXr1nHEEUcwffp0br311h5pb2+yepwirFzTpk1zCxYs6N2LPnc7XPtR2OskOOl3vXttERGp2rPPPsvuu+/e183YpsybN4/DDjuMu+66iyOOOKLgfhs2bGDYsGH88pe/5JxzzunFFtZGqd8tM3vMOTct3zZ1/q9URx8zdf4XERGplccee4xrr70WgP3337+PW9P7FJhVKq3pMkREREKpVPGuPfF4PNJ5PvOZz7Bq1Sp+/OMfM3369Fo0rV9RYFapMCDTzP8iIiI0NDQU3X7VVVcxa9YsSnWhWrhwYS2b1e8oMKuUpssQERHp8OijjxbdHi46LsUpMKtUWn3MREREQtOm5e3LLmXSdBmV6pguQ33MREREpDYUmFVKi5iLiIhIjSkwq1RYwlTnfxEREakRBWaVUsZMREREakyBWaW0JJOIiIjUmAKzSiljJiIiIjWmwKxSmsdMRET6sTlz5mBmLFmypOOx2bNnc88993Tbd9asWUycOLEXW+fNmzeP2bNnk8lkeuwaM2bMYMaMGV2uaWbMmzev47GLLrqIG264ocfakE2BWaXSqa63IiIi/cixxx7L/PnzGT9+fMdjF1xwQd7ArK/MmzePCy64oEcDs8suu4zLLrus4/7UqVOZP38+U6dO7XisNwMzTTBbKfUxExGRfmzMmDGMGTOmr5vR5/bYY48u94cOHcoBBxzQR61RxqxyGc38LyIi9WHBggWYGQ888EDHY7/85S8xM7797W93PPbiiy9iZtx+++3dSplmBsD3v/99zAwzY/bs2V2us2jRIg4++GAGDhzIlClTuPzyy7u15ZFHHuGII45g8ODBDBo0iMMPP5xHHnmkyz655cPQpEmTmDVrFuDLqhdccAHg1+EM2xSVmfGtb32L73//+0ycOJHm5mYOOeQQHn/88aJtyS1lTpo0iddee42rr766ow1hG3uCArNKqfO/iIjUialTpzJ8+PAuZch77rmH5ubmbo/F43EOPvjgbueYP38+4PuTzZ8/n/nz53PWWWd1bN+wYQMf+9jHOP3007n55puZPn06Z599NnPnzu3Y58knn+TQQw9l3bp1zJkzhz/+8Y9s2LCBQw89lCeeeKKsn+mss87i05/+NAAPPPBAR5vK8cc//pHbb7+dSy+9lDlz5rBy5UoOP/xw1q5dG/kcN954I+PGjePII4/saMN5551XVjvKoVJmpTLqYyYistX559fhzaf6tg3j9oajf1TWIbFYjEMOOYS5c+dy/vnnk8lkuPfeezn77LO55JJL2LRpE4MHD2bu3LlMmzaNIUOGdDtHWL6bMGFC3lLexo0bueyyyzjssMMAOOSQQ7jzzju55pprOh678MILaWpq4u6772b48OEAvP/972fSpElccMEFZfXTmjhxYseAg/33359EovyQpbW1lTvvvJNBgwZ1nGfKlCn84he/4Lvf/W6kc+y77740NTUxevToXilxKmNWKWXMRESkjhx22GHMnz+ftrY2Hn/8cVpaWvja175GU1MT999/P+DLdDNnzqzo/AMHDuwIwACampqYMmUKS5cu7Xjsvvvu47jjjusIysD32frgBz/IvffeW9kPVoVjjjmmIygDX5Y84IADys689SZlzCql6TJERLY+ZWaq6snMmTPZsmULDz74IIsWLWKfffZh7NixvO9972Pu3LnsuOOOrFy5sktwVY4RI0Z0e6ypqYm2traO+2vXru0yyjM0btw41q1bV9F1qzF27Ni8jy1evLjX2xKVArNKhSXMTBKcgzI6JIqIiNTa3nvvzejRo7nnnntYtGhRR2Zs5syZXHfddeywww40NjZy0EEH9VgbRo4cyZtvvtnt8TfffJORI0d23B8wYAAbNmzotl85fb+iWLlyZd7HJkyYUNPr1JJKmZXKHo3pem5+FRERkSjMjEMPPZS77rqL+++/v0tgtmjRIm688Ub2339/Bg4cWPAcjY2NtLa2VtyGQw89lNtuu42NGzd2PLZx40ZuvfVWDj300I7HdtppJ1544QXa29s7Hrvvvvu6HAc+IwdU3Kbbb7+dt99+u+P+kiVLeOihhzjwwAPLOk9TU1NVz0s5FJhVKp3M/72IiEgfmTlzJo888gibN2/uGHk5depUhg4dyty5c0uWMffYYw9uu+027rrrLhYsWMCKFSvKuv55551Ha2srhx9+ONdffz033HADRxxxBJs3b+b888/v2O+0005jzZo1fOpTn+Lf//43v/3tb/mv//ovhg0b1q09AD/72c94+OGHWbBgQVntaW5u5gMf+AA33XQTf/3rXznqqKMYOnQoX/7yl8s6zx577MH999/PP/7xDxYsWNBltYRaU2BWqey+ZepnJiIidSAMvKZNm8bQoUOBzhGb2dsLufTSSxk0aBDHH38806dP54orrijr+u9617uYN28eQ4cO5cwzz+SMM85g8ODB3Hvvveyzzz5d2nn55Zfz8MMPc/zxx3PVVVfx5z//ucugAYDjjjuOz3/+81x22WUceOCBTJ8+vaz2fOITn+DYY4/lnHPO4cwzz2TMmDHcfffdXcqqQMn50X74wx+y2267ccoppzB9+vRu87vVkjnneuzkvWXatGmu3Ci6ar87El5/yH//v69B8/Devb6IiFTl2WefZffdd+/rZkgPCSeY/d73vld0v6lTp7Lzzjvz97//vWbXLvW7ZWaPOeem5dumjFmllDETERHpt1555RWuuuoqnnzyyT5dgimXRmVWKrvzvwIzERGRXpFOpylW7YvFYsRipfNOl1xyCX/605/4+Mc/zuc///laNrEqCswqlT3jvzr/i4iI9IrDDz+86GS1Z555JnPmzCkavAFcdNFFXHTRRTVuXfUUmFVKGTMREZFe95vf/KbbtBrZRo8e3YutqT0FZpVKJwEDnAIzERGRXrLbbrv1dRN6lDr/VyqTgobmzu9FREREqqTArFLpJCQG+O8VmImI9Etbw5RRUl+q/Z1SYFapTBIagmUt1PlfRKTfaWho6LVldmTb0draSkNDQ8XHKzCrVDoFDWHGLN23bRERkbJtt912LF++nM2bNytzJlVzzrF582aWL1/OdtttV/F51Pm/UpkkJJo7vxcRkX4lXLJoxYoVJJN6H5fqNTQ0MHbs2I7frUooMKtUOpmVMVMfMxGR/mjo0KFV/RMVqTWVMiulUZkiIiJSYwrMKpFJA66zlJlWYCYiIiLVU2BWiXAUpkqZIiIiUkMKzCoRdvYPp8tQ538RERGpAQVmlQgzZppgVkRERGpIgVklwkCsY4JZBWYiIiJSPQVmlVAfMxEREekBCswqEfYpS2i6DBEREakdBWaVCEuXDZr5X0RERGpHgVklMrmlTK2VKSIiItXr1cDMzHYws7lm9qyZLTazL+XZx8zsEjN7ycyeNLOpvdnGSNI5pcy0MmYiIiJSvd5eKzMFfNU5t9DMhgCPmdldzrlnsvY5GpgSfO0P/Dq4rR8dozLV+V9ERERqp1czZs65N5xzC4PvNwLPAhNydjsB+KPzHgKGm9n43mxnSbnTZSgwExERkRrosz5mZjYJ2Bd4OGfTBOD1rPvL6B689S1NMCsiIiI9oE8CMzMbDFwPnOuc25C7Oc8hLs85PmtmC8xswerVq3uimYWFnf/jjWAxBWYiIiJSE70emJlZAz4ou9o5d0OeXZYBO2TdnwisyN3JOXeFc26ac27amDFjeqaxhYTTZcQbIJZQ538RERGpid4elWnA74BnnXM/L7DbLcAngtGZBwDrnXNv9FojowgzZrEExBqUMRMREZGa6O1RmQcBZwBPmdnjwWPfBHYEcM5dDtwOHAO8BGwGPtnLbSwtzJCFGTMFZiIiIlIDvRqYOeceIH8fsux9HPDfvdOiCnVkzBogrsBMREREakMz/1cit4+ZAjMRERGpAQVmlcjtY5ZWYCYiIiLVU2BWiaCP2TGXPsTmNMqYiYiISE0oMKtEEIitejtNeybWmUETERERqYICs0oEgVmSOBlTHzMRERGpDQVmlQhKmSnipIlDJt3HDRIREZGtgQKzSmSyAjOLa+Z/ERERqQkFZpVIh6XMRJAxUylTREREqqfArBKZJA4jQywIzJQxExERkeopMKtEOomLNQCQMvUxExERkdpQYFaJTAoX86tZpVEfMxEREakNBWaVSCc7ArOU+piJiIhIjSgwq0Qm6ecvA1JOgZmIiIjUhgKzSmRSOFPGTERERGpLgVkl0ikyHaXMmAIzERERqQkFZpXIJP3EskDSqfO/iIiI1IYCs0qkk2TMT5eR1JJMIiIiUiMKzCqRSZEJMmYpp1KmiIiI1IYCs0qkk6SDzv/tTjP/i4iISG0oMKtEJkmacLoMZcxERESkNhSYVSKd3fk/1rGouYiIiEg1En3dgH4pk+rImLW7ODgFZiIiIlI9ZcwqkU6SCvqYJV1MfcxERESkJhSYVSKT9DP+A1sywcz/zvVxo0RERKS/U2BWiUyadBCYJTH/mMv0YYNERERka6DArBLpnIxZ8JiIiIhINRSYVSKrlJnMBE+hpswQERGRKikwq0Q65ZdiAtpdUMrUAAARERGpkgKzSmSSpJwfldkWljK1XqaIiIhUSYFZJdLJjoxZOAhApUwRERGplgKzSmQ6A7PwVp3/RUREpFoKzCqRTpF0QcbMqfO/iIiI1IYCs0pkkn4pJiAVrmqlwExERESqpMCsEumkX4oJSKGMmYiIiNSGArNyZdKA68iYpdXHTERERGpEgVm5gsxYmDFLalSmiIiI1IgCs3IFmbEtuRkzzWMmIiIiVVJgVq5ghv/2TIzGeCwrY6ZSpoiIiFRHgVm50r5k2e7iNDfGSTuVMkVERKQ2FJiVK8iMbcnEGdgY7xyVqc7/IiIiUiUFZuXq6GMWo7kxTkp9zERERKRGFJiVKyhZbsnEgoyZ+piJiIhIbSgwK1c6q5TZkMgKzNTHTERERKqjwKxcQWasLRNjYFNcgZmIiIjUjAKzcqXDwMy6ljLTCsxERESkOgrMyhVkxlIuTnNDImuCWQVmIiIiUh0FZuUKl2QizqCmOCkXLmKuzv8iIiJSHQVm5QpKmSmXCKbLSPjHlTETERGRKikwK1eQGUsSjsoMM2aax0xERESqEzkwM7N9zewGM3vLzFJmNjV4/AdmdlTPNbHOBJ38U8RzOv+rlCkiIiLViRSYmdn7gPnAO4G/5ByXAT5X+6bVqSBjliKeM/O/SpkiIiJSnagZsx8BdwB7Al/J2bYQmFrLRtW1dFjKTDCwMZ41KlMZMxEREalOIuJ+U4EPO+ecmbmcbW8BY2rbrDqW6SxlNsRjENNamSIiIlIbUTNmbcDAAtvGA+ujnMTMfm9mq8zs6QLbZ5jZejN7PPg6P2L7ek+6s/N/QzxGIp4gQ0x9zERERKRqUQOzB4BzzSye9ViYOfs0cE/E88wBSg0UuN859+7g68KI5+09YR8zF6cxYSTiRsYS6mMmIiIiVYtayjwP+A/wBPB3fFB2ppn9HHgPMD3KSZxz95nZpAraWT/CecxIkIjFaIzHSLs4CQVmIiIiUqVIGTPn3BPAIcBK4FuAAecEmw91zj1fwzYdaGZPmNk/zWzPQjuZ2WfNbIGZLVi9enUNL19CRx+zWFDKNDLElTETERGRqkXNmOGcWwgcbmYDgJFAi3Nuc43bsxDYyTm3ycyOAW4CphRozxXAFQDTpk3LHZDQc7I6/zcmjIZ4jHRagZmIiIhUr+yZ/51zbc65FT0QlOGc2+Cc2xR8fzvQYGaja32dqmRNl5GI+axZmrg6/4uIiEjVImfMzGx34CRgB2BAzmbnnDuz2saY2ThgZTAtx374wHFNteetqZzpMhriRtoSmi5DREREqhYpMDOzTwC/x3f6XwW05+wSqZRoZtcAM4DRZrYM+A7QAOCcuxwf+J1tZimgFTjNOdd7ZcoogsxYmlhnKZOYSpkiIiJStXJGZd4MfNo511LpxZxzHy2x/VLg0krP3ysyST89BkYiFiMRj5EioZn/RUREpGpRA7NxwOeqCcq2GukkmVgDAA2JGI1xU8ZMREREaiJq5///ALv3ZEP6jUwqyJhBQ9xnzVIkIK3ATERERKoTNWN2DnCDma0B7gTW5e7gnMvUsmF1K53sDMxiMRoSMVLKmImIiEgNRA3MlgGLgD8X2O7KOFf/lkmSDlamakjEaIgZKRdXHzMRERGpWtRg6rfAqfgJX5+j+6jMbUe6aymzIR4jpZn/RUREpAaiBmYnAP/POXdxTzamX8jOmMX8kkxJYprHTERERKoWtfP/28AzPdmQfiOTIk2CeMyIxYzGeIyk08z/IiIiUr2ogdlVwMd6siH9RjpJ2hI0xA3AlzKdSpkiIiJSvailzNeAj5rZXcC/yD8q8/e1bFjdyqRIE6ch5mPaRNyCUZlb+rhhIiIi0t9FDcx+HdzuBByeZ7vDL9m09UsnSVmchoQPzBriMdqd+piJiIhI9aIGZpN7tBX9SSZJmuxSppHMqJQpIiIi1YsUmDnnXuvphvQb6RQp4iRiWRkzYur8LyIiIlWL2vlfQpkkSRI0JsI+Zn5UplPGTERERKpUMGNmZq8AH3LOPWFmr+L7kRXinHPvqHnr6lE6SYqmjlJmY9z8dBkKzERERKRKxUqZ9wIbsr4vFphtOzIpUgzsUspME9Mi5iIiIlK1goGZc+6TWd/P6pXW9AfpJEnXOSozESzJ5DJJrI+bJiIiIv1bpD5mZna+mW1fYNt4Mzu/ts2qYxnf+b8h1lnK1FqZIiIiUgtRO/9/B5hYYNv2wfZtQybImMW7ZsxM85iJiIhIlaIGZsWqdCOAbWfa+3SKdrpOMOszZpouQ0RERKpTbFTmDGBm1kP/ZWbH5ezWDBwLLK55y+pVJknSxTpKmQ1xI00cy6TAOTD1NBMREZHKFBuVeSjw7eB7B3wyzz7twDPAF2vcrvqVTtKeVcpsCOYxA/yyTPGoiymIiIiIdFWwlOmcu8A5F3POxfClzAPC+1lfA5xzU51z83uvyX0sk+oyKtNPlxHv2CYiIiJSqahLMmmFgFA6Sbt1jspMxI1UGN8qMBMREZEqqO5WrkySLRbrKGU2hp3/g20iIiIilVJgVo5MBlwmGJUZZMxilhWYacoMERERqZxKlOUIMmJbXKxzSaZEVsYsrYyZiIiIVE6BWTmCwGtLJk5j2Pk/ll3KVB8zERERqZwCs3IEgVd7JkZDPJjHLGGkw7ER6mMmIiIiVSirj5mZHQYcCEwAlgPznXNze6JhdSkMzFyMQUEpMxGLkQqfRvUxExERkSpECszMbCTwN2AGfrLZdfilmMzM5gEnO+fW9lAb60dQykyR6Chl+lGZmi5DREREqhe1lHkJMB04A2h2zo3BL8f0CWAacHHPNK/OBKXKJPEupUx1/hcREZFaiFrKPB74hnPuL+EDzrkkcHWQTfteTzSu7oQZMxfvGJWZUOd/ERERqZGoGbM08GKBbc8H27d+QeCVonNJpq4TzCowExERkcpFDcxuBk4tsO004KaatKbepcNSZoLGePaSTArMREREpHpRS5m3Ar8ws9vwgwBWAmOBU4A9gS+Z2cxwZ+fcPbVuaF3IhJ3/O0uZDfEYaafATERERKoXNTD7e3C7A3B0nu3XB7eGH7UZr7Jd9SndvZTZEDeS6vwvIiIiNRA1MDusR1vRX2SNygxLmWYGMa2VKSIiItWLFJg55+7t6Yb0Cx2jMhMdpUwAYg3+VjP/i4iISBXKnfl/NHAAMAq41Tm31swGAO3OuUxPNLCudIzKjHWUMgGIJ7psFxEREalEpFGZ5v0EWAbcAvwemBRsvhn4Vo+0rt5kT5cRlDIBLBYEZupjJiIiIlWIOl3GN4BzgAuB/fGd/EO3AsfVuF31KWtJpoZ4voyZ+piJiIhI5aKWMs8CLnTO/dDMckdcvgS8o7bNqlNdlmTqDMyso4+ZSpkiIiJSuagZswnAQwW2tQODatOcOpcuUMqMq/O/iIiIVC9qYLYc2KvAtn2AV2vTnDpXKGMW1wSzIiIiUr2ogdnfgPPN7KCsx5yZ7Qp8Fbi25i2rR1nTZXTpYxZrDLYrMBMREZHKRQ3MZgPPAffRuZj534Cngvs/qnnL6lGXJZk6S5mxhPqYiYiISPWiTjDbamYzgI8BR+I7/K8Bvgtc7ZzbNiKSICOWJE5j1jxm8YRKmSIiIlK9yBPMOufSwJ+Cr21TVsas66jMxi7bRURERCoRdYLZtJntV2Dbe8xs25jAK51VyoznK2VuG0+DiIiI9IyofcysyLY44GrQlvoXBF4p4jRmZcwa4jHSxDTzv4iIiFSlaCnTzGJ0BmWx4H62ZuBo4K0eaFv9KVDK9IFZnLj6mImIiEgVCgZmZvYd4PzgrgP+U+Q8l9WyUXUrnSRtccyMeNaozETcfBZNgZmIiIhUoVjGbF5wa/gA7Xf4RcyzbQGeAf4R5WJm9nv8upqrnHPdJqw1MwMuBo4BNgOznHMLo5y7V2SSZCxnDjOgMciYaVSmiIiIVKNgYOacuxe4F8DMHPBb59yKKq83B7gU+GOB7UcDU4Kv/YFfB7f1IZ0ibYku/cvAlzJTCsxERESkSpE6/zvnLsgOysxsmJlNM7OJ5VzMOXcfsLbILicAf3TeQ8BwMxtfzjV6VCZJmkSXEZnQWcpU538RERGpRsHAzMyONLNuM/qb2beAVcDDwGtm9hczizwfWgkTgNez7i8LHsvXvs+a2QIzW7B69eoaXb6EoI9ZbinTZ8ximi5DREREqlIsY/Y5YNfsB8zs/fjZ/p8DzgV+A5wKfKlG7ck3LUfeqTicc1c456Y556aNGTOmRpcvIVOolGkkXUITzIqIiEhVimW69sUHYdk+CbQBRzrn3gTw/fX5GPCzGrRnGbBD1v2JQLX92monnb+U2ZkxUx8zERERqVyxjNl2wMs5j70feCAMygK3kZNZq8ItwCfMOwBY75x7o0bnrl4m2W0OM4BEPEbKxXAKzERERKQKxTJmG4FB4R0zmwKMAh7K2W8Dfvb/kszsGmAGMNrMlgHfARoAnHOXA7fjp8p4CT9dxiejnLfXpJOk8wRmjXEjRQKXShZdIkFERESkmGKB2XP4UZK3BfdPwPf3ujNnv8nAyigXc859tMR2B/x3lHP1iUw6yJjljsr0pUxlzERERKQaxQKzXwA3mNlIfOA1C3iK7isAfAh4okdaV28KlDLDecwyqWS01KGIiIhIHgX7mDnnbsKPvJwOfAJfwjw5yGoBEMxjdhi+BLn1SydJkuiWMWsM5jFTxkxERESqUXT+MefcJcAlRbYvA4bXuE31K5Mq0vk/jtMEsyIiIlKFSDP/SyCdJFmklElaGTMRERGpnAKzcmSSJF33zv8NcSOtzv8iIiJSJQVm5UinSBInkTdjltBamSIiIlIVBWblCDJmuUsyJWKmmf9FRESkagrMypFO0p6vlJkI+phprUwRERGpggKzcmSSJF2sWymzMez8n0n3UcNERERka6DArBzpFO0ulreUmSauUqaIiIhURYFZOTIpki5OIta9lJl0cUyBmYiIiFRBgVk5Mkm2ZOI0JHJGZcZipNX5X0RERKqkwKwMLp0KOv/nBGYJI0lCGTMRERGpigKzcmSCmf9zSpmJIGNmToGZiIiIVE6BWTnSSb9WZiLfqMwEMWXMREREpAoKzKJyDnPpvIuYNyT8BLPmNF2GiIiIVE6BWVTBcktJl+g2wawvZcaJuRQ41xetExERka2AArOogln982XMGuN+ugy/n7JmIiIiUhkFZlGlCwdmiXgwwSxoygwRERGpmAKzqIKAK0metTLjMZIdgZnWyxQREZHKKDCLKgjM0vk6/8fNTzCbtZ+IiIhIuRSYRRV2/s8TmJkZGUsE+ykwExERkcooMIsq7Pzv4iRySpkALqY+ZiIiIlIdBWZRBZmwFAka492fNhcLMmYKzEREpCe1ruv5GQAyGXj7rZ69huSlwCyqTOFSJgDW0GU/ERGRmku1w8Xvhsfm9Ox1nr0FfrGnDwKlVykwiypruox8pUximsdMRER6WOtaaGuBta/07HXWLYFUG2x4o2evI90oMIsqa7qMvKXMeJAxSytjJiIiPWTz2q63PaVtfXCdNT17HelGgVlUHRmzRP5SpvqYiYhITwsDpZ4OmMLArLWHA0DpRoFZVJnipUxTYCYi0r8sfbh+qhzOwbIFpfcLA6WeDpg6MmZ9EJhteAPeerH3r1snFJhF1bGIeYFSpgIzEZH+o2Up/P4D8MzNfd0Sb+l8uPJwWP5Y8f22hVLmXefDX8/o/evWCQVmUWXC6TLyj8q0sI+ZAjMRkfq3caW/3bCib9sRCtuxflnx/Xq9lNkHozI3vQkblvf+deuEArOosgKzvKXMeDjzf52kxUVEpLCOjFCdzNXV1uJvS80dFgZKbet7dhaAvixltrbAlg2QbOv9a9cBBWZRZU2Xoc7/IiL9XEcgVCejDltb/G2pQKhju+s8pif0ZSmz47VZ1fvXrgMKzKLKypjl62Nm8cYu+4mISB0LM0/1ljEr1Z7sQKkng6a+HJUZXvvt1b1/7TqgwCyqjkXME3lLmTGtlSki0n9ELR32ljD7VbKUuRYs3vl9T0htgVSr/763S5mZDLRt8N9vUmAmxWQvYh7L18dMnf9FRPqNjtJhnZQyI2fM1sKISZ3f90hbgsAoluj952fLesD575Uxk6LCTv3xBGZ5ArOEZv4XEek3OgKhOgnMOjJmJdqzeQ2MntL5fU8IS4nDduj5QQa5svvNqY+ZFBVmwsLMWI54OCpTa2WKiNS/MADYssGX7vpalIxZJu0DpVG7+Ps9VcoMA7ORk+nxQQbdrp11rXopM/cyBWZRBZkwi+UPzDoCtowyZiIidS8MPqCyzNNzt8ELd9SuPa1ZoyCdK7BPC+B8Jive2HMZsy1hYLZzZ5vK9Z9LYM3L5R+XHQRuUsZMigkDrkIZs4T6mImI9ButLWDBv8BKMjP3/h/c95PataetBTBIt8OWjfn3CQOkgaP8V4/1MQsCsxGT/W25mbnWFrjrPHj8LxVcu8XfNg5RHzMpIe0DLisQmMXU+V9EpP9oa/GZJ6hsyozNa2HTytq0JZP2JdXhJdoTBkgDR0DzyJ6blb9LKZPyA8AwgNz4RvnXDjNmo3dRYCYlBBmzWKHALBHO/K/ATESk7rWth1Hv8N9XknlqXeuXdSpUdiy3LdDZd6zQAICwnc0jYeDInu/8H2bMyr1OGDBWstxVmDEbvatKmVJCJkWaGIlE/qcsroyZiEj/EGaoOgKhMjNmqS3QvgnSW7r2VatUGIyE7SmUMetSyhzZs6VMi3dm8MotZVabMYs1wPAd/Xm2wQF1CsyiSidJWyLvrP8AsYSf+d9pugwRkfrWkRGa5PuZlVvKzA6IalHODMt3pQLFjlLmyKCU2YOB2YBh0DjYB0lllzKD/TdUEJi1rYfm4TBoO8DVz3QmvUiBWVSZFGkS+dfJBBJB5/+MAjMRkfoWltqagwCn3IxZa40Ds46MWVhaLVLKjDX4gCnMmNWilNqtPUFgZhYMMii3lBk8P1vWQ/vbZV67BQYMh8Fj/P1tsJ+ZArOo0klSFs+7HBNALAzMUgrMRETqWhgINQ+HQaMryJhlBSoba5gxGzoBEgOKlzIHjuoMmFy6NqXUXGFgBj4ALHeQQfbzU27WrLXFX3vQdv7+NtjPTIFZVJlk0YxZQyJOysVIq/O/iEh9CwOhAcNh4OjKS3UAm96svj1hoBi2p1Dn/9Z1PlACn+mDnilnZgdmzRX0Zcvef2OZAwDaWoKAOcyYbXuTzCowiyqdIkW8YB+zxriRJq5SpohIveuSMRvV96XMMFBsHh6UKAtlzNZ2BmRhgNYTAwByM2aVlDLjvt912SMzW1tySpnKmEkhmSQpCpcyE/EYSeJkUu293DARESlLl4zZqMpLmUPG16aU2dbiA5nEAF9aLRQobl7TGZANHBU81guBWdmjMtfCmN389+UGZmHGbMBw35+u3D5ma16GS/erbOBBnVBgFlU6SapYKTMeI00Mp1KmiEh9y86YhaXMcqZl2LzOd8AftkNtSplhlsgsaE+hUubarFLmiM7Hai1fKbOcQQab18LwnaBpWHlTZmQywbWH++di0BjYVGZgtmwBvPU8LH+svOPqiAKzqDJJkkVKmQ1xI0lCpUwRkXrX2tI1Q4Urr4N7a1BSHDK2Np3TwywRBIMR8gRmzhUoZdZ4OolUOyQ3++AIKhtk0LrWB45Dx5eXMWvfCC7T9bkot5QZBsrrlpR3XB1RYBZV0MesUClTGTMRkX4iOysTlgTL6We2eY1fFmnwONhYw4wZ+Pa0b4JkW/c2u3Rne5uG+Ulga13K3LLB32aXMiF6Zs65zpLrkPHlZcyyS8wAg7crv5QZBsotr5V3XB3p9cDMzI4ys+fN7CUz+3qe7TPMbL2ZPR58nd/bbcwrkyTlYoXnMYv5jJkCMxGROpeboYLyMk+b1/oAafBYf67Ulh5oT06gmD25LEAs5rNStS5lhpmx7FIm+PJtFO1v+4XYB46CoduX19cru8QMlZUyNypjVhYziwO/Ao4G9gA+amZ75Nn1fufcu4OvC3uzjQVlUiSLTpcRI+1imvlfRKTe5WaooLwBANmlTKh+ZGaYwctuT24GLwyMwu3QM+tl5gZmHc9PxOuEgWJzkDHbtDJ6/73cjNmgMT5jVk7/tvC1UGAW2X7AS865V5xz7cC1wAm93IbKpFO0uxgNhUqZsRgp4sqYiYjUu+wM1cAgQ1V2KXOkz5hB9SMzW/O0JzdQDAOjMIMFwYjSnsqYDQ2uUWYps2M9z5G+j5lLR++HF147O2OW3tJZXo2iIzB7zQ8m6Id6OzCbALyedX9Z8FiuA83sCTP7p5nt2TtNKyGTJOniRUZlGinikFHGTESkruXNmEXMCKVTPoAIS5lQXcYseyQiZJUycwKh3FImBOtlljkrfyndSpkj8renkHC/gaNgyPb++6iTzHZMtBtce3Aw+385QfPGlZBo9gFdLUbM9oHeDszypZtyc5QLgZ2cc/sAvwRuynsis8+a2QIzW7B6dS+spZUuEZglYqSVMRMRqX/ZGbNEo+9IH/Wff0c/qJEwZJz/vpoAYMsGwGVlzAqVMvMEZgNH9Hwpc8DwYKH3qKXMrHVIh47330ftZ9atlBkEqVEzbslWvz7nhKn+fj8tZ/Z2YLYM2CHr/kSgSyjtnNvgnNsUfH870GBmo3NP5Jy7wjk3zTk3bcyYMT3ZZn+9TJJ2Fy9aykwSxyljJiJSvzIZaNvQ+c8f/Oz/UQOP7FLdwNGAVVfKzF6OKby1eP5SpsV8EBkKS5m1XMg8NzArd5BB9vPTkTGLGJi1tfifvWmIvx+ulxl1ZGaYudxhf3+7rn+OzOztwOxRYIqZTTazRuA04JbsHcxsnJlZ8P1+QRtr/JGgAmk/j1nhjJmRJubT3CIiUp+2rKdLhgrKm/0/O3MVTwQjB6sIzLKXYwIfCA3Ms0xUODdYLOt/UPNIX7JLbq78+rna1vsAsHFw1+uUVco0H2AOGgOxBGxYHu3YcAFzCxIgHaXMiBmzMECeOM23oZ9mzBK9eTHnXMrMzgHuAOLA751zi83sc8H2y4GTgLPNLAW0Aqc5V8uPA5VxwTxmhafLiAV9zBSYiYjUrdxyGfjM1/plEY/PGnUIwSSzNcyYQRAo5uQjwik6smVPMts4qPI2dGnP+q7BUaH2FNK61h8fD8KLweOilzKzS8zhdSH6lBnh6zBsBxg2UYFZVEF58vacxy7P+v5S4NLebldJ6WQwXUb+UmZjPEbKKTATEalruXNlgS9lvvF4tOOzS3XgBwBUM8lsbsYM8q+XuXlN1xGZ0HW9zOE7Vt6GbNnLMXVcZyS0LI12fPZ6nuD7mUXt/J89KAMg3uB/5nJLmYPHwohJ/TYw08z/EblMklSRzv+JYFSmKTATEalfuX2owGfM3n4rWl+t7FGH4DNC1SzLlDsSMTx3boaqdV3XgAc6A7VaTjKbLzArt5SZndkrZ5LZ3IwZBHOZRXx+N630ZdhBo2HETgrMtnrpZMklmVTKFBGpc/lKmYNG+6mOosyX1boW4k3QMNDfH7ydDxwqnTOrUHu6df5f2z0w6yhl9nBgFk5kGyVwDSffDQ3ZPnrn/9yMGQTPb8T+fxvf9AMGYnEYPsmPlk22Rju2jigwiypTvI9ZY9xPl6GMmYhIHctXyixnLrOwVBf2wRoyzn8grzRr1dbiO8hn9xEbONpnyMLBZOH6k8VKmbVSKDCLOsggN4AcOt6v/dkWIejNmzEbHT0juWll54CBEZP8bdQSbB1RYBaRBYFZY5FSZpK4n+VYRETqU6HO/wBvRwnM1nUt1XXM/l9hP7MwS5Td2T6cvysM9pKbfWCUmzELf4beKGVCtAAwt5QZdcoM57pOtBsaVEbGbNPKzrnlwsCsH5YzFZhFlfHTZRQrZaaJEdM8ZiIi9StfhmpQGetlhtNWhDpm/68wMMuXJcqdZDa3X1sonvBBVC0nmc0XHEXNKCbbIPl21+enY5LZEgMA2t/2mcfcoHDQGD/FSbKtZNPZuLLz9VBgtvWzEtNlhEsymTJmIiL1K1+Gqpz1MnNHHXYsZF7hAIAogVC+dTKz961VKTOd8oFVvlImlM7M5Vs2akgQmJXKmOUrMQMMDiaQLxU0Z9J+9GYYmA0aDQ2DFJhttZzDXIoUiYKlTDMjTUJ9zERE6lmhfkwQLWOWW6qrRSmzVHvyBTyh5pG1K2WGgx+ahna/BpQOAPNl9oYGpcxSGbN8JWbwGTMoHfhuXuO7EoWlTLN+O2WGArMogmAr6QqXMgEyFiemjJmISP3KN/KvYSAkBpQu1WUyvlN+duaqcRA0Dql8ktm2PO3JzeAVKmVC54jJWsg3dUf2dUsFZrmT7wI0NPufr1RgVihjNijiQuZhYBx2/od+O2WGArMo0r7fWLFSJkAmliDm1MdMRKRu5cuYmQVzmZUIcLas91mZ3MxVNbP/58uYZc/oD50BUcFS5rrKrp0r3xxv0NlnrFRmLnfy3dDQCFNmFMqYhaXMUnOZhRm1weM6HwszZn2/eFBZFJhFEXToL7ZWJgCmPmYiInUtX8YMgoXMS2RlCmWuBo+tbCHzQiMR4w3+sbdzSpnZnepDtSxlFgrM4gm/eHolpUzw/cwqzphFLGWGgy/CPn/gA7Pk5uijOuuEArMogrlkUkWWZAKVMkVE6l7b+u7//KFz9v9iWoPMVG7manCFGbMtG30GLl97sieZ3Zyz/mS2gSP8PGGpLeVfP1ehwAyilUzzlTIhWJapwoxZ4yDfiT9yKTMnMIN+V85UYBZFJlop08UaiDt1/hcRqUsdGao8gUe+2fZzFSrVDRlXWWCWbwHzUPayTPkmlw2VM8dYyfaUCMxKljLXQuNgSDR2fXzI9j7jlS7S1aetBbDuAw8gWDs0QimzaZjv0xZSYLYVS0crZTplzERE6leYoSoUCJXqY7a5wOjIwdv5rNWWTeW1J98C5h3tyerz1ro2f8d/6Hy8FuXMYoFZlPUy8y0bBcHITFc8eG1t8deN5fkfO3i70guZb3qza8d/6FzYXYHZVijImKVdvGgp08USxEn3u46GIiLbhEL9mMAHOMm3i6+tWKhUF3Y4LzdrVmgUJHTt81Yo4IHuAwWq0bbeLwLeODjPdSLMl5a7TmaoY8qMIuXMfIMyQoPGwKZSgdmqzqkyQg3N/rVRYLYVyvgsWOmMWaLL/iIiUkcK9WOCrLnDigQ4m9eAxbsHUh2TzJYZmBVrz8DRnQuHby4Q8EDtS5lNQ/NnrSKVMtfkDyA7JpktMgAg3yCI0KAxpTNmG9/s2r8s1A/nMlNgFkXE6TJcvMF/o0lmRUTqT9GMWYTZ/8PMleVUTiqdZLZYewaN9v9L2lp6t5SZL3sHPgAsNcggd/LdUJSMWVjKzGfQGJ89LJT0cEGZtFBg1vJa4evWIQVmUXTp/F+4lInFu+wvIiJ1JFLGrEhgVqhU11HKLHNZplIZM/DBTPsmP/oyn1qXMgsFRwMjZOYKPT8DR0G8sUTGrKVwKXPwduAynaNic7Vv8tNiDCkQmK1fBqn2wteuMwrMohi7F1e9bx7zMu8uPo9ZLCxlKmMmIlJ3SvUxg+IDAAr19Woe4d//y13IvK3Ff6BvGlK4PW+9EFyjQCkz0eSnk6jFJLNRArNCmbl0yh+f7/kx8/2/is1lVmh+OegMmgsFvuEccoPHdd82YhLgYP3rha9dZxSYRRGLszk+hHYaigdmYSkzrcBMRKTuFM1Q5Swcnk+hUl0sFsxlVmbGLAyEckuj4Dv/A6x5sWv78hk4qndKmVA4YxZmswq1c8j2hUuZzpXo/B8uy1Tg+d2UZzmmUMeUGa/mP7YOKTCLqD2VASheyoyFpUwFZiIidadYhmrAcL+tWClz85r8s+9DMPt/mRmzfMsxhcJS5uogY1ZoVCb4MmdNSpkbCmetSgWuxVYngGCS2QIZs2QrpNsLX3twifUyw0EXuaMyoV/OZabALKJkOkMiZli+TzahWNj5X33MRETqTrEMVSwWzGVW4J+/c0En/AIBUiWz/+dbwDwUlu9KlTLDbbUalTkgzwSvULqUWWjy3VCYMcs3nVSxEjOUXpapo5SZp4/Z4LEQb1JgtjVKZVzxMiZ0LpehjJmISP0plqGCYPb/Ahmh9rd9Vqdgqa6CwKxYexqafd+xt3qplJlOQfvGykuZhdbJDA0dD6nWziAsW7ESc/h4LFF4yoxNK/3ggnzZulgMRuykwGxr1J7KkChWxgSsI2OmecxEROpOW0vhwAOKZ8zCgK1Q5mrwOH9sOX2Mi2XMwPczS74dtK1YKTPCOpalbNngbws9Pw0DgkEGhfqYFZh8NxTOZZavn1mpjFksFqyEUKiPWTBVRqGK1ohJsK7/TJmhwCyiVCZDY4mMmYUZs2LrgYmISN8oNvIPgtntCwRmYeBRsJS5HeBKr+mY255iGbww+9QwyI++LKR5pC9DVjPwrNhyTB3tKTLJbKlSZjiXWb5+ZqUyZgCDxxQOmjfmWY4pWzjJbD9ZlUeBWUTJVOlSpqmUKSJSv4qN/IPipcxSpbohZS7LFI5ELBooji5+zY79wklmq5gyI0pg1jyieCkzMQAaBubfXjRjFly76GszpnAfs02r8k+VERoxyWcEq3l+epECs4iS6QilzLg6/4uI1K2SGbPR/p93vszT5hKluo7Z/yMGZu1v+w/xpQJFKDy5bKhUx/woImXMRhUPXJvzrIoQ6liWqUgps2hZd7siozLfzD+5bKifjcxUYBZRMuNKljJjca2VKSJSl0rNlQWdgVC+AKdkKbPM9TKLLWAeCjNhxUZkQmen92pGZlZbyiw2YhV8H7VBY2BtnvnEOkqZRa49ZKwPwNrf7vp4OumDxXwjMkMjJvvbcIRrnVNgFlEySuf/eKP/RqVMEZH6EmaoSvUxg/yZmc1rASs911bUwCxKv6pBZZYyqxkAEKmUWWRajkKrImTbYX9Y8kD3vl5tLcHi6fHCx+5yhB8V+/w/uz4eljeLBWZjdvPP85L7i7evTigwiyiZzkTuY+bS/WdNLhGRbUKpkX+QFeDkC8zW+KAlrIzkSjT5zFXUSWYjtScMzEoEPL1ZymxrKVDqXVM6s7fzDFi/tPss/KVKzAA7HeTnQnvq710f75j1v0hgFovD5IPhlXv7xQAABWYRJSPMYxYP/mDTKWXMRETqSjkZqnyZp1KlOvAd0GuZMQsDxVIZs1JzjEXRth4waMyzKkJHe4Lr5J2LLMLzs/MMf/vKvTnXboHmIgEh+OBqrw/DS3d1/TnDjFmxPmbhtde/DmtfKb5fHVBgFlEylSm+HBNgCV/KTPWjVexFRLYJUUb+hRmqQqXMUgFSOZPMRsmYhYFiqUxU4yA/u321pcwBQ/2cYYUUCgAzGT9ootTzM2oXn/V6ZV7Xx6NkzAD2PtmXo5+5ufOxMENZbFQmwM6H+dvca9chBWYRRSllxoNRmemURmWKiNSVKCP/Bo7yAU6+TuJRSnXDdvDHRvlw3lE6LNKeEZP8FBRjdi1+LjM/s364SkAlii1gHhoajKzMfX7aWsBlSj8/Zj5z9ep9PpjLPr5YgBoavw+M3rVrOTMMhMNlmwoZuTMMnQiv3lt8vzqgwCyiKKXMWCIoZSYVmImI1JWwdFgsAIgnYNcjYfFN3ftRta4rXap753E+wHllbsT2mO/0Xsjg7eB/l3SWAIvZ7Vh4+e7K5+qKEpjtsL8Pvhbf0PXx8Jqlnh+AnQ/1Zc+VT2Ud3xItY2bms2avPQDrl/nHNq30AXVQsSp6bEdQWN8zJygwiyhKKTOWCDJmmvlfRKS+RMmYgf/H//YqWHJf18ejlDLfMdMPAHjqb9HaM2BY8dIh+DUzo9j7JD9q8dlbo+2fa8sGaCoRmMUbYM8PwXO3w5ZNnY+Xmnw32+RD/W12STFqxgxgr4/426ev97cbVxbv+J9t5xk+iHzzyWj79xEFZhFFKmUGEXtGpUwRkfoSJUMFMOUDfp/sclmyza9ZmW+R7GyJRtjjRHjutu7zbeVrT9RgJIrt94WR74Anr6vs+CgZM/CBa6rV/4yhUuuIZhs6Hsa8s3MAQLINUm3RMmYAo94BE97TGfxuKiMwm3yIv80dfFBnFJhFlIoyKjMsZSowExGpL20tpTu3g58IdfcPwjO3QLLVP1Zqctls7zoFkpu7z7eVtz3DS58vqrDMt+QB2JBnPcpSogZmO+zv+9JlZwU7np8SgWto8qHw2oOQ2hJtmo5ce58Cbz4Fq57zgdmQEh3/Q0PGwnZ71P0AAAVmEbVHmGBWGTMRkToVtR8T+LJg+0Z44Q5/v5xS3Q4H+E7mpcqZtc6YgQ/McPD0DSV37SZqYBaL+efn5Xs6R6+W8/yALymmWmHZo1mjUyMGdeDLqRbzz/GmlcUXMM937aXzfaauTikwiyiZzpRckikclZlRHzMRkfpSTj+myYf48lgYXJVTqovFYO+PwEv/hreLTF9R64wZwOhdfEkzSh+3bJm072MWNWu198ng0rD4Rn9/8xqIJUqXiUOTDvKB1Svzos3nlmvIWJ91W/gH36+u1FQZ2SYf6kunyx6JfkwvU2AWUaRSZkMQmCljJiJSX8rJmMXivpP5i3f6zuLllDIha76tm4q3p9YZs/Dabzxe3tQZWzb426iB2dg9fUkwDABb1/qMV6EFzHMNGOb7ib0yL9p8bvm86xR4e7X/vtTkstkmHQQWr+typgKziKKslZkIS5n5lqsQEZG+U07GDLqOciy3VDd2L9/BPXf5oFC4oHo5/aqi2vPDgBW+dj4V9fM6GV5/GNYtiTZiNdfkQ2H5QmhZGlx7eHnHv/M4P+ccRO/8D9A0BCZOV2C2NWiPUspMqJS5VdqyEe76DrxR30OsRaSItvXl/fPffqqflPSpv3UGZlFKmRB0xD8Jlj4ILa93355s9UFfrUuZ4Ec9Tj7YtzvfupCLb4L/XNJ1WyWBWfa0FZvXRn9uQjvP8OXQcHRnuRmzAUNht6P89+WUMsHPpbZiUWcZtc4oMIsoSikz0RhmzBSYbTVa18EfT4T/XARzjoWlD/V1i0SkEuWWDs386L9X74dVi6FxcOlJTLPtdZK/DefbylZp+S6qvU+GtS/74CPbw7+Bv50Jd50Ht3yhc6LVSgKzETv5gQ5P/i3aOpm5dtgPEs1+wtdyrx167xdhypEwfMfyjtt5hl+pYMkD5V+zFygwiyCdcaQzrmQpsyHo/O/Ux2zrsGkVzDnOT0Z43C/8yJ8/fciPRhKR/iPZCukt5Weo9j4JcL6cWW7gMXKyL5nlKylW0uG9HLsfD/HGzms7B/f9FP75NV8CPPirsOhPcP1ZkE5WFpgBvOtkWP2sX6Kp3Ocn0QQ7HeizZo2D/eS15Zo4DT5+XXkBM8CEadAwqG7LmQrMIkim/ZpepTJmDQ0xUi6Gy6iPWb+3fhlcdTSsfQU+9leY9in45D99aeMvp3adXFFE6luU5ZjyGT0Fxr/bd+Qvt1QHPuO28incyme6Pt7TGbPmEX6i3Kev90tL/Xs23PNdeNepcPIf4PDz4f0X+qWV/np653qT5QZme3zIj8as9PkJVwGoMEB9fe1mfjX3JdpTmdI7Z0s0wk7vrdt1MxWYRZDK+Fp8qT5miViMFHGcSpn929pX4PdH+4zZ6Tf4ZVbAZ8zOvBXG7Q1/PcOn8EWk/kVdjimfvU8GoL2pgmP3PJGMxbnm979g7dtZC5tHWcC8WnufDJvehD+d6LtiTPsUnHi5Xw8U4KAvwbE/93O13Xl+0J7OwKwtmeahV9bg8vVTCw0a1fn+WG7GDDrWAG1viDjNRpb2VIb/+tNj/OSO5/nRP5+r7NpvvQDrl5d/bA9L9HUD+oNkEI2XKmU2xn1gtsPSm+CKhfk7XnYMJw5vg31y9zXr3Mcs2J5n37z7Befttl9wXTPftyCT9p90Mkl/i/lPPx1fcT/XTLltdJmca8c69+2yX9bPlN22jlty9s29NnR5HrudL885S14bP8oIgzNv8XMCZRs4Ej5xM1zzUbjhM/DQr7q2s9C1I7Ux/Hlc19exkp+l4zmK8nsGZT2PXfYr1MZ858w9ZTl/CzXcr2PffPtV8/tYxX4Ff54Ibeyyrytx7axjavZa1/r3sdDzWODaJc9H5wz+FfRjenX8Uezkvs2DKxwHZxzxWMQpIYB1NpzFmb34YNstrLv4MUaMGIA5OgOzEhmzls3tXHDrM5x18GT23L7Mtu96JK5xMLbkft8X6/0Xdp/OYvqnfRnxprP9/aYhHZsu/Mcz/OXhpfz8lH348NSJha+z9yl+WpEKMmbrh++OYzAvr4GdNm1h9OCmyMf+4t8v8MwbG9h/8kh+/59XmT5pBEfvPT76xcOF4eccA41D6PL7teeJcMj/i36uGlNgFkEyE62UmYgbv0sfzclDV7P9wOYI/ygcpf+RBd/nvmHl+weevR90fbPuuG6wXywOsYbgNgjEIAjUUp1Bm0uX10aL0fVNM9jmMgV+lgJv6h37Qek3Ydf9fLk/b5Rrh8cM2wEO+yZstzt5NQ2Bj/8N7jof1r7a9bku+I+5RBu7fJ/z8xZ9fvL9LHTdt8vvQ85rWPbzmGe/3Dbm/izZz0+HYvtFbWMl+2U/N1n7dWtjnr+ZYr+PuT977msTaT/y7FvgZyn2Whe6dvbPk/tzVPpa1/r3sdTz3e3aBZ6ffM/h2D393FllyGQcX7tjFe/OnMHTm3bg9Ydf44wDJ0U+/vL7Xuax5IlcOOpulre0EhsxjAkjB/qNux0Fw3cqevxP7nieGxct54nXW/jHF9/HwMbo/7K3WCO/TJxFQ6ydTx58HkMLzTG2z6k+QFz2qP9/ALywciPXPrKUxkSM79y8mP0mj2TiiIH5j9/9ODjgv2HK+yO3LXTZvFdYl/wYbTaQzdc/xW8/8R4swlxoD7+yhsvvfZnTpu/AhSfsxSm/mc/X/v4ku48fyqTRg6JdfLs9YL//8t1WQuG1y536o8YUmEWQTPs/7lKlzIZ4jF+kTmbke/bijAOK/8FJP9fQDMf8pK9bISI96OqHX+PRJes45aSvsXjRcn5yx/Mcs/d4RkXI7Kzc0MYfHlzCMfvMZLeTv8wFv32IxSs28M9TD2aHkQWCnCxPLVvPXx5ZykG7jOLBl9fwg9uf5Xsn7h257b+65yUuXbsf8Zix8JrH+d2Z00gU+h+265H+K/CD259lcFOCP5+1Px/77cN89bon+MtnDsifLWxohqN+ELldodfXbuaq/yzhg+8+nXeNG8L3bnuWvz76OqftV3yE5Ya2JF+57gl2HDmQ847bg8ZEjF99fCrHXnI/Z1+9kBs//14GNMRLNyAWg2P+r+x29wb1MYsgaimzIdieLLcjooiI1JXlLa386J/PcfCU0Zz0nolceMKebG5P8+N/RevPdOk9L5FKO849YlfiMeNnp+yDAV+57nHSGVf02EzGcd7NTzNqUBO/Pv09nPW+yfz5oaXMfW5VpGs/+8YGLpv3Mh+eOoHvnrAX976wmu/d9mykY+9/cTXznl/NF2ZO4V0Th/Od4/fg4VfXcuX9r0Q6Pqof/+s54jHjfz6wG586aDLvfccoLvzHMyx56+2ix82+ZTFvrG/l56e8m0FNPrc0YXgzvzj13Tz7xgZm37K4pu3sCwrMIkhFLGWG28NRnCIi0v845/jWjU/hgB98aG/MjF22G8KnD57MdQuW8dhr64oev3TNZq55ZCmn7bcDO47y2bGJIwZywQl78uiSdfzmvpeLHv+3x17n8ddb+OYx72TogAa++oHdeOe4Ify/vz/Jmk1bih6bzji+fv2TDGtu4Lxj9+Bj++/Ipw6azJwHl/Dnh14reez3b3uWHUcO5BPv9VWfk94zkaP2HMdP73yeZ1ZsKHp8VAuXruMfT77BZw7ZmXHDBhCLGT89eR8SMePL1z1OqsD/0NufeoMbFi7nnMN24T07jeiy7bDdtuO/D3sH1z76Otc/tizv8f2FArMI2lP+003UwCxV4tOQiIjUr5sfX8G851fzPx/YrUvZ8YszpzBu6ADOv/npolmvi/79AvGY8YWZU7o8/qF9J3Ds3uP5xV0v8PTy9XmPbdnczo//9TzTJ43gQ/tOAGBAQ5xfnPpuNrQm+cYNTxUdKXnVf17liWXrmf3BPRkxyM/v9a1jd+ew3cbwnVsW88CLbxU89m8LXue5Nzfyv0e9k6aELweaGT/48N4MH9jIl//6OG3JdMHjo3DO8b1/PMOYIU381yE7dzy+/fBmvnviXixa2sJl817udszzb27kmzc+xT4Th/GFw6fknhaALx+xKwfsPJJv3PgUH/n1g3zhmkX88J/P8sf5S5j73Kryp9XoIwrMIhg1uJEvztyFXbYr3qkwLGX2lxdfRES6WrNpCxfcuph9dxzOme+d1GXboKYE3z5udxav2MBfHs6ffXph5UZufHw5s947ibFDB3TZZmZ8/0N7MXJQI1+8dhFPvN7S7fif3vm8H435wb26dITfffxQ/ufIXbnzmZX8bUH+jNDSNZv56Z3Pc8Tu23HcuzpHKMZjxiUf3Zddxgzm81c/xsurN3U79u0tKX521wu8Z6cRHLN31yWORg5q5P9OehfPr9zIT+94Pu+1o7r9qTdZuLSFr75/145SZOiEd0/gg/tsz8V3v8gtT6zg9w+8yuevfoz9fnA3R150H1uSGX5+6rsLJkkS8RiXfmwqJ71nIk2JGE8ta+GqB5Zw/s2L+eScRznq4vu4/8XVVbW/N/R6538zOwq4GIgDVzrnfpSz3YLtxwCbgVnOuYW93c5sY4cO4Csf2K3kfmZGImYdpU8REekf0hnHc29u4Bd3vcCmLSl+/JF35e3sfuze47lml6UFBwL87M7nGdyY4HOHviPvdYYPbOSiU/flc39+jBN+9R/et8toPn/YOzhw51EsXrGBqx9eypkHTmKP7bvP7XXW+3bmnudWccGtixkxqJH9dx7J0AHBijPO8fUbniQRi/HdE/fqNrpxyIAGrjxzGif+6j+ceOl/OHrvcZy47wQOmDyKWMz4zb0vs3rjFn5zRv6RkYftth1nHLATVz7wKs+8sYET953AUXuN67h+FFtSaX70r2d557ghnDxth7z7fPeEvXh0yVq+eM0iwPcfO+gdo5g+eSSHTBlTcuDE6MFN/OBDnYMkMhnHmrfbeey1dfzwn89yxu8e4ei9xvHt4/ZgwvDmyG3vTVZ08rhaX8wsDrwAvB9YBjwKfNQ590zWPscAX8AHZvsDFzvn9i923mnTprkFCxb0WLvL8c7z/slB7xjNR94zkcZ4jMaE/wr/wAs93dl/B/mGGOT+nUR92cp5dbMvUen1Cp7bcq/Q5eyRzlFOGzoGyec5JvzZCrUmmXas2tjGsnWtrGhpZXlLK2+ub2NYcwMTRjQzcXgzE0cMZMKI5mD0jyt5rezt4W5vb0mxvMVfY0VLG8vXtbJxS4rxwwaw/fABTBg+kO2HD2Ds0AEkypg7KVvu85Dv7z37TTjCSHW/X77Hihxbm9+fso8q8Hj5jemJt8l8pyx2nezf23y/V6WO6zw6PM7hguOzfy/MzM98U6Kt3a4TYZ/ibYsuyvtoFFuSGZ5Ytp5HXl3DgtfWsbHNr9ry9aPfWTCwAnhp1SaOvvg+JgxvZvfxQ5kwvJkJI5pJxGOcd9PTfOX9u/LFAuW20KYtKf7y8Gv89v5XWb1xC1N3HE5rMsPqjW3c/dUZDGvOH/Asb2nlhEsf4K1N7ZjBO8cNZb9JI2hMxPjt/a/y/Q/txcf3LzwrwAsrN/Kbe1/hX0+/wdvtacYPG8Cxe4/nzw+/xhG7j+XSj00teGxbMs1v7n2FGxctY8mazTQmYhyx+3Ycs/d4Mg6WZ71nvrG+jSFNCSaMaO54P3tx1Uau+s8S/vip/Thk1zEFr/Py6k08+8YGpu44gu1rGDy1JdNcef8rXDr3JQDOOWwXDtjZT42R/Ss1ZnBT9Gk3KmRmjznnpuXd1suB2YHAbOfckcH9bwA4536Ytc9vgHnOuWuC+88DM5xzbxQ6bz0FZjN+Mpclazb3dTOkRkYOavTB0ZABrG9NsryllZUb2qhlN8J4zBg3dAATRjQzuCnBm+vbWN7SyvpWrSAh0ht22W4w+00eyX6TRjJ98shImZSbFi3n+oXLWN7SyvJ1rWwJurCMGtTIvV87jMFN0QpSbck0f39sGZff+zLL1rXys5P34SPvKTKhK7C5PcXjS1t4ZMlaHl2yloWvtdCaTLPf5JFc+5kDiEX4INfanuauZ1dy86Ll3PvCamIx4+6vHBppKg/nHI+/3sLNj6/g1idWsCZrVYNhzQ1MGN7MuGED2LQlxfJ1rby5oa2jT96M3cYw55P7lbxGT1q2bjPf+8ez/Gvxm3m3n37AjmVNTVKJegrMTgKOcs6dFdw/A9jfOXdO1j7/AH7knHsguH838L/OuYKRVz0FZm9vSbFq4xbaUxnaUxm2pNK0pzJd/pEXy0a5PJ9Nq/10aBE+y2Zft2t7Oj8Jl8qIFNruXHbmxuVNk0f9oFvOJ+Lw586fterMcuWeM27GdkN91irfhI7JdKYjeArfjPM9R/mex86shzGgIcb2w5sZO3RA3pLJpi0pVrS0smrDlm6/F8We71y5z0OhTEj0LGy039Hs353s65er0O9j0WPynqfzd683sztFz5mnJfmuk/17G/49RXluc5+7zscdsY7MWGeGrCODlufvo9j7SKHfiShZ1Kivae6+xd9Ho50zHjPeOW5IpHnJirbL+XLZ8nWtjBzUGCm4yZVMZ3hx5SZ2Hz8k0iSrucc+/+ZGdhw1sKzSYmjNpi1sbEtVlCVKpjM8uWw9g4PsWL6ANJ3xVYgVLW1MGTu4ojb2hMUr1ndZKiv8HR83bAC7bDe4R69dLDDr7T5m+X7bcv+io+yDmX0W+CzAjjsWn5CuNw1qSjA54icl6Z8a4jF2GDmwojffcgxuSrDr2CHsOnZI6Z1FpM+YGaMHN5W1pFCuhngsb7+yqMfuNWFYxdceNbip4uC0IR7rNnVFrnjMGD+smfHD6qtPV9nLXPWS3h6VuQzI7vE3EVhRwT44565wzk1zzk0bM6ZwrVpERESkv+jtwOxRYIqZTTazRuA04JacfW4BPmHeAcD6Yv3LRERERLYWvVpzc86lzOwc4A78dBm/d84tNrPPBdsvB27Hj8h8CT9dxid7s40iIiIifaXXO0M5527HB1/Zj12e9b0D/ru32yUiIiLS1zTzv4iIiEidUGAmIiIiUicUmImIiIjUCQVmIiIiInVCgZmIiIhInVBgJiIiIlInFJiJiIiI1AkFZiIiIiJ1QoGZiIiISJ0wP9F+/2Zmq4HXeuFSo4G3euE6Uh69LvVLr0190utSn/S61K9avzY7OefG5NuwVQRmvcXMFjjnpvV1O6QrvS71S69NfdLrUp/0utSv3nxtVMoUERERqRMKzERERETqhAKz8lzR1w2QvPS61C+9NvVJr0t90utSv3rttVEfMxEREZE6oYyZiIiISJ1QYBaBmR1lZs+b2Utm9vW+bs+2ysx2MLO5ZvasmS02sy8Fj480s7vM7MXgdkRft3VbZWZxM1tkZv8I7uu16WNmNtzM/m5mzwV/OwfqdakPZvbl4L3saTO7xswG6LXpG2b2ezNbZWZPZz1W8LUws28EMcHzZnZkLduiwKwEM4sDvwKOBvYAPmpme/Rtq7ZZKeCrzrndgQOA/w5ei68DdzvnpgB3B/elb3wJeDbrvl6bvncx8C/n3DuBffCvj16XPmZmE4AvAtOcc3sBceA09Nr0lTnAUTmP5X0tgv87pwF7BsdcFsQKNaHArLT9gJecc68459qBa4ET+rhN2yTn3BvOuYXB9xvx/2Am4F+PPwS7/QE4sU8auI0zs4nAscCVWQ/rtelDZjYUOAT4HYBzrt0514Jel3qRAJrNLAEMBFag16ZPOOfuA9bmPFzotTgBuNY5t8U59yrwEj5WqAkFZqVNAF7Pur8seEz6kJlNAvYFHgbGOufeAB+8Adv1YdO2ZRcBXwMyWY/ptelbOwOrgauCEvOVZjYIvS59zjm3HPgpsBR4A1jvnLsTvTb1pNBr0aNxgQKz0izPYxrK2ofMbDBwPXCuc25DX7dHwMyOA1Y55x7r67ZIFwlgKvBr59y+wNuoNFYXgv5KJwCTge2BQWZ2et+2SiLq0bhAgVlpy4Adsu5PxKebpQ+YWQM+KLvaOXdD8PBKMxsfbB8PrOqr9m3DDgI+aGZL8OX+mWb2Z/Ta9LVlwDLn3MPB/b/jAzW9Ln3vCOBV59xq51wSuAF4L3pt6kmh16JH4wIFZqU9Ckwxs8lm1ojv8HdLH7dpm2Rmhu8r86xz7udZm24Bzgy+PxO4ubfbtq1zzn3DOTfROTcJ/zdyj3PudPTa9Cnn3JvA62a2W/DQ4cAz6HWpB0uBA8xsYPDedji+36xem/pR6LW4BTjNzJrMbDIwBXikVhfVBLMRmNkx+P4zceD3zrnv922Ltk1m9j7gfuApOvsxfRPfz+w6YEf8m93JzrncTpzSS8xsBvA/zrnjzGwUem36lJm9Gz8goxF4Bfgk/kO5Xpc+ZmYXAKfiR5wvAs4CBqPXpteZ2TXADGA0sBL4DnATBV4LM/sW8Cn8a3euc+6fNWuLAjMRERGR+qBSpoiIiEidUGAmIiIiUicUmImIiIjUCQVmIiIiInVCgZmIiIhInVBgJiJ1w8xmmZkzs12C++ea2Yf7sD3DzWy2mU3Ns22emc3rg2aJyFYs0dcNEBEp4lzgAfys6H1hOH4+o2XAwpxtn+/11ojIVk+BmYhsU8ysyTm3pdrzOOeeqUV7RESyqZQpInUpWHdzJ+DjQXnTmdmcrO37mNktZrbOzFrN7D9mdnDOOeaY2TIzO9DMHjSzVuD/gm2nmdk9ZrbazDaZ2SIzOzPr2EnAq8Hd32a1YVawvVsp08x2M7MbzawlaNNDZnZUzj6zg/NMMbPbgmu/Zmbnm1ksa7/BZvZLM1tqZlvMbKWZ/dvM3lntcysi9UuBmYjUqw8BbwJ3AAcGX98FCPp8PQiMBD4DfARYA/zbzN6Tc55h+IXVrwGOBv4SPL4zflHvjwMnArcCV5rZ54LtbwBh/7YfZrXhtnyNNbPt8WXXfYBzgFOAFuA2Mzs6zyE3AvcE174JuIDOdfkAfhGc4wLg/cDngMfx5VUR2UqplCkidck5t8jMtgBvOeceytn8E/zadTOdc+0AZnYH8DRwHj7YCQ0GTnfOdVkM2jn3g/D7IFM1DxgPnA1c7pzbYmaLgl1eydOGXF8BRgAHOudeCs57O37R8O8DuWvp/cw5d1Xw/b/NbCbwUSB87EDgaufc77KOubFEG0Skn1PGTET6FTNrBg4F/gZkzCxhZgnAgH8Dh+QckgL+kec8U8zsGjNbDiSDr7OA3Sps2iHAQ2FQBuCcS+Mzde82s6E5++dm3p7GL5YcehSYZWbfNLNpZhavsF0i0o8oMBOR/mYkEMdnxpI5X+cAI7L7agGrggCpg5kNBu7Clx2/DhwMTAd+DzRV0a438jz+Jj5oHJHz+Nqc+1uAAVn3vwD8BvgUPkhbZWa/MLOBFbZPRPoBlTJFpL9pATLAr4A/5tvBOZfJvptnlwPxAwsOds49ED4YZN4qtRYYl+fxcUEbcgOxopxzm4BvAN8ws52Ak4AfAe3A/1bRThGpYwrMRKSebQGasx9wzr1tZvfjs10Lc4KwqMKsUzJ8wMxGACfkuT65bSjgXuBcM5vknFsSnDMOnAoscs5trKCdADjnXgN+ZmYfB/aq9DwiUv8UmIlIPXsGONjMjsOXBN8Kgp6vAPcBd5jZ7/AlxNHAVCDunPt6ifM+CGwAfmVm3wEGAd8G3sKP4gytxI/2PM3MngTeBl51zq3Jc85fALOAu4JzbsBPQrsrcGyZPzdmNh+4BXgK2ITvV7cP8IdyzyUi/Yf6mIlIPfsG8DxwHb6f1WwA59xCfJ+wNcAlwJ3AxcDe+ICtKOfcavx0HHH8lBk/BK4E/pyzXwY/IGAEfmDBo8DxBc65AngfsBj4dXDekcCxzrl/Rf6JO92Hny7javxAgZOALzvnLq7gXCLST5hz+bpfiIiIiEhvU8ZMREREpE4oMBMRERGpEwrMREREROqEAjMRERGROqHATERERKROKDATERERqRMKzERERETqhAIzERERkTqhwExERESkTvx/pf24TVdGUlYAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "print(\"num_samples:\", NUM_SAMPLES)\n", "print(\"num_features:\", NUM_FEATURES)\n", "print(\"num_features:\", NUM_CLASSES)\n", "print(\"maxiter:\", MAXITER)\n", "print(\"stepsize:\", STEPSIZE)\n", "print(\"linesearch (ignored if `stepsize` > 0):\", LINESEARCH)\n", "print()\n", "\n", "errors, step_times, compile_time = run()\n", "print('Average speed-up (ignoring compile):',\n", " round((step_times['without_pjit'] / step_times['with_pjit']).mean(), 2))" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", "kind": "private" }, "provenance": [ { "file_id": "18Ea9Vl8fRiCpya0DXitfklcf4kaKj_kN", "timestamp": 1671533866034 }, { "file_id": "1RaTFt691iv9n0PVyb3lcp9tMWCUje7bK", "timestamp": 1664566635253 } ] }, "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.1 (main, Dec 23 2022, 09:28:24) [Clang 14.0.0 (clang-1400.0.29.202)]" }, "vscode": { "interpreter": { "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a" } } }, "nbformat": 4, "nbformat_minor": 0 }