{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "SDYtl1XHEEjK" }, "outputs": [], "source": [ "# Copyright 2022 Google LLC.\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# https://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "metadata": { "id": "nCt2V9cCR-CF" }, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jaxopt/blob/main/docs/notebooks/distributed/custom_loop_pmap_example.ipynb)" ] }, { "cell_type": "markdown", "metadata": { "id": "N3RErasEJpMo" }, "source": [ "`jax.pmap` example using JAXopt.\n", "=========================================\n", "The purpose of this example is to illustrate how JAXopt solvers can be easily\n", "used for distributed training thanks to `jax.pmap`. In this case, we begin by\n", "implementing data parallel training of a multi-class logistic regression model\n", "on synthetic data. General aspects to pay attention to include:\n", "* How to use `jax.lax` reduction operators such as `jax.lax.pmean` or\n", " `jax.lax.psum` in JAXopt solvers by using custom `value_and_grad` functions.\n", "* How `jax.pmap` can be used to transform the solver's `update` method to easily\n", " write custom data-parallel training loops.\n", "\n", "To obtain the best performance on Google Colab we recommend:\n", "1. In `Change runtime type` under the menu `Runtime`, select `TPU` for the `Hardware accelerator` option.\n", "2. Connect to the runtime and run all cells. \n", "\n", "NOTE: this example can be easily adapted to support TPU pod slices (e.g. `--accelerator_type v3-32`) as well as hosts with one or more GPUs attached." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "cellView": "form", "id": "LD0juyi9JmIZ" }, "outputs": [], "source": [ "#@markdown The number of optimization steps to perform:\n", "MAXITER = 100 #@param{type:\"integer\"}\n", "#@markdown The number of samples in the (synthetic) dataset:\n", "NUM_SAMPLES = 50000 #@param{type:\"integer\"}\n", "#@markdown The number of features in the (synthetic) dataset:\n", "NUM_FEATURES = 784 #@param{type:\"integer\"}\n", "#@markdown The number of classes in the (synthetic) dataset:\n", "NUM_CLASSES = 10 #@param{type:\"integer\"}\n", "#@markdown The stepsize for the optimizer (set to 0.0 to use line search):\n", "STEPSIZE = 0.0 #@param{type:\"number\"}\n", "#@markdown The line search approach (either `'zoom'` or `backtracking`), ignored if `STEPSIZE > 0.0`:\n", "LINESEARCH = 'zoom' #@param{type:\"string\"}" ] }, { "cell_type": "markdown", "metadata": { "id": "9cSpF3jQMezP" }, "source": [ "# Imports and TPU setup" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "AkeC41JXNTv0" }, "outputs": [], "source": [ "%%capture\n", "%pip install jaxopt flax" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "6UypR4hUNbve" }, "outputs": [], "source": [ "import functools\n", "import time\n", "from typing import Any, Callable, Tuple, Union\n", "\n", "from absl import app\n", "from absl import flags\n", "\n", "# activate TPUs if available\n", "try:\n", " import jax.tools.colab_tpu\n", " jax.tools.colab_tpu.setup_tpu()\n", "except KeyError:\n", " print(\"TPU not found, continuing without it.\")\n", "\n", "from flax import jax_utils\n", "from flax.training import common_utils\n", "\n", "import jax\n", "import jax.numpy as jnp\n", "import jax.tools.colab_tpu\n", "\n", "import jaxopt\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "import numpy as np\n", "\n", "from sklearn import datasets" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "iMxI2ZU0JnJO", "outputId": "97d0bdd8-5b35-4766-ff0d-f020de40f79c" }, "outputs": [ { "data": { "text/plain": [ "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "jax.tools.colab_tpu.setup_tpu()\n", "jax.local_devices()" ] }, { "cell_type": "markdown", "metadata": { "id": "8ZReXwUIO4Tt" }, "source": [ "# Type aliases" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "5A_PBZ7OO5rO" }, "outputs": [], "source": [ "Array = Union[np.ndarray, jax.Array]" ] }, { "cell_type": "markdown", "metadata": { "id": "gbqilnNSNpw1" }, "source": [ "# Auxiliary functions\n", "A minimal working example of how all-reduce mean/sum OPs can be introduced into JAXopt solver's `update` method by overriding `jax.value_and_grad`. Note that more complex wrappers are of course possible." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "awxxl498NwMI" }, "outputs": [], "source": [ "def pmean(fun: Callable[..., Any], axis_name: str = 'b') -> Callable[..., Any]:\n", " \"\"\"Applies `jax.lax.pmean` across `axis_name` for all of `fun`'s outputs.\"\"\"\n", " maybe_pmean = lambda t: jax.lax.pmean(t, axis_name) if t is not None else t\n", " @functools.wraps(fun)\n", " def wrapper(*args, **kwargs):\n", " return jax.tree_map(maybe_pmean, fun(*args, **kwargs))\n", " return wrapper" ] }, { "cell_type": "markdown", "metadata": { "id": "clQeXb77ObhK" }, "source": [ "A small utility to shard `Array`s across the available devices:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "rMntHIEFOeyO" }, "outputs": [], "source": [ "def shard_array(array: Array) -> jax.Array:\n", " \"\"\"Shards `array` along its leading dimension.\"\"\"\n", " return jax.device_put_sharded(\n", " shards=list(common_utils.shard(array)),\n", " devices=jax.devices())" ] }, { "cell_type": "markdown", "metadata": { "id": "4Pph9LUAOmhi" }, "source": [ "# Custom-loop\n", "The following code uses data-parallelism in the train loop. Through the `use_pmap` keyword argument we can deactivate this parallelism. We'll use this feature later to benchmark the impact of parallelism." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "tv1n_kkJO0Pf" }, "outputs": [], "source": [ "def fit(\n", " data: Tuple[Array, Array],\n", " init_params: Array,\n", " stepsize: float = 0.0,\n", " linesearch: str = 'zoom',\n", " use_pmap: bool = False,\n", ") -> Tuple[np.ndarray, np.ndarray, float]:\n", " \"\"\"Fits a multi-class logistic regression model for demonstration purposes.\n", "\n", " Args:\n", " data: A tuple `(X, y)` with the training covariates and categorical labels,\n", " respectively.\n", " init_params: The initial value of the model's weights.\n", " stepsize: The stepsize to use for the solver. If set to `0`, linesearch will\n", " be used instead.\n", " linesearch: The linesearch algorithm to use. If `stepsize > 0`, linsearch\n", " will be disabled.\n", " use_pmap: Whether to distribute the computation across replicas or use only\n", " the first device available.\n", " \n", " Returns:\n", " The per-step errors and runtimes, as well as the JIT-compile time for the\n", " solver's `update` function.\n", " \"\"\"\n", " # Value and grad of the objective function for the solver.\n", " value_and_grad_fun = jax.value_and_grad(jaxopt.objective.multiclass_logreg)\n", " # When `jax.pmap`ing the computation, use JAXopt's option to provide a custom\n", " # `value_and_grad` function to include the desired reduction operators. For\n", " # example, here we decide to average across replicas.\n", " if use_pmap:\n", " value_and_grad_fun = pmean(value_and_grad_fun)\n", "\n", " # To override `jax.value_and_grad` in a JAXopt solver, set the flag \n", " # `value_and_grad` to `True` and pass the custom implementation of the\n", " # `value_and_grad` function as `fun`.\n", " solver = jaxopt.LBFGS(fun=value_and_grad_fun,\n", " value_and_grad=True,\n", " stepsize=stepsize,\n", " linesearch=linesearch)\n", " # Apply the `jax.pmap` transform to the function to be computed in a \n", " # distributed manner (the solver's `update` method in this case). Otherwise,\n", " # we JIT compile it.\n", " if use_pmap:\n", " update = jax.pmap(solver.update, axis_name='b')\n", " else:\n", " update = jax.jit(solver.update)\n", "\n", " # Initialize solver state.\n", " state = solver.init_state(init_params, data=data)\n", " params = init_params\n", " # If using `pmap` for data-parallel training, model parameters are typically\n", " # replicated across devices.\n", " if use_pmap:\n", " params, state = jax_utils.replicate((params, state))\n", " \n", " # Finally, since in this demo we are *not* using mini-batches, it pays off to\n", " # transfer data to the device beforehand. Otherwise, host-to-device transfers\n", " # occur in each update. This is true regardless of whether we use distributed\n", " # or single-device computation.\n", " if use_pmap: # Shards data and moves it to device,\n", " data = jax.tree_map(shard_array, data)\n", " else: # Just move data to device.\n", " data = jax.tree_map(jax.device_put, data)\n", "\n", " # Pre-compiles update, preventing it from affecting step times.\n", " tic = time.time()\n", " _ = update(params, state, data)\n", " compile_time = time.time() - tic\n", "\n", " outer_tic = time.time()\n", "\n", " step_times = np.zeros(MAXITER)\n", " errors = np.zeros(MAXITER)\n", " for it in range(MAXITER):\n", " tic = time.time()\n", " params, state = update(params, state, data)\n", " jax.tree_map(lambda t: t.block_until_ready(), (params, state))\n", " step_times[it] = time.time() - tic\n", " errors[it] = (jax_utils.unreplicate(state.error).item()\n", " if use_pmap else state.error.item())\n", "\n", " print(\n", " f'Total time elapsed with {linesearch} linesearch and pmap = {use_pmap}:',\n", " round(time.time() - outer_tic, 2), 'seconds.')\n", "\n", " return errors, step_times, compile_time" ] }, { "cell_type": "markdown", "metadata": { "id": "1sEssbjlPOvt" }, "source": [ "# Boilerplate\n", "Creates dataset, calls `fit` with and without `jax.pmap`, makes figures." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "id": "jlz80AwbPRej" }, "outputs": [], "source": [ "def run():\n", " \"\"\"Boilerplate to run the demo experiment.\"\"\"\n", " data = datasets.make_classification(n_samples=NUM_SAMPLES,\n", " n_features=NUM_FEATURES,\n", " n_classes=NUM_CLASSES,\n", " n_informative=50,\n", " random_state=0)\n", " init_params = jnp.zeros([NUM_FEATURES, NUM_CLASSES])\n", " \n", " errors, step_times, compile_time = {}, {}, {}\n", "\n", " for use_pmap in (True, False):\n", " exp_name: str = f\"{'with' if use_pmap else 'without'}_pmap\"\n", " _errors, _step_times, _compile_time = fit(data=data,\n", " init_params=init_params,\n", " stepsize=STEPSIZE,\n", " linesearch=LINESEARCH,\n", " use_pmap=use_pmap)\n", " errors[exp_name] = _errors\n", " step_times[exp_name] = _step_times\n", " compile_time[exp_name] = _compile_time\n", "\n", " plt.figure(figsize=(10, 6.18))\n", " for use_pmap in (True, False):\n", " exp_name: str = f\"{'with' if use_pmap else 'without'}_pmap\"\n", " plt.plot(jnp.arange(MAXITER), errors[exp_name], label=exp_name)\n", " plt.xlabel('Iterations', fontsize=16)\n", " plt.ylabel('Gradient error', fontsize=16)\n", " plt.yscale('log')\n", " plt.legend(loc='best', fontsize=16)\n", " plt.title(f'NUM_SAMPLES = {NUM_SAMPLES}, NUM_FEATURES = {NUM_FEATURES}',\n", " fontsize=22)\n", "\n", " plt.figure(figsize=(10, 6.18))\n", " for use_pmap in (True, False):\n", " exp_name: str = f\"{'with' if use_pmap else 'without'}_pmap\"\n", " plt.plot(jnp.arange(MAXITER), step_times[exp_name], label=exp_name)\n", " plt.xlabel('Iterations', fontsize=16)\n", " plt.ylabel('Step time', fontsize=16)\n", " plt.legend(loc='best', fontsize=16)\n", " plt.title(f'NUM_SAMPLES = {NUM_SAMPLES}, NUM_FEATURES = {NUM_FEATURES}',\n", " fontsize=22)\n", "\n", " return errors, step_times, compile_time" ] }, { "cell_type": "markdown", "metadata": { "id": "v6Oo8E5hPbyd" }, "source": [ "# Main" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "byYCECkWPcbe", "outputId": "48b17428-0dff-4573-a8c4-d8c59f0f6c74" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "num_samples: 50000\n", "num_features: 784\n", "num_features: 10\n", "maxiter: 100\n", "stepsize: 0.0\n", "linesearch (ignored if `stepsize` > 0): zoom\n", "\n", "Total time elapsed with zoom linesearch and pmap = True: 5.88 seconds.\n", "Total time elapsed with zoom linesearch and pmap = False: 20.39 seconds.\n", "Average speed-up (ignoring compile): 7.08\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAm4AAAGYCAYAAAD7i26KAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd3xUVfr48c+TnpBOQq9KVRARUFwQCaCyAva+qNjXLa5u9WtbsKzbXNey6qqrrOVnwe5aEKUrIlWqFJFeQgnppJ7fH+dOmCQzk5lkJjMJz/v1mleYe+8598wwc+e5p4oxBqWUUkopFfmiwl0ApZRSSinlHw3clFJKKaVaCA3clFJKKaVaCA3clFJKKaVaCA3clFJKKaVaCA3clFJKKaVaCA3c/CQiW0XEOI+JPo5b4xwzus72uc72KQ2cZ7pz3NQ626e6nX+XiET7yGOc27FGRHr48RJ9lamLiPxdRFaLSJGIlInIThFZIiL/EpFL/MjjJrfy3NHAsVPcji0VkXQfx/au81pH19k/tc5+IyJVInJARGaLyHUiInXSjHaO29rQ63JLU/cc3h49PKQ9W0Tect7TMhEpEJHvRWSmiNwnIif6W47mUuf/yNujg4/040XkMxE5JCIlzvfmbhGJb+C8p4nIuyKSKyJHRGSTiPxVRNIaSNdXRF4Rkd3Oe7xNRJ4WkY4NpOvkHLfNSbdbRF4WkT6+36HAuH3mjIjsEJEEL8cN9fTZFJEervR+nMvjZ1FqX+P+3EAer7gdO7ehc/rIp4f4971Jd0vj6Tvt6TG9gXO/6nbs4Dr7/D2Hx++3r+97nfO43vPRdbbP9ZB3ufP5+1BEzveRZ6PfHxHJcNIvFXsdKheRPSKyUkSeF/u99/rbEw4BvN4tXtIPF5E3xF5/y8X+xq0Qe+1N8bMMUSKywO1cQ4P7KmuLCWXmrdifRORjY0x1mM7fCTgL+NTL/uuCdSIRGQX8D0gBDgJfAfuBDOBk4GfA5cBbDWR1fZ3yPepnERKAK4BnvOyf4mc+3wMLnX/HA32BHOdxvohcbIyp8jMvX94Ginzsr7VPRP4J/Mp5uhZYCpQC3YAzgLOBVOC3QShbKLi/r3WVetooIr8H/gJUAXOBPOBM4EFgooiMNcaUeEh3JfAyEA18CewChgO/Ay4UkRHGmFwP6c4EPgESgeXAfGAQ8FPgYhEZaYzZ6CFdf2AB0Bb4DngX6ANMBi4SkbONMV96ee1N0QX4BfD3EOTtr6tF5G5P3wkRSQUuCsE5/+tjX7mHbb4+e/jaJzbQv9Bt0/XAL92er/RSnvFAe+znb7OH/b6++43hfp4UYCAwEfs9edwY8yuvKQN8f0TkBOBzoCP2dXwD7AOSnfPe4DzeIvivsym8/V+5nIf9vZpTd4eIXA88DwiwGvv7lgqMAKYBP3GuD/sbKMPtwEjAOHmFljFGH348gK3Of0qx8/dqL8etcfaPrrN9rrN9SgPnme4cN7XO9qnO9iXO39e9pE8DSoB12C+XAXo08jXHY38cDfZHJMHDMUOAhxvIp7+TRxFw2Pn3MB/HT3GOWQ5UAl97OS4K2AEcAjZ5ed9d79t0D+kvxAYPBrjRbftoZ9vWAN4rE+h7DUxy0hQAOR72JwGXAT8J9+ffx/9Rvfe1gXRDgWrne3Sa2/ZkYJ6T56Me0nVxPtdVwPlu22OA151073pI1wbY4+z/RZ19f3e2LwPEw2frW2f/3+rs+6WzfReQFKT30/WZK3H+HgTSvLx/9T6bQA/XZ7Cxn1WOXuNc15jxXtLf7Oz/xvk7twmv2+9yu6Xx+p0OII9bnTx2ur3f8X6km4t/13G/rgdu7/lof86DDQrudsv/1GC9P873wACvAqke9vcD/gokBuMz3xwPbCVHpfO6flRnXzpHf8/rvs9tgRXOvscaOEcf53v7P7f/z6GhfF3aVBq4x52/00QkLgznXwysx9YSeWpCvBxbszA9COc6A/vB322M+a0x5kjdA4wxy4wx/9dAPq7athnYH1n3bb7sBmYBp4lIPw/7x2F/0F8HyvzIrxZjzLvYixTApYGmD4LLnb9PGmPq3Q0aY0qMMW8aY16tu68FuxP74/MXY8xi10ZjTBG2JrYa+JmHz/bt2M/1f40x77ulq8QGEgXABU6tgbvrgA7AHGPMk3X2/QFbK3EK8OM6+84FTsLWdtzpvsMY8wT2h7UT/tf4+msL9gcgE/h9kPP213Tn7xQv+6dgA+iXm6EsoeK6/tyNrU3NBC4IX3H8Y2yk8CfAVUPstdtOIESkF/Z7UAncbIwp8HDu74wxvzfGeKxJj1DXYmvovzPGfFVn32nYm+ONxpjp7juMMQexQSrA6d4yF5Eo7PelAluD3yw0cAvc29g7zZ40439UHdM52oRY13UE76LazvnbUDWxVyISA1ztPH0ReMH59xXipR9PHS86f6d42HddnWMaY6nzt3sT8mgs1/tbr3mvNXJudFwBUr1g1BizBVgExGEDJ3euH1VP6QqAD+sc50+6Ko7eSHhL97rx3IT+ap3jgukubAB7u/joJxhCXm8ORaQv9odsJrYms8URkQHYmssibLOf6/rhz81k2DnB22rnafsgZeu6FhUZY4qDlGckcP1G/MfDPn9v9g/42Pcb7Pfht8aYnYEUrCk0cGscVw3T3SKSHIbzv4wNzmr1ZXMuqsOBT40xwbiobnf+DhCRsY3MYwL24rIFWGCM+QbbjJsOXOxH+vexfaCudu8U6/ygXACsNcYsaWTZwPZngEbU2AWB6/2dIg10ro9gvUTkQRF5VuwAlqt8fCf6Yu9wDxljvvdyjOv/sqazuNOn6vg6+xtMV+d5c6VrMmPMamxgmATcF+z8/fQi9ubwyjrbp7jtb6lucP7OcIIU1/V0nIh0DV+xAuK6XuwLUn6ua1G6NDCArqUQkTOA3tjasJc8HLIMG5T1qfuaRaQtR2u8n/OSf3/gfmC2McbjMaGigVsjGGNmA59h71J+E4bz78He8Z7qfHhcXIHc9CCd6itsx89oYJaIzBGRe0TkXBHJ9jMP113sdOdOEQK4wzXGlAGvYZulznbbdQX2h2W6n+Wox6nmdo3OWtnYfJrgWewPxiBgm9jRireKHTnZ6GZ4qT060N/H1kaebgS2uekm7HfhVWC7eB5p3NP5u93DPurs6+m2rYfz97CnJhxv6ZyAL9N5ui2A87k/byhdVohu3u7Ddsi/UUSOb+jgEHAFM1NcG5wbp2uwfUo/CEOZmkxEYrGDS8C5DjnX00+xv4dTwlMy/4lIO+BU5+n/gpGnU1vkqrV+UUS+cW7ILhCRLo3NV/wbfe7pMSUIL8v1+/I/42HQkjGmEPt5LsK+5lXO6NJPsH3VOmD7Pr/r4XVFY397qrHXvmalo0ob7/+wIzt/IyJPmYZHnQTbi9jmpCnAH5wP0tXYTrZBuagaY6pFZAL2A3oWtgP1aNd+EVkJ/Bt4zlNzkoi0d8pYTe1RPy8DDwM5ItLDGLO1gaK8iB29OgU7OhBskFpJI5qEnaCoL/bHcSj2B6pu/6fG+kHE66Cib40xJ7ueGGO+EZGLgaewgelkjv6oHBGRj4E/N6JG8S0gK8A0vpoDPNmDHQX6AbY2tRI7COX32EEfb4jIucaYmW5pXAGOr6YY12g192H4TU3nK62ndP6c031UXQpBHmVnjNkqIs8At2Hf57o1XyFljNkrIp8CE0SkvzFmPfbGqRO2T2a5j895o4j3qUyuq9sHyXGtiFzrI8sLjTHv1dl2Hva7sdkYs8Bt+wvY1oEpIvKg201mxBA7NcXJ2EE1qdhrwzc+kgT6/lyNHb1/OTDMebjOvRH7Hj0eYB+3zfge8ekrXaM575Wr37KnZlIAjDGfiJ054Q3syNmBbrtnYgcoePI7bPB8u9PFo1lp4NZIxpjlIvIm9kN+N7bzdHP6AHvnO1lE7qLORTVYJzHG7AbOFpGTsBe907EjSdtjLyJPY6dUmODhvNdgP2OfG2O2u+W5zwlKzsMGYH9soAxLRWQNts9NBvZO6FTsnZS/TQXeLmKFwC1NbG5152s6kHo1TcaY950fyB8DY7EXy0HY2sSLsK/5p8aY5/0tgDEm5FOHOAHZzDqbv8ZOk/EI8GvgEQ/HKP89iP1+XC4ifzXGePsRCZXpOMEMdiDHFLftoeDtB97bj3hD0114qtmtaQGos/1D7M3Lcdib03qDhcLkRRGp2yxtgBuMMS94SuAmoPfHGJMPXCki92K7oYzADljohh05+Wdn/2hjzGF/Cm+MWdhAGULlcuyI8t14nzYLEfkpdsDhXGzguhY7dciF2O/fBBE53xjzmVuaE7EjdxcBT4Sm+L5p4NY092D7af1URB41xnhrVgH7ZYOG53hx7fd5x+fc8f4/7HxP5xDivifGmFXAKtdzERmErV25Cju681fA3+ok8zV44EVs4HatiEwzDc+JNx17p3klR5vOAnmt7hexKuy0JN8CH/h7EfLTb/2oQazFaQ5+z3kgIknY+aIexl4w/yUinzZn59cmehD7eThRRLq5Be2ugLaNj7Sumq5Ct21NTedKm+9nOlfaDB/ndK/Nq5s2KIwx+0XkH9gbm4exnwmvh7v+ISLircZIaleTNVSr9AG2Bv9qEfkbtlvBamPMMn/KHyhjzJQAkywMJI2IdMJeK6up0+fJGFMhIq9iP7fXE7zAranXe/d53LKwI/1TgcdFZL0xZpGPvAN6f1yMMZux19q/Q03f6Z9jf2sGAQ85zyOZK0D/r5fBRYjICGzFw2pggjGmwtlVhH1/D2C7fjwjIr2NMVVuTaQA1/vxuxUSGrg1gTFms4g8jx1dej926LE3rglFff34wNEfBH+aXl7EfpluB0YBq4wxy/1I12TGmG+xkxMmYwOwC3AL3ETkdGzTGcAdIvKzOlm4PnvdsTVNsxo45SvYO77rsTWLBzjaJ8MfjbqIhYOxk8++IyKLscP+k7A1cn51gBWRv9OIptJg1dQZY/JEJBc7kWdnjt7Zb3X+dvOR3NU5fKvbNtcNUbqIpHrp51YvnTGmQETysAFYd9xuPBo4n+u5K923PtIddKYyCZVHsN0EzhE7u763c7lPWNzGx3HuAafPcrvdHP4Se62Jp2UPSnBNDVEGvOahqbet8/diEfm5j/6U/ijBfm+ber1/3r2Z2Om3+S4wBvsaTjAeJqsOJmPMBuA2pyn7Nuy13q/ATURGAjc24rTPO7V1ARM7dZRrCg9ftZJTnL8z3II2dzOwQVpPbE3sJuz3fih2wNwzHj5DrlHg/xaRQuAtD9MQNZkGbk13P7ZJcLJzV+qN68erVwP59Xb+7mjoxE5z7Wps/zMIXROGL59hA7e6gxXcBx40tPzH9TQQuDnNq59gJ60F29fC05et1TDG7BKRddj3z9/BIACXEPj0JtsI0uoMzl2pa9Sb+w/Sd9jVFDJF5HgvI0tdna5rmgWNMfki8j12ZOkw4At/0jmWc7QJ2lPg5ivdYCedpz6j3tIFlTGmUEQeAv6JvXH5hZdDD2H747XBXmO8DbZxXV+KsD8+DZmODdwmYvsxtuQ5BV0tAPHYZkBvErE1+/9uwrm2Yyes7YWdlL0eEcnk6OCZBq/3UHMzcgX2u9Qd2yXhwSaUMxCfYQO3QK5FvfBdoeHNXBrfxOr67Znn1B5647qB9FQT76qFLcbOgJBZZ3cGdrUXb05x/oZk0JuOKm0iZ0TSY9j38k8+DnVVvZ8ndm6zepwq6YHYprwFno7x4N/Y5oxcbK1U0IiH2wkPXB/+mmY8p6nPNbnsKGOMeHoArslSL3D6rjXkeexrPUjLvvMHGn5/nQCok/PU72ZSY0wPb++5j0ePJryUuiZiaxsKsT8wrnKVc3RwyU/qJhKR47B3yuXAR3V2uybd9ZQulaMBfd0RYL7SRXN0LkRv6a4Qz2szuvKrN+IsBJ7G1gCehpelppzmoPnOU1/T7LhG+873p5nHqcH/Evudm+FpdF5LIEenhigC2vi4JrlaBpo6p5vreu/P/8W6QN5XZyDcA87T34qPtZz91dhrfUOMMdMbcS0SL4NR/HkdMdiKFPAxKMGx2/k73Ete/bBBGzg18saYrb7KzdHWgWHOtpD0fdfALTj+ir17nUT9aQVc3sX2s+qJ7bNUa/JZp/+FK/B6xfg5D5sx5l/GmCxjTHsT/JGtk0TkHREZI3bqDPfyiohcwNEagDfcdl+KHWm3FR93TcaOVFuG7Yh/VUOFMcZ84LzWLGNMOKbvCLb/iMj94nnh+TTsSNdO2ADo4+YtmncikiR22pJ602A4o5BdTbr/8lAr+mdsf54/iMipbumSsc0aUcBTHvod/hNbW3etiJznli4Ge/OSCrxnjFlXJ92LwF7sCOa6zTt/xtbireBoQOnyEbaGrhe2f5n7a/wFtgP7bjzUcsvRBcKn1t3XGE7A6xrA42ttStcSXr9x/h/qlmsScAdHl7Dz9/wjne9cg9/RCOYKxN5uoGnxDeyNw6lOJ/TGehzbJPsTEbmh7k6nK4nrRt9XS403T2ODhDTs/2lTnSQis0VkktgpU2pxAl/XZ/CNuvsjiGve0HwaXj/btf8KEZnsvkPsxNeuZtY5xv9BcM1Cm0qDwBhzWEQexgZwSV6OqRCRi7AjXG7G9qNYhP2AdcRW3cdj5067rVkK3rAo7OiaC4FDIrICW7OXiq0tcwWpr1G7/5XrIvmKt07Sbl7GjlK9HvhXkModLB1F5Gsf+5cbY+r23fu7iPjqO/S4Wz/ETGzzzb1ih9uvx/aN6YBtoksGjgDXGmMCna4jlOKwU5j8Q0SWY5t54rB9Gl1Lk72Dh8ljjTFLRORO7CLzX4nIbOxAkTOx8yIuxo7Srptuh/MD+DLwnogsxAZOw7FNRpuBWzykK3Kalj4BnhSR67B9VQY55T0AXFn3c2rsVDhXYmuxficiE7F93XpjP6+lwOVeggDXTU4wm/JfwU5BMMDbAcaY2SLyB+x7+z8RWYsdJQdwovMwwB+MhyXWWqCRIjLdx/7txpj7nJsC19QQPqcPMsYcEjvi/QLsNalR83QaY74Tu4D5i8DzYkf+u9Ze7oX9DAl2FgBfr8Fb/mUich92JO6vROSfxpi6Td9+vT/OvwXIcR5Fzvd6N7bZuDdHW0fm0HxNs43h+u15zTQwbYkx5iOnj/qNwMsi8n/YyeEzsLXbydj3oNnnaWuQiYCFYFvCgwYWj8XWGu3g6OK/o70c1xY7lHgJNmirwC4p9QX2AxLrJd1UJ98nAyhzUxeZT8COZPs7NqDchg0kSrFzd72JHY3jnuZ47KgtA/Tx4xztnPfAACc526Y4z/8XQFnXeHrfacSCyxxd8Luhx1y3NP4cb4AL3NJ0xgZur2Jrd3KxF/Z87EX+70DPcH/2Pbw/cdimmpnO96IIW0uxC9vEeJEfeYzH9mvMcz5Pa7EBm8+FvrEX1Pec70wZNmD7Kx4WZK+Trq/zPu910m3HzlnVsYF0nZzjtjvp9mCDKI+fbWzn98PO96R7Iz5za3wcM9Htc7TVx3GnYkdNbnHeW9f39SU8LErulm4rASyQjW3qq/U9aMRnqYfrNQWQZir+fddWOsff4DzfCUT5kf9FzvH7qHM9xs9F5t2O7+98fjZg+yCWYX8n3gLO9pGuwfNgbxBWOcc90Nj3x0kTg715egCY53xeSpzP8Q5sP8+r/Hn/wvXA1rS5fkuGBZDuImwN+14nfRH2Ju0hoG2AZQjoO9TYhzgnU0op1UROE9hXwKPGmF+HuzxKqdZH+7gppVTwnAUUYO/WlVIq6LTGTSmllFKqhdDBCceQBjqq1tXoCRCVUseeQCdbNS1kQmylIo3WuB1DxPsizp54W9hZKaXqEZEpBDC/orHzXimlAqSBm1JKKaVUC3HMNJVmZWWZHj16hLsYSimllFINWrZs2QFjTL0lxo6ZwK1Hjx4sXbo03MVQSimllGqQiGzztF2nA1FKKaWUaiE0cFNKKaWUaiE0cFNKKaWUaiE0cFNKKaWUaiE0cFNKKaWUaiE0cFNKKaWUaiGOmelAlFJKtX4FBQXk5uZSUVER7qIoVU9MTAwJCQlkZ2eTkJDQuDyCXCallFIqLAoKCti3bx+dO3cmMTEREV1VS0UOYwyVlZUUFRWxfft22rdvT1paWsD5aOCmlFKqVcjNzaVz584kJSWFuyhK1SMixMbGkpGRQXx8PHv37m1U4KZ93JRSSrUKFRUVJCYmhrsYSjUoMTGRsrKyRqXVwE0ppVSroc2jqiVoyudUAzellFJKqRZCA7cgWfz6wyz6z2/DXQyllFJKtWIauAWJ5K5l8Pbp5O3fE+6iKKWUOgZNnz4dEWHr1q0126ZOncrs2bPrHTtlyhS6dOnSjKVTwdLqAzcRmSQiz+bn54f0PO3PuoMEqeC7//0zpOdRSimlPJkwYQKLFi2iY8eONdumTZvmMXBTLVerD9yMMR8aY25uzJDbQHTvP4RvE4bRe9vrlB0pCem5lFJKqbqys7MZPnw48fHx4S6KCqFWH7g1p6gf/ZIsDvPtJ/8Jd1GUUkq1YMuWLUNEWLhwYc22J554AhHhnnvuqdm2adMmRISPPvqoXlOpa+TiQw89hIggIkydOrXWeVasWMEZZ5xBUlISvXv35plnngmonHPnzkVEePvtt5kyZQoZGRmkpqbyk5/8hIMHD9Y61lX2Rx55hO7du5OUlMSECRPIzc0lNzeXyy67jLS0NLp27cpf/vKXWmn379/PLbfcQp8+fUhKSqJr165cddVV7Nq1q9ZxU6dORURYvXo1OTk5JCUl0bFjR+677z6qq6sDem2RSgO3IBowchJbonqQvfp5TCv5gCillGp+gwcPJj09vVYz5+zZs0lMTKy3LSYmhlGjRtXLY9GiRYDtz7Zo0SIWLVrEjTfeWLO/oKCAq666ismTJ/P+++8zbNgwbr31VubMmRNweW+//XZEhNdee42HHnqIDz74gEsuuaTecS+//DKzZ8/mqaee4sknn2TBggVcc801XHjhhZx00km8/fbbnHvuudx55518/PHHNekOHTpEQkICDz/8MJ9++il/+9vf2LRpEyNGjODIkSP1znPBBRcwbtw43nvvPa666ioeeOAB7r///oBfVyTSlROCSKKiOHjSTQxbeTerF37IwFHnh7tISil1TJv24VrW7S4IaxlO6JTKHyedGFCaqKgoRo0axZw5c2pqi+bNm8ett97K448/TlFREcnJycyZM4chQ4aQkpJSL4/hw4cD0Llz55p/uyssLOSpp54iJycHgFGjRjFz5kxee+21mm3+OvHEE3nxxRcBGD9+PJmZmUyePJkvvviCsWPH1hwXHx/P+++/T0yMDT/WrFnDo48+ygMPPFBTkzh69GjeffddZsyYwbnnngtA3759eeyxx2ryqaqqYsSIEXTr1o1PPvmECy+8sFZ5brrpJu68804Azj77bAoKCnjkkUe4/fbbSU9PD+i1RRqtcQuyk8ZfzwHSqf7qCY/7V3z2CoufuJbKivJmLplSSqmWZMyYMSxatIgjR46wcuVKDh8+zO9//3vi4+NZsGABAHPmzAk4yHJJSkqqlTY+Pp4+ffqwffv2gPO67LLLaj2/9NJLiYqKqqn1cznrrLNqgjaAfv36AXDOOefUbIuJiaFXr17s2LGjVtqnn36aQYMGkZycTExMDN26dQNgw4YNDZbniiuuoKioiDVr1gT82iKN1rgFWXxCEst7XMnpW59m2/pldO8/pGbf4jf/xrC1DxElhs3rbqLXoJFhLKlSSrV+gdZ0RZKcnBzKysr46quvWLFiBYMGDaJ9+/aMHDmSOXPm0K1bN3JzcxkzZkyj8s/IyKi3LT4+3mPTY0Pat29f63lcXBwZGRn1+qDVPWdcXJzX7e7leOKJJ7jtttv49a9/zd/+9jcyMjKorq5m+PDhHstbtzyu53XL0xJpjVsI9JvwK0pNHPtmPQqAqa5m0Qu/47R1D7I+fgAAB9bODWMJlVJKRbqBAweSlZXF7NmzmT17dk2ANmbMmJptcXFxjBgxIswlhX379tV6Xl5eTl5eHp07dw5K/q+//jpjx47lkUce4eyzz2bYsGG0a9fO7/K4ngerPOGkgVsIZGR3ZFXWuQw6+CkH9m7nm39dx+nbn2VJ+o/p89sv2Es2cbsXh7uYSimlIpiIMHr0aGbNmsWCBQtqBW4rVqzg3Xff5dRTTyUpKclrHnFxcZSWloa8rG+++Wat5zNmzKC6uprTTz89KPmXlJQQGxtba5urT50/5Xn99ddJTk5m4MCBQSlPOGlTaYh0POcO4v/fe+Q/cyancYhFHScz/KYnkKgodqYOokfBUkx1NRKlsbNSSinPcnJy+PnPf050dDRnnHEGYEecpqSk1Axc8OWEE07go48+Yvz48WRkZNCpUyc6deoU9HKuXbuW6667jiuuuIKNGzdy9913M3r06FoDE5pi/Pjx/OUvf+FPf/oTp556KrNnz+att97yevxzzz1HdXU1w4YNY+bMmTz//PNMnTqVUM/p2hw0agiRbn1OZmXicNpxiK97/5rTb/lXTZBW1WU4WRxm99b1YS6lUkqpSOYaPDB06FBSU1MBiI6O5swzz6y135snn3ySNm3aMGnSJIYNG8azzz4bknI+9thjGGO4/PLLueuuu5g4cSIzZswIWv733Xcft9xyC48++igXXnghq1atYubMmV6Pf//995k1axbnnXcer7zyCvfccw/33ntv0MoTTmKMCXcZmsXQoUPN0qVLm/Wcefv3sH/HBvqcMrrW9h/WLaHnm+NYcvJDDLvgF81aJqWUaq3Wr19P//79w12MY8rcuXPJyclh1qxZjBs3LtzFYerUqUybNo2Kiopao1cjUUOfVxFZZowZWne71riFUEZ2x3pBG0D3vqdQQBvMtkX1EymllFJKeRHZ4WgrFRUdzQ+JA+iQvyLcRVFKKaU8qqys9Lk/Ojq6mUqi3GmNW5iUdDyNbtW7OJTb8ueUUUop1frExsb6fPz3v/9l9OjRGGMiopkUbFOpMSbim0mbovW+sgiX0fcM2PI421bOIfPsyeEujlJKKVXLkiVLfO7v2bNnM5VEudPALUx6DhpJ2cexlH2/ENDATSmlVGQZOu1yiCoAACAASURBVLRev3gVAbSpNEziE5LYEteHjIPLw10UpZRSSrUQGriF0eHsIRxXsZnS4sJwF0UppZRSLYAGbmGU1OsMYqWKLSvnh7soSimllGoBNHALox4n51BthIKNGrgppZRSqmEauIVRWmY226K702af75E7SimllFKggVvY5WaczPGla6msKA93UZRSSikV4TRwC7PoHj+ijRxh6zqtdVNKKdV406dPR0TYunVrzbapU6cye/bsesdOmTKFLl26NGPprLlz5zJ16lSqq6ub/dythQZuYdbl5LEAHFg3L8wlUUop1ZJNmDCBRYsW0bFjx5pt06ZN8xi4hcvcuXOZNm2aBm5NoBPwhlmHrr3YSzaxuxaHuyhKKaVasOzsbLKzs8NdDBViWuMWAXamDqJb0bcYvQNRSikFLFu2DBFh4cKFNdueeOIJRIR77rmnZtumTZsQET766KN6TaUiAsBDDz2EiCAiTJ06tdZ5VqxYwRlnnEFSUhK9e/fmmWeeqVeWb775hnHjxpGcnEybNm0YO3Ys33zzTa1jRo8ezejRo+ul7dGjB1OmTAFss+20adMAuw6qq0z+EhHuvvtuHnroIbp06UJiYiKjRo1i5cqV9coycuRIPv30U04++WQSExMZPHgwixcvprKykrvuuouOHTuSmZnJlClTKC4urpX+j3/8I6eccgqpqalkZWUxZswYvv7661rHzJ07FxHh7bffZsqUKWRkZJCamspPfvITDh486PdragwN3CJAVZfTyCaPPds3hbsoSimlIsDgwYNJT0+v1cw5e/ZsEhMT622LiYlh1KhR9fJYtGgRYPuzLVq0iEWLFnHjjTfW7C8oKOCqq65i8uTJvP/++wwbNoxbb72VOXPm1ByzatUqzjzzTPLy8pg+fTovvfQSBQUFnHnmmXz77bcBvaYbb7yRG264AYCFCxfWlCkQL730Eh9//DFPPvkk06dPZ9++fYwdO5ZDhw7VOm7z5s387ne/484772TGjBmUlZVx3nnnceutt7Jnzx6mT5/Offfdx6uvvloTTLrs2rWLO+64g/fff5/p06fTrl07Ro0axerVq+uV5/bbb0dEeO2113jooYf44IMPuOSSSwJ6TYHSptIIkNShN6yDw3u20KlH33AXRymlWo9P7oS99X9wm1WHgfDjPweUJCoqilGjRjFnzhzuu+8+qqurmTdvHrfeeiuPP/44RUVFJCcnM2fOHIYMGUJKSkq9PIYPHw5A586da/7trrCwkKeeeoqcnBwARo0axcyZM3nttddqtt1///3Ex8fzxRdfkJ6eDsBZZ51Fjx49mDZtGu+8847fr6lLly41AyJOO+00YmICD0FKS0v57LPPaNOmTU0+vXv35tFHH+WBBx6oOe7gwYN89dVXHHfccQBUV1dz/vnn88MPP/D5558DcM455zB//nxmzJjBX//615q0zz//fM2/q6qqGD9+PCeeeCLPP/88jz32WK3ynHjiibz44osAjB8/nszMTCZPnswXX3zB2LFjA359/tAatwiQlN4egCP5+8JcEqWUUpFizJgxLFq0iCNHjrBy5UoOHz7M73//e+Lj41mwYAEAc+bMqQmyApWUlFQrbXx8PH369GH79u012+bPn8/EiRNrgjaA1NRUzjvvPObNa/5Bdeeee25N0Aa2KXb48OH1au769OlTE7QB9OvXD7DBmrt+/fqxc+dOjDE12z7//HNycnJo27YtMTExxMbGsnHjRjZs2FCvPJdddlmt55deeilRUVEB1yQGokXWuIlIG+ApoByYa4x5NcxFapLUrE4AVBTsD3NJlFKqlQmwpiuS5OTkUFZWxldffcWKFSsYNGgQ7du3Z+TIkcyZM4du3bqRm5vLmDFjGpV/RkZGvW3x8fEcOXKk5vmhQ4dqjVJ16dChA3l5eY06b1O0b9/e47a1a9fW2lb3tcXFxXndXllZSVVVFTExMSxfvpxzzz2Xc845h//85z907NiR6Ohobrzxxlrvi7fyxMXFkZGRwa5duxr1+vwRMYGbiLwATARyjTED3LaPBx4DooHnjTF/Bi4C3jLGfCgibwAtOnBLa9sBgOoiDdyUUkpZAwcOJCsri9mzZ7NixYqaAG3MmDG8+eabdO3albi4OEaMGBGyMmRmZrJ379562/fu3VsrCEpISKCgoKDecXX7njXVvn31W6b27dtH586dg5L/22+/TUxMDO+88w6xsbE12/Py8mrVOnorT3l5OXl5eUErjyeR1FQ6HRjvvkFEooF/AT8GTgCuFJETgC7ADuewqmYsY0jExSdQQBJRJQfCXRSllFIRQkQYPXo0s2bNYsGCBbUCtxUrVvDuu+9y6qmnkpSU5DWPuLg4SktLG12GM888k48//pjCwsKabYWFhXz44Ye1RpF2796djRs3Ul5+dBWg+fPn10oHtkYPaHSZPv7441qjQLdu3crXX3/N6aef3qj86iopKSE6OrrWaNfZs2fXaj529+abb9Z6PmPGDKqrq4NWHk8iJnAzxswH6obmpwKbjTFbjDHlwOvA+cBObPAGPl6DiNwsIktFZOn+/ZFdm5Uv6cSUBffORCmlVMuWk5PDN998Q0lJCWeccQZgR5ympKQwZ86cBptJTzjhBD766CNmzZrF0qVL2b17d0Dnv/feeykpKWHs2LG8/fbbvPPOO4wbN46SkhLuu+++muOuuOIKDh48yPXXX8/nn3/Oc889xy233EJaWlq98gA88sgjLF68mKVLlwZUnsTERM4++2zee+893njjDcaPH09qaip33HFHQPl4M378eIqKipgyZQpffPEFTz/9NJMnT/Zag7Z27Vquu+46Zs6cyRNPPMGtt97K6NGjQzYwASIocPOiM0dr1sAGbJ2Bd4CLReRp4ENviY0xzxpjhhpjhkb6pITFMenElzd/fwGllFKRyzV4YOjQoaSmpgIQHR3NmWeeWWu/N08++SRt2rRh0qRJDBs2jGeffTag85900knMnTuX1NRUrr32Wq6++mqSk5OZN28egwYNqlXOZ555hsWLFzNp0iRefPFFXnnllXrNixMnTuRnP/sZTz31FKeffjrDhg0LqDzXXHMNEyZM4Be/+AXXXnst2dnZfPHFF2RmZgaUjzfnnHMOjz/+OF9++SUTJ07khRde4KWXXqJXr14ej3/ssccwxnD55Zdz1113MXHiRGbMmBGUsngj7iMpwk1EegD/c/VxE5FLgPHGmBud51cDpxljfhFo3kOHDjWBRvbNacVff0z6kV30vG9VuIuilFIt0vr16+nfv3+4i6FCxDUB74MPPhjuojB37lxycnKYNWsW48aNa1QeDX1eRWSZMWZo3e2RXuO2C+jq9ryLs63VqUhoS0p1friLoZRSSqkIFjGjSr1YAvQWkZ7YgO0K4KrwFik0qhLbkm4KqK6qIio6OtzFUUoppZpFVVUVvlr/oqKiiIqK9Hqm5hMxgZuIvAaMBrJEZCfwR2PMf0TkF8BM7HQgLxhj1vrIpsWSNlnESDWH8/aTntUh3MVRSimlmsXYsWN9TuZ77bXXMn36dJ/BXXMbPXp02MoTMYGbMeZKL9s/Bj5ubL4iMgmY5K1jYaSISbaDJ/IP7tXATSml1DHj3//+d71pQ9xlZWU1Y2kiX8QEbqFijPkQ+HDo0KE3hbssvsQ7y14V59Wf6FAppZRqrfr21TW6A6GNxhHCtV5pma5XqpRSSikvNHCLEClt7Vpw5fm5YS6JUkq1XJHUD0opb5ryOdXALUKkZ9nArbpYl71SSqnGiI2NbdLyTko1l9LS0prlvwLV6gM3EZkkIs/m50f2HGm6XqlSSjVNu3bt2LVrFyUlJVrzpiKOMYaKigoOHTrEzp07adu2baPy0cEJEaRA0og5cjDcxVBKqRbJtSTU7t27qaioCHNplKovJiaGhIQEunXrRkJCQuPyCHKZVBMURet6pUop1RSpqak1AZxSrVGrbyptSUrjMmhTcTjcxVBKKaVUhNLALYJUxGeSUq2Bm1JKKaU808AtgrivV6qUUkopVZcGbhFEkrOJkWoKD+vIUqWUUkrV1+oDt5YyHQi4r1e6J8wlUUoppVQkavWBmzHmQ2PMzWlpaeEuSoPi09oBUHRI1ytVSimlVH2tPnBrSRJ1vVKllFJK+aCBWwRJzeoEQHnB/jCXRCmllFKRSAO3CJLW1ta4VRdp4KaUUkqp+jRwiyDxCUkUmkRdr1QppZRSHmngFmHyo9KJOXIoJHnn7f6eA9vWhSRvpZRSSoVeqw/cWtJ0IOBarzQ0gduel27g0CtTQpK3UkoppUKv1QduLWk6EAjdeqXVlZV0P7Ke1IqDQc9bKaWUUs2j1QduLU1FfEZI1ivd9f0q2nCENFMQ9LyVUkop1Tw0cIswVYlZpJlCTHV1UPPN/e4rABKlnCMlhUHNWymllFLNQwO3CCNtsoiVKgoOB7dJs2rHspp/5x/SCX6VUkqplkgDtwgTk+KsV3pgV1DzzTi8hiojABRp4KaUUkq1SBq4RZj4NDsJb3Fe8IKrivIjdK/YwsbYfgCU5usEv0oppVRLpIFbhHGtV1p6ODdoeW5b9w1xUsmhTqMAKNPATSmllGqRNHCLMCltOwBQURC8wO3QxkUAtD15IgBVuqSWUkop1SK1+sCtpU3Am57VEYDqouAFblG7V3CIVHoNPJ1qI1QXh2aCX6WUUkqFVqsP3FraBLyu9UqlJHijSrMK1rI9oR8xsbHkSzJSqpPwKqWUUi1Rqw/cWqL8qLSgrVdaUphHt6odlGQPAqBQUogpywtK3koppZRqXhq4RaDi6HQSyoNTK7ZtzSKixJDUYxgAJTFpxJcHf2UGpZRSSoWeBm4RqCQuk6QgrVda8P03AHQZMBKA0th0kio1cFNKKaVaIg3cIpBdrzQ4gyli9q5gD9lkte8MQHlcBslVIRyoUVUJH/0W9m8M3TmUUkqpY5QGbhGoKrEt6aYgKOuVdipex+7kE2qeVydmkmoKwZgm5+3RnpWw5DnKl78SmvyVUkqpY5gGbhEoWOuV5u3fTUeTS3n7k49uTMwkQSooLQ7NQvMH188DYN93X4ckf6WUUupYpoFbBIpJaQdAwcHdTcpnx5ovAUg57tSabdHJWTbvQ3ublLc3xZvtOTPz14WuVk8ppZQ6RmngFoHiUm3gVtzExeCLf1hCtRG6D/hRzbaYFBu4FYZioXljyDiwjFITR5vqQkze1uCfQymllDqGaeAWgZIy7HqlJYebFlwl7V/J9ugupKRl1mxLTLNBYWl+8FZmcDEHN5NSdZj3OROAfRu0uVQppZQKplYfuLW0Ja8AUtraZa+asl6pqa6mS+l37E85sdb2pLRsAMoKDjS+gF7krp0LQMGJV1NmYji8+Zugn0MppZQ6lrX6wK2lLXkFbuuVFjd+Mfh9O7+nLflUdxpca3uqs4h9VVHwA7fCjV9yyCRzVs5YNks3YvauDPo5lFJKqWNZqw/cWqL4hCSKmrhe6e51dpBARu/Ta21PTc+iygimOPjrlabkLmVNdH96ZLUhN7k/HYo36AAFpZRSKog0cItQh6PSiGnCYvBl25ZQbqLpfsKwWtujY2LIl2SigrQWqospyqV9xQ7ys05BRDAdB5NMMYW7dSJepZRSKlg0cItQxdHpxJc3PrhKO/gt22KPIz4hqd6+QgneIvYuu9fY+duSetmltdr2OQ2A7Wu/DOp5lFJKqWOZBm4RqjQuo9Frim5cPp8TyldzsFOOx/0lManEVQR3sMbh9fMpM7H0OfkMAHoPGEqZiaVk67KgnkcppZQ6lmngFqHK4zJIrQo8cDPV1ZTPvJc8Ujnx4v/zeMyRECw0n7R3Cd9F9aJruwz7PDGJbTE9STqwKqjnUUoppY5lGrhFqKqkLNIasV7pmgXvMaBsJRv63FJr/jZ35XEZpFQVBKOYAFSXFdOlbCMH2p5Sa3t+xol0K9tEZWVl0M6llFJKHcs0cItQ0iaLOKmiIN//vmjVVVUkzn+Q3dKOwRf92vtxiZmNCgq92bF2IbFUEX/ciFrbY7oOIUVK2bJhdVDOU0tZEezQeeKUUkodWzRwi1DRyXai3IKDe/xOs/zTF+hV9T27Tr7D46CEGkltiZNKSouD08/twFo7MKHn4NG1tnc6wU5Fsu+7RUE5Ty2f3gkvjIfS4Db5KqWUUpFMA7cIFZ9ml70q9jNwKy87Qoelf+eHqB6cMuFmn8dGt2kLQP7B4KxXGrd7CVukK507dq61vV3PkzhCHBU7VwTlPDXytsG3r4GpgoObg5u3UkopFcE0cItQKe26AVC4a71fx6947590MXspGHkX0TExPo+NTbW1ecV5TV+vtLqyku6la8hNH1xvn8TEsTuhFxmH1zT5PLUsfBSqnX5zB3SeOKWUUscODdwiVI9+Q9hLNnGbPmrw2OLCw/Ra/xTrYgdy0uhLGzw+wQncgrHQ/Jb1S0mlhJiep3vcfyRrIL2rt7D3cEmTzwXA4R2w4hU45VqIitXATSml1DFFA7cIJVFRbG0/jv4lS8nP872u6Oq3HqYt+USdPQ2Javi/tE1GOwDKCxu/FqrLPqd/W7eTx3jcn9xzGMlyhA1rg9Rc+uU/7d9Rv4PM4+DApuDkq5RSSrUArT5wE5FJIvJsfn5wJ5xtDunDLiVOqtg4/02vxxQX5HHitv+yImkE/YaN9Svf1Ey70HxlYdMXmo/esZj9kkm7rn097u/Y39bE5W1e3ORzUbAblr8EJ18F6V0hq7cGbkoppY4prT5wM8Z8aIy5OS0tLdxFCVjfU3LYSxaxGz7weszamc+TQimJOb/xO9+U9LZ2ofkmLGIPUFVt6Fb8LXtSB4GIx2Ni2/ejjHii9670L9NNs+Cb56C6qv6+hf8EUw1nOK81qzcc2gJVFY18BUoppVTL4rsXuworV3PpKXvfouDwQVLT29bab6qryV7/Mpujj6fvEM/LW3kSHR3NIUlBSvOaVL6NG9fTnwMc7Dbcx8li2J/clw4F33GkooqE2GjvxxpD8bu306ZkJ1vnv8q8Ex8kMas7WSlxDEw9Qvay6TDoCsjobo/P6gPVFXaUaVavJr0WpZRSqiVosMZNROJE5FciMqA5CqRqSx9yCXFSycb5M+rt+27JZ/Ss3sahE67xq2+bu4KoVGLLmha47VnyPgBdTxnv87jqjidzomxl9Q7fkwlX7VtHm5KdzDKnkl24ngu+vpS57z7L9dOX8vG//w9TXXm0tg1s4AY6QEEppdQxo8Ffe2NMOfBnwPP6SSqk+gwZQy6ZRH/3fr19JQv/TQFJDBh/fcD5lkSnE1/etMAtY/tMdkV3Jr3HIN/H9RpGkpSxed1yn8ftWPQ2ANET/kabXy0ipcsJPBX3OEv6v8HlMotPo0aRF9/laIK2Ti3bQe3nppRS6tjgbzXNeuC4UBZEeRYVHc2WduM4oXgJhW7LXx3Yu52BBfNY124SSW1SA87XLjTf+AEbB3L3MqB8FXs7neW1f5tLSs9TAdi+5iuMMV6Pk42fsNocz48GD4TM44i6/lMY9Xuyt35IvFTy6JFJ/OzV5VRUOUt1JaZDm3Za46aUUuqY4W/gdh9wr4gMDGVhlGdpQy4hXirYsOCtmm2bPvkXcVJF57N+3qg8K+PTSalufOD2/ZdvEStVpA+5uOGDs3pTHpNC7+KlLN3muZbvSN5uupeuY1f70Uf7wUXHwpi74YZZyGUvcctF57Boy0Hu/3CdW959dGSpUkqpY4a/gdsfgGRghYhsFpEFIjLf7TEvhGU85vUdOs42l663zaWVFeUct20Gq+NPoWtv382U3lQmZJJqChu90HzCpo/YR1uOO2lkwwdHRSMnXcqEqMW896XnBec3L7RBabuhF9Xf2WUo9J/ExUO6cPOo43j562288vU2uy+rN+zfAD5q8pRSSqnWwt/ArQpYBywAdgCVzjbXo3G//sovUdHR/JA9hv5FiykuPMyq2W/QnoNUDLmh0XlKUlvipIqSosAXaS8vKaBf8RK+z8rxe1BE7Gk3ES8VpHz3BnnF5fX2V63/mN1kM2jIj3zm84fx/cjpm83UD9ay6PuDtsbtyGFo4tQmSimlVEvg16+uMWa0MSbH1yPUBT3WpQy5lASp4Lv5bxG7/D/sJYuTci5rdH5RyVlA4xaa//6r94iXCuIGXuh/ovYnUNJhGFfI57y9bHutXfkF+fQtXsqO7DOJjvb9kYyOEh67cjDd2ybx8/+3nPJ0p+ulNpcqpZQ6BrT6CXhbi75Dx3GAdFKXP83AshX80ONSYmLjGp1fXIoN3IrzAg/cqta+z0GTyomnnR1QuqQRP6VH1D42fvVhrUEKq+e/T4JU0Haof4FgakIs9048gUPF5Swrseuu6gAFpZRSxwK/AzcR6SgifxeRJSLyvfP3ryLSIZQFVFZ0TAzfZ42hd9Vmyk00vcc3blCCS2Ja4xaaNxWlHJe3kNUpI0lMCDBw7D+JI3GZjCv+kEVbjjZtVq77iEKSOH6I/4Hg8OPakhAbxac7YiAmQQM3pZRSxwS/AjcR6QOsBG4DioBvnL+/AlaKSO+QlVDVSD7lUgBWpY4mq0PXJuWVlO4sNF8Q2ELze1d+ShJHqOozMfCTxsQTM/QaxkYv56OFSwHIzS9hQPFX7MoaicT4HwgmxEYz4vgsvthwEKOLzSullDpG+Fvj9hegAOjj9Gm70unX1gfId/arEOt36tl83eUGOpx/f5PzSmtrK0qrigLr1F+w/F0KTCJ9fzShUeeNGXY9UUDHzW+wv7CMxQs+I0sKSB98fsB55fRrx868UopSjtNJeJVSSh0T/A3ccoB7jTFb3TcaY7YBU539KsSiY2IYfuM/6NKr6auPpaS1pdJEBbbQfFUFnfbOZkncaXTJSm/ciTO6U9Ith0ujZvP2ki2Urf0flUTT4ZTAA8GcfrbWcFNVR8jbCpVljSuTUkop1UL4G7jFAYVe9hU6+1ULEhUdxWFJJarU9/qh7oo3zSfFFJLf48dNOnebEbfQXg6zZeEMTir+in0ZQyAxI+B8Oqcn0rd9Cl/lZ4KphkNbmlQupZRSKtL5G7itBH4pIrWOFxEBfubsVy1MUVRKQAvN7//mLUpNHN1Pm9S0E/c+i5LETtxc+Rp9onaRPOi8Rmc1ul82s3KdJb90gIJSSqlWzt/A7X5gHLBeRO4XkVtFZBqwFjgLmBaqAqrQKYlOI77Czwl4q6vJ2D6TL2Uwg3p2bNqJo6KJG34DvaJ2A5A2qPGB4Ji+7WxTKegABaWUUq2evxPwfgpMxDaL3g38C7gHO7J0ojHms5CVUIXMkdgMvxear9qxhLTKg+zuOI6YBibJ9UfMkGsxUbFUZPWHjB6NzueU7hlEJyRzOCZbAzellFKtXkxDB4hILHAusMoYM1REkoAMIM8YUxLqAjaViEwCJvXq1SvcRYk4FQkZpJR4Xju0rj1r5tEFyD55fHBOnpyNTPwHsSmdmpRNbHQUo/pks2FTR049sBEJTumUUkqpiNRg1YkxpgJ4E+jhPC8xxuxqCUEbgDHmQ2PMzWlpaeEuSsSpSsgkzRT4tdB86faV7DGZDD2hb/AKcMo10Htck7PJ6duO7yraU71/oy42r5RSqlXzt81rC9AulAVRzU+SMomRaooKGh5Z2iZvLVuie5KdEt8MJQvM6L7ZfG86EV1RBEWBL+GllFJKtRT+Bm5/Be4WkexQFkY1r+jktgAUNrTQfEUp7cu3k5favxlKFbis5Hho6yzeoSNLlVJKtWIN9nFzjAEygR9E5GtgD+DeJmWMMdcGu3AqtOJSbBxefDgXONHrcUU7V5FMNVEdT2qmkgWuW99BsASKd62nTc9R4S6OUkopFRL+Bm5nABXAfuB45+FOOxa1QAnOeqUNLTS/77slJAOZvYY0Q6kaZ9jAARR/E8+eLavoNTLcpVFKKaVCw6/AzRjTI8TlUGHQJr090PBC80d2rqTAJNK7T9OX2gqVgV0y2CCdYe+GcBdFKaWUCpkG+7iJSJyILBeRs5ujQKr5pGbawK2q6IDP45IOruX7qJ60TUlojmI1SlSUUJp6HKklW6msaniUrFJKKdUS+TMdSDnQE6gMfXFUc0pJzaDCRPteaL66ig5HvudQar/mK1gjpXbtT2f28/KC9eEuilJKKRUS/o4qnQVojVsrExUdRb6kEH3E+3qlRXs2kEgZpn3kDkxwOX7AcAC+/vwd1uzyb0WIWj76DXw/O8ilUkoppYLH38DtCeBKEfm7iIwUkeNF5Dj3RygLqUKnMCqNGB8Lze9e/zUA6cdH7sAEF+lzDlWpXfh57Ifc9v+WU1wWQCVxUS4seR7Wvhu6AiqllFJN5G/gNg/oCvza+fdGYFOdh2qBimPSSPCx0HzJ9pWUmRiO639KM5aqkaJjiR7xK04yG8jOW8bUD9b6n3bvKvv38PbQlE0ppZQKAn+nA7kupKVQYVMWm0ZW6Vav++MPrGVrVDf6piY3X6GaYvBkmPcX/pz+OTnL+nNGn2zOG+THeqh7XIHbjtCWTymllGoCf6cD+W+oC6LCozwhi7bFS6mqKCM6ts5yVsbQoWQT61J+RBBXKA2tuCQYfis9Zz/AxR2v4u53VjO4azpdM5N8p3PVuOXvgOpqiPK3MloppZRqPgH9OolIlIgMEJEzRaRNqAqlmk9M3/EkU8q6eTPq7SvYv4MM8qlqPzAMJWuCYTdCXAoPZM0CgdteX0FFQ1OE7F1t/1aV63qnSimlIpbfgZuI/BzYC6wCZoOthBGR90TkttAUT4XaoNEXkUsm1ctfrrdv17rFAKT0aAH929wlpsOwG0ja/D/+MTaZFdsPM/s7H6tDlBXBwe+hyzD7XPu5KaWUilB+BW4ichPwGPAecBkgbrsXABcHv2iqOcTFxbG50yQGFC8md9fWWvsKty0HoPuJp4WhZE10+s8hOo6cg68REyWs3OF9AAb71gAG+k20z/O1eKJ6rAAAIABJREFUn5tSSqnI5G+N26+BR4wxNwN150v4DlpOFyhVX7cxNxEthk2znqu1PXb/anZIRzIzMsNUsiZIbgeDJxOz6nVGtivnW1+Bm2tggitwO7wt9OVTSimlGsHfwK0nMNPLvmIgPTjFUeHQpddAvosbQNdt71Dl1hesXfFGcpP6hLFkTfSj28BUc1Psx6zamU91tfF83N5VkNQW2h4PSVnaVKqUUipi+Ru4HQB6eNnXF9gVlNKosCkbeBXdzG6+XfQpAPl5B+ls9lGRHbkLyzcoozsMvITTDn1ATFkeWw4UeT5u7yrocBKIQHpXDdyUUkpFLH8Dt/8B99VZIcGISBZwB7bvm2rB+o+9hmISKP3azvyy3RmY0KalDUyo60e/JKaqlInRX7Niu4fm0qoKyF0PHZ0lvdK76VxuSimlIpa/gds9QBmwBvgcMMDjwHqgCrg/JKVTzSYuKYVN2edwcuEccg8coOCHZQB07d8CBya4az8Ak9GTc2JW8O1OD4Hb/u/sFCAd3AK3/B1gvDSrKqWUUmHkV+BmjDkADAUeBmKB77GT9z4JnG6MacSK3irStB91A22kjFUzpxO9bzUHSSe9fddwF6tpRJB+Exgua9m4fU/9/a7522oCt+5QecSuXaqUUkpFGL/ncTPGFBpjHjDGjDTG9DHGnG6MmWaMKQhlAVXz6ThgFLtiupK9eQZtizawtyUPTHDX91xiqaD9voUcqaiqvW/PKohNsgMTANKcQFX7uSmllIpAuq6POkqEgn5XMMh8x3HV2zjS9oRwlyg4up5GeVw6Y6KWsm5PnfuMvaug/QCIirbP07vZv/kauCmllIo8GripWo4fdwOVRBEthqTug8NdnOCIjqGq1zmMiVrBqm37j243xjaVdnBb0itda9yUUkpFLg3cVC1x6R35IWMEAJ37tfCBCW4SB0wkTUoo2rjw6Ma8rVBWcHREKUB8CiRmauCmlFIqImngpurpcfFD5A28gdROrWhBjOPHUCGxdNg7++i2vc6KCR1Oqn2szuWmlFIqQmngpuqJ7TKIjIv/AVGt6OMRn8yujNM4rexrDheX2W17VoFEQ7s6ffl0LjellFIRyt9F5meLSD8v+/qIyGxP+5SKJJW9f0zXqP1sXrvEbti7GrL7QmxC7QPTu9saN53LTSmlVITxt0plNJDqZV8KcGZQSuMnETlORP4jIm8153lVy9Z+2PkAVKz90G5wLXVVV3o3qCyF4gPNWDqllFKqYYG0hXmrfjge8LIIZH0i8oKI5IrImjrbx4vIBhHZLCJ3+iyIMVuMMTf4e06lAFKyurIuqg8d9syGov1QuKf2wAQXnRJEKaVUhIrxtkNErgOuc54a4FkRKaxzWCIwAPgigHNOx6648JLbuaKBfwFnATuBJSLyARCNXa3B3fXGGJ3WXjXKD1lnMiH3OczGTxCoPRWIi/skvJ2HNGfxlFJKKZ+8Bm5ANXYdUgCp89zlIPA08Bd/T2iMmS8iPepsPhXYbIzZAiAirwPnG2MeBib6m7dSDans/WPIfY6qeY/YD7+nwE3nclNKKRWhvAZuxpj/Av8FEJE5wK3GmO9CVI7OgPswvp2A10nERKQt8BAwWET+zwnwPB13M3AzQLdu3YJXWtVi9eh7ClsXtKdH/lbbJJqYUf+ghDRISNfATSmlVMTxd5H5nBAGbQEzxhw0xvzUGHO8t6DNOe5ZY8xQY8zQ7Ozs5iyiilD9OqUy2wy1TzwNTHDRKUGUUkpFIF9NpbWISCpwLtANqDN/AsYY80ATyrEL6Or2vIuzTamgio+JZlPmKMj/qOHA7eD3fudrjEFEglBCpZRSyju/AjcRGQF8CKR7OcQATQnclgC9RaQnNmC7AriqCfkp5VV8z9P55/Ir+OWgq4j2dlB6N/h+jp3LzY+A7P+zd9/RcVXXAod/d0a9915sWXLvBTcMGIOppieBEEgogRAIKS8hnUACLy8hQAIhIfSaQugEiLGxjY0blnu3JEuyVawujbqmnPfHmVEdSaNmFe9vrVkj3XvnzhmD5a1zzt77j2uz2JRVxtvfXjqoYxVCCCHa87QcyB+BPGAB4KeUMnV6dPvvX2eGYfwD2ApMMgyjwDCM25RSNuAeYDVwGHhDKXWwT59ECA/NTInkjy1XkN3c3e8h6MDNWg8NlR7d8/29RRwosqCkaK8QQogh5OlS6RTgy0qpnQN9Q6XUDd0c/wj4aKD378wwjFXAqvT09MG+tRilZifrgO2p9dn87tqZ+Pu4+b2jfS23wMge73eysoHc8noAGq12Anw83oEghBBC9ImnM24nAN+hHMhQUUp9oJS6IzQ0dLiHIkaItOgg7l2Rwft7i7jqqc1kl3YuT0jHWm692JTV1mGhsr5lsIYphBBCdOFp4PYg8BNngoIQo94PLpzIy7eeRXldM6ue3MybOws6XuCacfMgcNt4rKz16+oG62AOUwghhOjA0zWdy4FYINcwjK1A540/Sin19UEdmRBD7NyJ0Xz03WXc+4/d/PDfe9maU8GDV04jyNcL/MPAN7TXwM1md7A5p5y06ECOl9XLjJsQQogh5emM29nozFELMA1Y5uYhxKgTG+LH67cv5N7z03l7dwFL/28dj6w+Qmltk0e13PYW1FDbZOOq2YkAVDVI4CaEEGLoeDTjppQaP9QDEWK4eJlN/GDlJFZMieWvG3L4y4Ycnt2YyzsRoaRbc3vc3LnxWBmGAZfNjOexNceokhk3IYQQQ8jTGbdRyzCMVYZhPFNTUzPcQxEj3KzkMJ6+aR7r/uc8vrwgiZ01QbRU5PPk2mPdvmZTVhkzk8JIjQjAMKBS9rgJIYQYQh4HboZhBBqGca9hGG8ahrHeMIwM5/HrDcOYPHRDHBjJKhV9NT4qkIeumsG15y8h2Gjk1Q179dJpJzWNVvacrOacjCi8zCZC/LyplqVSIYQQQ8ijwM0wjGRgH/AIkAGcAwQ7Ty8HfjgkoxNiGAXGpgEQ4yjl2Y3Hu5zfkl2OQ8E5E3Uf3IhAH0lOEEIIMaQ8nXF7FGgGJgLzgPY9gD5DkhPEWOQsCXL1eDuvbTtBRV1zh9Mbs8oJ8vVqLegbHuAt5UCEEEIMKU8DtwuBXyml8tHZpe0VAomDOiohRgJnEd4rxtlostl5/vPc1lNKKTYeK2PxhEi8zfqvUXiAzLgJIYQYWp4Gbj6Am/LyAIQCtsEZjhAjiH84+AQTbSvh0unxvLI1v3UPW255PYXVja3LpADhgT6yx00IIcSQ8jRw2wdc2825S4AB9zAdKpJVKvrNMCAqA4p2c8/56dQ123hxcx7Q1ubqnIyo1svDA7yplMBNCCHEEPI0cHsEuM0wjGfRiQkAUw3DeBC4zXl+RJKsUjEgEy+Gk18wJbiJlVNjeXFzLrVNVjZllZESEUBqZGDrpeGBPjRZHTS22IdxwEIIIcYyjwI3pdTbwLeBLwFrnYdfAb4H3KOU+u/QDE+IYTblckDBkQ/5zvkZWJpsPLcpl605FZwzMarDpeEBPsAY7J7gcEDhLtj1KthlV4QQQgwnT3uVopR62jCMV4HFQAxQAWxRSnW3902I0S9mKoSPhyP/Ycb8W1g+KZo/r8/G7lAsy4jucKkrcKusbyEhzH84Rjt4GqsgZx1krYHstVBfpo8HxcLElcM7NiGEOIN5HLgBKKXqaZtxE2LsMww967btaWiq4TsrMlh/tAyzyWDxhMgOl4YHeAOM/pIg25+B1T8Fh00naKRfAEkL4OP7oObEcI9OCCHOaN0GboZhnAPsUkrVOb/ukVJq46COTIiRYvIq2PIkZK1h7ozruGBKLDaHgxA/7w6XRQQ6Z9xG81Lp54/D2gf03r5lP4TEuWAy6+XS1T+HmsLhHqEQQpzReppx2wAsAr5wft25fpuL4TxnHsyBCTFiJC3QS4SH34cZ1/G3m+Z1qEDtEu4M3EZlSRClYMNv4bPfwfRr4eq/gbldYGoyQUg8WCRwE0KI4dRT4LYcONTuayHOTCYTTLoU9r0B1kbM3u73r4X560Bn1BXhVQrW3A9bnoDZN8IVT+pZts5CkmTGTQghhlm3gZtS6jN3X482hmGsAlalp6cP91DEaDblctj5IhzfAJMucXuJbjTvNXR73BwOeOZcWHgnzPna4N3z4/tgx7Ow4Ha45BEdqLoTmggFOwbnfYUQQvSLp3XcRi2p4yYGxbhzwDcUDv+nx8vCh7LRfGUOnNoHuYO4nXTnCzpoW3wPXPqH7oM2gJBEPePmcAze+wshhOiTnpIT1vXhPkoptWIQxiPEyOTlo8tgHP1I1zIzu/+rEx7gM3R13Ap36efK44N3z6y1EDEBVj6kM2h7EpoEDqsuDRIcO3hjEEII4bGeZtxM6MQD12MycB4wDvB3Pp8HTHKeF2Jsm3w5NFbCia3dXhIe4D10gVuRM3CryBmc+zkccHIbpC7uPWgDPeMGYCkYnPcXQgjRZ90Gbkqp85RSy5VSy4E/AVZgsVIqTSm1WCmVhi7Ga3WeF2JsS78AzL5wpPvl0vBAH6rqh2iPm2vGrbFSF8gdqPJj+j4piz27PjRJP0uCghBCDBtP97j9BvilUmp7+4PO7x8AHhrkcQkx8vgGwYTz4ciHOhPTjYihWiq1W/X+tvBx+vvK3IHf0zlz+FRONDa7B/vWXIGblAQRQohh42nglgGUdXOuFJCUTXFmmHI51JyE4j1uT4cH+tDQYqfJaoeqfMh8cXDet/Qw2Jp0jTUYlH1utrytVBLKI5lWXtuW3/sL/MPByx9qZKlUCCGGi6eBWy5wZzfn7gTyBmU0Qox0Ey8Bw9xtdqmrX2l1gxXW/Qb+8z2orxj4+xbt1s/TrtHPgxC41WVtYpt9EhOig3h0zTHK65p7foFh6JIgErgJIcSw8TRwexBYZRjGAcMwHjAM4y7n8wHgMvRy6YhkGMYqwzCeqampGe6hiLEgMBLGLYUDb7kti9Har7SqHA5/oA9W5w38fYt2gV8oxE7TSQIDDNzycrMIay6iOeEs/nbTfBpb7Pz+v0d6f2FIoiyVCiHEMPIocFNK/RO4CKgBfgo85XyuBi5SSv1ryEY4QFLHTQy6OTdDVS4cX9/llKvtldfh9/TSJugl04Eq3AUJc/SsV0TagDJLlVJ88MHbAJx3wRWkxwRx29njeSOzgF0nekl6CJXuCUIIMZw8LsCrlFqrlFqKLgUSB/grpc5WSn06ZKMTYiSaegUEREHmC11OuZZKI7LfhLAUfbB6gIGbtQlKD0HCXP19RNqAZtz+s6+Y4NJMrGZ/wtPmAfCdFRnEhvjyq/cOYnd015YYPeNWd0rXshNCCHHa9blzglLKoZQqVUpJ+XRxZvLyhbk36WK8nfZ7hQd6k2qcIqJiF8y/FQIioSpvYO93aj84bJA4lw/3FVPtnwwN5dDU9+X/2iYrv/nPIZb5ZuOVclZrIeEgXy9+dukU9hfW8K8dJ7u/QWgSKAfUFvf30wghhBgAjwM3wzB8DMO40jCM+wzDuL/T45dDOUghRpx5t+iSIDtf7nA4zN+Ha80bcWCCmV+BsNSBL5U6C+8eMtK5+++7+LgoQB/vx6zbY2uO0VRXRZojD6NT/bYrZiVw1vgIfr/6CFXdte0KdRXhleVSIYQYDh4FboZhJABHgHeA36KTER4AfuV8PDAkoxNipApPhYyVsOtlXWPNyccE15k/53jIAghJ0HXXBrpUWrQbAmP4v80WAPbUR+jjfQzcDhbV8PKWPH4wpQZDOXTHhHYMw+DXV06jtsnGo2uOur9JiKsIr2SWCiHEcPB0xu0RdB23FHR7q4VAGvAwkO38Wogzy4LboK6kYyeFvE0kGOVsDlqpvw9PheqT4LD3/30Kd1EdPp2NWeUE+JjZWhmsj/cxcHtuUy7Bft5cH1ukS5okzu9yzeS4EG5alMrr209QWN3Y9SauGTcJ3IQQYlh4GrgtAx4FipzfO5RSeUqp+4E3gSeGYnBCjGjpF0BoCux4vu3Ynr9TZwSy0bRQfx+WqhuzW4rc36M3zbWo8mN8Up1IVJAv31yWxok6A0dQHFT0LXDbkVfJ0vRI/Iq/gPiZuhOEG9fNS0IpyMyr7HrSNxh8Q2WpVAghhomngVskUORMSKgHwtudW4duNi/EmcVkhvm3QN4mKDsKzbVw+H12BJ5HWZOzaXt4qn7u73Jp0R4MFB9VxnPP8gnMSNRlbeqDUvs041ZiaaKgqpH5SUFQsKPH/qST44Lx8zax52S1+wtCE6UkiBBCDBNPA7cCIMr5dQ6wst25s4CmwRyUEKPGnJvA5K1Lgxx6D6wN7I28lErX5n5Xb9F+JigoZ2P5kqCp3LAwhfQYPUtW5t23Iry78nV9tqWBBbq+XMqibq/1MpuYmRjG7hPdBG4hiWDpulTqcCje31vElU9t5qP9knUqhBBDwcvD69YD5wLvAn8DnjIMYzZgRRfm/dvQDE+IES4oGqZeCXv+DlEZEJmOJXIOVbnOkhqhyWCY+l0SpPToVloc0dx8wTx8vcwkRwTg42UiX8WRVl+qZ/l8g3u9T2Z+FT5eJiY0HtAHkrsP3ABmp4Tx0uY8mm12fL3MHU+GJrW14EIX9F13pJRHVh/lyKlaAP654ySXzojv24cVQgjRK09n3H4B/BVAKfVX4LtAABAP/B74nyEZ3SCQlldiyC24HZotULgTZt1AeKAP9S12mm12MHvrGap+LJUqpVCFu8jyzuC6eTqb02wySIsK5GCTcwLcw1m3nflVzEoKxatguy7gGxzb4/VzksNosTs4XFzb9WRooq4jZ21i+/EKrvnrFm57OZMmq50/XT+bbywZx/bjFTRZe0nIOLkDTmzzaPxCCCE0TwM3K9D6L49S6kln14S5SqmfKaVG7FKptLwSQy5lEcRMBQyYdX1r26vqBmeZkPBx/VoqXbfrMHGOEqInLsbb3PZXNT0miMy6MP2NB4Fbk9XOwaIa5qWEw4mtkLKk19fMTtH33+OuBZazJEjxyRy++tx2TtU08dtrZrDmB+dy5exEzp0UTbPNQWZeD+2zyrPhlSvh1asH1L5LCCHONL0GboZheAEVdNzXJoRwMQy49A9wye8gNIkIZ+BW1eDc5xaW2uel0mabnU8//S8AU+ef2+FcRkwwO2o8D9z2F9ZgtSuWRVRBY2WP+9tc4kP9iQ3xZbe7BAVnSZB9Bw9gdyj+8c1F3HBWSmtwuXB8BD5mE5uyytzf3NYMb90KXj56RvKdb0kLLSGE8FCvgZtSygaUAAMoRCXEGDduKSy8E4CwAG+AdgkKqbq/p9VNXTQ3Glvs3PHKTiJrDqIwMCfO6XA+PSaIeuWHNSDGo5IgrpmvWY5D+kAPGaXtzU4Oc59ZGqIDtxN5WYyPCmRcVGCH0wE+XswfF87GrHL3N/7011C8F658Ci59FAq+gC1/8mhMQghxpvN0qfQ14PahHIgQY0WEu6VS0IV4e1HXbOOWl75gY1YZ1yeWY0RlgF9Ih2tcmaU1/skezbjtzK9ifFQgQSWZEBAFkRM8+hxzUsLJr2hoC0BdnIFbfWk+502KdvvaZRnRHC62UFrbaRdF1lrY+me9L3DyZTDjOph6Faz/LRTv82hcQghxJvM0cMsDFhiGscMwjF8YhnGbYRi3tn8M4RiFGFXCA3Tg1hrwhHlWy62m0cpNz29nR14Vf/zyLBIbDkPC3C7XjYsKwGTAKXNCr4GbUopdJ6qYlxquEwFSFumlXQ/MTnbuczvZaa+atx8tvhHEqHLOmxTj9rXLMnTyxOftZ93qSuHdb+n9gCsf0scMAy57DAIi4J079TKqEEKIbnkauD0FJALzgF8DzwLPtXs8OySjE2IUci2VVje0WyqFHve5Vda38NVnt3GgsIanvjqXK8fZdTutxK6Bm6+XmXGRgWTbYvQSbHNdt/fNc86YLY2zQ1WuR/vbXGYkhmIyYI+bem7l5hgSzZUsHB/h9rVT40OIDPRhkytwczj0XrbmWrjuBfD2b7s4MBKueBJKD8H6hz0enxBCnIk8DdzG9/KQXqVCOPl6mQn0MVNZ71wqDYoFL79uA7e6ZhvXP7OV7NI6nr15PhdPj4Pstfpk2nK3r5kQE8S+xkj9TVVut2PZ6Sy8u9B8TB/opX5be4G+XkyKC3GboJDTHMoEn2r8vM1uXgkmk8HZGVFsyirH4VCw7S+Q8ylc9L8QM6XrCyZeBPO+AZufgPytHo9RCCHONB4Fbkqp/N4eQz1QIUaT8ECftqxSw9DLpd0slW44WsqxkjqevGFO29Jj9loIS9FFfd1Ijwki09J7ZunO/EpC/LyIr9mrg8f4WX36HK4EBYdDtR7LLa8nuzmMGNVN8oHTsoxoyuuaOVJUCZ/9HjJWwvwedlWsfFjPTn70oz6NUQghziSezri1MgzD1Onh2YYZIc4g4QHtAjfQAUk3M26ZeVX4e5tZPtkZtNma4fhnkH5ht/vRMmKCyLE7i+j2UAdtZ34Vc1PDMQq26/1yXj59+hxzksOobbJxvLy+9dj6I6UUqwh8bHXQZOn2ta59bsd3rIbmGj2j1tOPC98gmHY1lB0GhySxCyGEO90GboZhxBmG8aFhGDe3O2ZGF+Nt/6g2DKPnMuxCnGHCA32oap+NGZYKVSfcXrszv4rZyWFtRXbzt4C1Xs9QdSM9Joh6/Gn2jep2xq2m0cqxkjoWJvrr8ht92N/mMsdViLfdcumGY2U4ghP0N5bum83HhvgxKTYYn+yPwcu/22XfDkKTwWHT+/uEEEJ00dOM27eBucC/Ox030AkJvwZ+AxQB3xqS0QkxSkUEeFPlKgcCuiRIcw00dszQrG+2cajYwvxx4W0Hs9eC2QfGL+v2/hOidUmQCt9EqHS/x223s+vBsoB8HQz1I3CbEB1EsK9Xa2ZpQ4uNbccrSEx1LuHWdB+4ASxLj2RG3efY05aDT0Dvbxia7Lxv1yb2Qggheg7cLgaeVUp1rhqqgL8ppR5USj0A/Bm4dIjGJ8SoFBbQacatNbO04z63PSersTuULtfhkrUGUpeCT8fCtu0F+nqRGOZPAXFQ6X6pdGd+FWaTwcSWg/pA0oI+fw6TyWBmcii7nZmlW3MqaLE5mDZlmr7A0nOAdUl0KfFGJVnh53j2hqG6nRbV7mcnhRDiTNdT4DYJ2OLmeOdNKsec1wohnCICfahttmG1O/SBMPclQTLzqjAMmOsK3Kryofxoj8ukLhNigjhijYHaYmip73J+Z34VU+KD8Sn6AqKn6Fpp/TA7OYwjp2ppbLGz4WgZAT5mZk2dBIap1xm3mbWfY1cGHzXN9OzNXIGbzLgJIYRbPQVufkCHAlFKKTsQD+xtd7jJee2IZBjGKsMwnqmpqRnuoYgzSLizlltV51punTJLM/MrmRQbTIifvp7sNfo548Je3yM9Oojd9a6AL6/DOZvdwZ6T1cxPDoWTOyBlYb8+B8Cc5HDsDsX+whrWHy1lyYQofH18ITi+xz1uAN7Z/yXLdxqr8zxMNvALAb9QCdyEEKIbPQVupbipz6aUKnEGcC7jgW66SQ8/pdQHSqk7QkNDh3so4gwS7mo076rl5hcK/uEdlkrtDsXuE9WdlknX6v1wkem9vkd6TBBZVmcmaqcEhSOnamlosXNOeIXeW9eH+m2dzXYmKLy1s4CCqkaWT3a2uQpJhJoe2nhV5UHJAaqSL+RoSS0llqbur20vNLnn+wohxBmsp8Dtc+AmD+5xM7B5cIYjxNjganvVoSRIWMeSIEdP1VLXbGtLTLA2QW7PZUDay4gNIl/F6W86lQTZ5UxMmMMRfWAAM25RQb4khfvz1i49C9Zaay40seel0iMf6dfPvxqgrYtCb0KTZcZNCDEoduZXcarGw18aR4meArcngPMNw/iDYRhenU8ahuFlGMZjwHnAn4ZofEKMSq2BW+cEhXZLpTvzKwGYn+rce3ZiC1gbPFomBb1UWksAjd7hul1UO5uzy4kN8SWsfJfu3BA+fgCfRjectzkUE2ODSAxztqsKSdRLpUq5f9GRDyF6ChMmzSQqyIct2Z4Gbkky4yaEGBR3vJLJ42uODfcwBlW3gZtSaitwH/B9oMAwjFcNw3jY+XgVKADuBX7qvFYI4RThWirtXBKk+oTu2wnsyKsiNkTPZgF6mdTsC+O6LwPSXnigD5GBPuwLOhsOvAVl+ofT6oOnWH2whKvnJGGc3AbJCz1uLN8dV8P5Dk3lQ5PA1gQNlV1f0FCpA9HJl2EyGcxKCuNAkYf7TEOToKmmx+K+QgjRG4dDUdnQwtGS2uEeyqDqsXOCUupR4AJgD3At8FPn41rnsZVKqUeGepBCjDZhnZMTQC+V2lt0Y3j0FP781Aham49kfQLjzvas3plTekwQT5tvAO8A+O+PKapq4L439zE9MYTvLwzUgWI/6rd1dnZ6FD5mE5dMj2s76MoAdVcS5NhqUA6YrCsFTUsIIbu0jsYWD5IUwqSWmxBi4OpbbCgFOaV1qO5WBkahXlteKaXWK6UuBoKBOOcjWCl1sVJq3VAPUIjRyM/bTICPuZtabnkU1zRSWN3YlphQlQcVWR4vk7qkxwSxs9wLdd5PIGcdr77yNDa7gydvmItv0Q590SAEbpPigtn/4ErmpLRLpAhJdI7dTQ/WI//RWafxcwCYmhCKQ8GRUx7MokkRXiHEILA02QCobbZRWts8zKMZPB73KlVK2ZVSpc6HNBIUohfhAT5UduhX6txnVpVPZp5OHmhNTMhylQHpvX5be+kxQViabJRNuZkK//FcX/FXHl6VwfioQDi5Xc/ExXlYQ60Xvl7mjgci08E3FN7/jl6qdbE2Qs46mHQpmPSPmGkJIQAcLPIkcHPVcpMivEKI/rM0tm1VyS6t6+HK0aXPTeaFEJ4JD/Smuv0et9AkwIDqfHbm68byU+J1QEPWGh3YRU7o03tkxAQD8I/MU3yv5npSTaVc1fiOPnliGyTOA7P3IHwaN/xC4JtClCuaAAAgAElEQVTr9JjfvBXevE239Dq+QSdZTG5rqJIU7k+ov7dngVtQLJi8ZMZNCDEgtc4ZN5DATQjhgfAAHyrbL5V6+UJIgp5xy69sayxvbYLcjX1eJgU94wbw+NpjnAxfiG3S5bDpUZ2ocGq/TkwYSlHpcOsnsPzncOhd+MsS2PwE+AR3SLIwDIOp8SEc8iRBwWR21oiTwE0I0X8y4yaE6JPwAJ+OyQkAYanYK3M5VGRhwbhwKMiE5y8AWyNMvrzP7xEb4kuQrxfeZoMnbpiD18X/q5MC/nkDKDukLB6kT9MDsxecex/ctgZ8g3Q2acaFOlBtZ1pCCEdO1WJztQHridRyE0IMUG2zDtzCA7zHVODWpT6bEGJwhAd4d0xOAAgfhy1rHUGqjutLH4ct/4TgOPjSS5B2bp/fwzAM7l6eTlyoLzOTwoAwWPpd+Ox3gAHJfW8s32+Jc+HOjZD5gtu9etMSQ2i2Ocgpq2dSXHDP9wpNgrzPh2igQogzgaVRL5XOSQlnf+HYaXspM25CDJHIIF8sTTY+OXiq7WB4Kj4NJazz/SHxOf+CRXfB3V/AtKv7/T53nTeBq+cktR1Y+j09YxU3XbfaOp28/WHx3RCV0eXUtAQ9loOeLJeGJUNtEdhtvV8rhBBuuJZK5ySHUVbbTE27pdPRTAI3IYbIl+cnMzU+hDte3ckP/72X2iYrRE3EQFHmFY9xx2dw8W/1Jv/B5BMAX/8AvvTy4N53gNKiAvH1MnmeWaocUFs89AMTQoxJtc02fL1MrUlgY2W5VAI3IYZIXKgf7969lHuWp/P2rgIu/uMmNvss5Ub1EH+f/izED06ZDrci+p6hOtS8zCYmx4d4NuPWWhJEWl8JIfrH0mglxN+7NYkrRwI3IURvfLxM/PCiSbx51xK8zQY3vpDJ5uY05o2LGu6hDYtpCSEcKrL0XsU8NEU/S4KCEKKfaptshPh5kRwRgI+XiewyCdyEEB6amxLOR99dxtcWpRAV5MuSCZHDPaRhMS0hBEuTjYKqxp4vDHV2Zag5SVF1I//acWJMtawRQgw9S5OVYD9vzCaDtKjAMbNUKlmlQpwmAT5ePHTVDB66asZwD2XYtE9QSI7ooSerTyD4R0D1SR7+6DAf7iumoKqR/1k56TSNVAgx2lkarYQG+AAwISaI/QVjI7NUZtyEEKfN5LhgzCbDswSFsGSslSdYc7CEiEAfnlyXzRs7ZM+bEMIzrqVSgPToIE5WNdBkHf0dOyVwE0KcNn7eZiZEB3qYWZpMXWkeLXYHr9x6FssyovjZO/vZlFU29AMVQox6rqVS0F1mlILc8vphHtXAjfnAzTCMVYZhPFNTMzamSIUY7aYlhHqUWapCEvGtL2JWYgjTE0P5y41zSY8J4q7XdnHklAeBnxDijGZpshHi75xxc2aWjoV9bmM+cFNKfaCUuiM09DQXIhVCuDUtIYQSSzPldc09XldENAE08rU5YQAE+3nz4i0LCPQ1c8uLOyixNJ2O4QohRqEmq50Wm4MQ54zb+KhADEMCNyGE6LOpCboYZm/LpZ+V6F6nlyS3dU+ID/XnhW8swNJo5ZYXd2D1pO+pEOKMY2nSXRJce9z8vM0khweMiZIgErgJIU6rafG9t75qbLHzQZ4ZgKDGjt0TpiWE8qsrpnGo2MIhT/bKCSHOOLVN+he+EH/v1mPpMUFjogivBG5CiNMqNMCbpHD/HmfcPtpfTHazXiJ1V4R3fmo4AFn9/SFcfQKsstQqxFjl6lMa7NdW9Sw9Jojj5fXYHaO7JqQEbkKI087VQaE7/8o8SXBkPMrs67btVUpEAD5mE1mltX1/87zN8MRc2Prnvr9WCDEqtM64+bWbcYsOosXm4GRlw3ANa1BI4CaEOO2mJYSSW15PXbOty7nc8nq+yK3kugUpGKGJbgM3L7OJtOhAskv6OONWmQv/+ho4rFB2pL/DF0KMcK49bsHtArcJYySzVAI3IcRpN82ZoHC4uOus2xuZJzGbDK6bm6SbzXfTrzQ9JqhvS6VNFvjH9aAcEDURqvL6M3QhxChgaXTtceu4VAqM+gQFCdyEEKdda+urwo4JCja7g7d2FrB8UjQxIX662Xw3gVtGTLDnldAddnjrdijPgi+/AimLoCp/wJ9DCDEy1bZmlbbNuIX6exMd7DvqZ9ykV6kQ4rSLDfElKsiHjw+cwt/HjLfZhI+XifyKBkprm/ny/GR9YWgS1J4CWwt4+XS4h6sSek5ZXWsg2K21v4Ks1XDZo5B2LhRmQn0ptNTrvqhCiDHF0mTFbDII8DF3OJ4eHSSBmxBC9JVhGCxKi+Q/+4rZnlvZ4VxciB/LJ8fob0KTAAWWQogY3+G6jNi2/So9Bm67X4ctT8KCb8KC2/Wx8HH6uSofYqcOwicSQowklkYbwX5eGIbR4Xh6TBDv7i5EKdXl3GghgZsQYlj86fo5/GrVNKx2By02B1a7g2abg5gQX7zNzl0cYc6Zt5qCLoHbuMhAzCaDrJ4SFE7ugA++C2nnwcX/13a8NXDLk8BNiDGotsnaoRSIS3pMELXNNkprm4kN8RuGkQ2cBG5CiGFhNhlEB/v2fFGoK3Drmlnq42ViXGRA9yVB6krhjZsgNBGuexHM7X7chTuDQElQEGJMsjTZOuxvc2nfs3S0Bm6SnCCEGLlCEvVzDwkKbjNL7Vb49zegsRq+8hoERHQ87x8OviESuAkxRtU2WXsN3EYrCdyEECOXtx8ExridcQO9zy2/ooFmW6fM0jX3Q/5muOJJiJvR9YWGAWGpErgJMUa59rh1FhPsS7CvlwRuQggxZEKToNp94JYeE4Tdocgrb1cJfd+/YdtfYNG3YeaXur9vuARuQoxVliZrhz6lLoZhkBYTxPFyCdyEEGJo9FCENyMmGKBtn9up/fD+dyB1KVz4657vGz4OqvNBje6+hUKIrmqb3M+4AYyLDCC/YvS2vZLATQgxsoWl6KXS5q5JCGnRgRiGc79KQ6VuZ+UfBl96Ccxdf9vuIHwc2JqgrmRIhi2EGB52h6Ku2X1yAkBqRABF1Y202ByneWSDQwI3IcTINvkyZ7LBLWDv2NvUz9tMSkQA+afK4O9fBksxfPlVCIrp/b6SWSrEmFTnajDvZqkUIDUyEIeCgqrROesmgZsQYmRLXaI7HmSvgQ9/0GVpc3K0H9fn3Q+FO+G65yF5gWf3bV/LTQgxZrQ1mHe/VJoaGQBAfuXoDNykjpsQYuSbf4teLt30qE4qWPY/+rhS3F33JDNtmdgvfQzzlFWe3zMsGTAkcBNijKlp7NqntL0UV+BWXg+TTtuwBo3MuAkhRofzfwkzvgSf/hr2vaGPffogM8s/5HHrteSN/4rHt6qsb8FmeOs6cRK4iYGqL4cvnpVElxGi1rVU2s2MW3SQLwE+5lE74yaBmxBidDAMuPIpSD0b3v02vH8vfP44FZNv5E/2a3pufdVOk9XO8j9s4Kn1OVISRAyOfW/ARz+EsiPDPRJB21Jpd3vcDMMgJSKAE6M0s1QCNyHE6OHlC9e/BhFpsOtlmLIKvysfBwyyu2t91cmBwhpqGq28v7dQ73OTwE0MlKVQP5ceHt5xCKD9jFv3meWpkQHkVdSfriENKtnjJoQYXfzD4aa34cBbcNadBHr7khjm73El9Mz8KgByyuqpmBJPZG0xWBvB27//Y7I2Qs56qDwOVblQmaufHXa48zM9ZjF2WYr0swRuI4KlsefkBIBxkYGsP1qGw6EwmYzTNbRBITNuQojRJzQJln5Xt8RCd1Bw27PUjcy8KiIDfQDYVRuqD1afGNh4Nj0G/7wBPvm57tzQUA6hybrAb0HmwO4tRj5X4FYmgdtI0FtWKegEhRabg1OWptM1rEEjgZsQYtTLiAkiu7QOu6PnzeFKKXadqGL55BhmJYWyplhnlw14ufT4BkiYA/flwk/y4c6Nurk9BhTtHti9xchXKzNuI0ltk40AHzNe5u5DnNSIQIBRuVwqgZsQYtTLiA2i2eagsKqxx+uOl9dTWd/C/NRwVk6LY90pV+CW3/83b66Dol2QthwCInQSBYBfCERlSOA21jkcuvCzYdZL5dbRN4Mz1lgarT3ub4O2Wm6jMUFBAjchxKiX3rlnaTd25un9bfPHhXPRtDjKCcFm9hvYjNvJ7eCwwbizu55LmCOB21jXUA4OKySfBcoB5ceGe0RnvNomGyH+PW/hTwjzx9tskCeB2+lhGMZVhmE8axjGvwzDWDnc4xFCDK/0mCCAXve5ZeZXEhbgTVpUEOkxQaRFB1FsxA0scMv7HExekLyw67mEOVBbrGdkxNjk2t82YYV+lpIgw87SZCW4lxk3s8kgKTyAE5WyVNorwzBeMAyj1DCMA52OX2wYxlHDMLINw/hJT/dQSr2rlPom8C3A86qbQogxKdTfm9gQ315ruWXmVzEvJbw1i+yiaXEcbYnEXpnb/zfP+xwS5oJvUNdzCXP0c/Ge/t9fjGyuwG38MjB5Q+mh4R2PwNJk7bb4bnupkQHky4ybR14CLm5/wDAMM/AUcAkwFbjBMIyphmHMMAzjP50e7btH/8L5OiHEGS4jJpjssu4Dt8r6Fo6X1TNvXFtpjoumxXHCEY2qzO1f1XvX/jZ3y6QAcTPAMMly6VjmquEWlgqR6VAqM27DTS+V9jzjBpAaoQM3Nco6Xpz2wE0ptRGo7HT4LCBbKXVcKdUC/BO4Uim1Xyl1eadHqaH9DvhYKbXrdH8GIcTIkx4TRHZJbbc/hHc667fNT41oPTYzMZQq3wS87I26bVFf9bS/DcAnEKInD07gphS8/mX47PcDv5cYPLXOxISgGIiZIjNuI4Cl0dpjKRCX1MhA6pptVNa3nIZRDZ6RssctETjZ7vsC57HufAe4ALjOMIxvdXeRYRh3GIaRaRhGZllZ2eCMVAgxIqXHBFHfYqe4xn1WX2Z+Jd5mg5lJoa3HTCaDmJTJADSX5/T9TXva3+biSlAY6G/1J7dD1mrY+Ieh3TNXsBP+vAAq+vHncSayFEFwPJjMOnCrzoeW0bdvaqxQSukZt172uEFbZulo61k6UgK3PlFKPaGUmqeU+pZS6ukerntGKTVfKTU/Ojr6dA5RCHGaTYkPAeDzLPczZzvzqpieGIqft7nD8anTZgJw9PD+vr9pT/vbXBLmQH1Z25Jaf+14HnyC9Azf548P7F7dsTbBu9/SmZF7Xh+a9xhrLIUQkqC/jpminyVBYdg0Wu3YHKrX5ARoF7iNslpuIyVwKwSS232f5DwmhBAemZsSxvTEEJ7akI3N7uhwrtlmZ19hDfNTu7aemjl9BgCFuX0sntrb/jaX+Nn6eSDLpfXlcOhdmP1VmH0D7HypbVP8YNrwvzpoC03RLcVG2d6fYWEphpB4/XW0M3CTfW7DxtLo7FPaSzkQgKTwAAyDUZegMFICtx1AhmEY4w3D8AGuB94f5jEJIUYRwzC49/wM8isaeG9Px6DmQGENLTYH89rtb3Px9guixiuS5tLjXQK+HvW2v80lbrreAzWQwG33q2Bvgfm3wTk/AmUf/Fm3gkzY8iTMuQnOvU+XSJGkip4ppQPoEOfOnojxYPaVfW7DqNbZ7sqTpVI/bzPxIX6jrgjvcJQD+QewFZhkGEaBYRi3KaVswD3AauAw8IZS6uDpHpsQYnS7cGosU+JD+PP67A7trzKdhXfnuZlxA3CEphJrL+GLvM55Uz1w7W9LWdTzdd7+EDO1/0GQwwGZL0Lq2RAzGcLH6Zm3nS9BzSAtTFib4N279F6tix6GKZfr0hYH3hqc+49VTTVgrW9bKjWZIXqiLJUOI0/6lLaXEhkw6tpeDUdW6Q1KqXillLdSKkkp9bzz+EdKqYlKqQlKqYcH6/0Mw1hlGMYzNTU1g3VLIcQIZRgG312RTm55PR/sbZt1y8yvYlxkANHBvm5fFxyfToqplL99dpzGFrtnb5b3OSTO05mjvUmYDUV7+rf0mPOp3vC+4La2Y8t+qKv0f/5Y3+/nzobf6iXSK54Av1DwD4f0FXDwXR04CvdqnUkiwfFtx2KmSs/SYWRpci2V9j7jBjAuMpATkpwwsiilPlBK3REaGtr7xUKIUW/l1DgmxQbzxLos7A6lG8vnV7ldJnXxikwjwahga1Yx1/51CwVVvfwg93R/m0vCHGishOoT7s87eggWdzwHgTEw+fK2Y+GpMOdrsOsVqCnwbAz73oC/LIZPfqmXRV1BZEEmbHlCL5GmX9B2/fRrwVIABV94dv8zkSvhJKRdEYToyfp4k0wWDAdLo2up1PMZt/K6FuqabUM5rEE15gM3IcSZxWQyuHdFBsfL6vlwfzG55fVU1Lcwf5z7ZVIAwlMxULx8TRwnKxu44s+b2ZpT0f31nu5vc3F1UHC3XFqeBb9Pg09+0XV2q/oEHFsNc28GL5+O55b9UAdfmx7t/f1tzbDmV1BXAtv+Cs+tgMenwcc/6bhE2t6kS8DLT5ZLe+JKEHEtlYKecQNJUBgmrTNuHuxxA0iN0DPmoymzVAI3IcSYc8n0ODJignjy0yx2OPetucsobRU+DoDF4bW8e89SwgO8+drz23l5S577gr6e1G9rL3aa3jPmLnBb+wA0W3RiwDt3gK1dMdCdL4FhwLxvdH1dWDLMvQl2vQrVJ7ueb2/P61BbBNc+Bz/KgquehvhZkPlCxyXS9nyDIWOlc7nUw+Vj4GRlAzXOWY8xz+JuqVTXBaRMlkuHQ2tygodLpa6SIKMpQUECNyHEmGMyGXxnRQZZpXU88Wk2of7eTIjuodaaM3CjKo8J0UG8e/dSlk+K5lfvH+TRT451vb4v+9sAvHx18NY5cDuxDY78B877KVzwAOz/N7x+HTRZdAC36xWYeLEO0tw5+wf6+bPfdf/edqvOQE2cD2nL9f612TfADf+A+3Lg29s7LJE2We28sjWPJqsdpl8D9aX683rA4VBc+9ct/OydftTEG40shRAY3XE2NDQFvANkn9swsTTa8DYb+Hp5Ft6MxiK8ErgJIcaky2bEMyE6kMLqRualtjWWdysoTpdxyFkPFTkE+3rxzE3zuXhaHC9vyaPF1m4Js6/721wS5nRMUFBK7zcLioPFd8PZ34er/wb5m+HFS2HHs7pw7/zbur9nWDIs+pYuF5K91v01+/6ll1zPvU/P3rXnG9w2Q+T0jy9OcP97B3lyXRZkXATegXDw7a73rciB5y6Aox+3Hjp8ykJpbTNrDpW0ZveNaZaijsukACaT3ucmgduwqG2yEuLnjdH5//VuBPt5ExHoI0ulI4lklQpxZjKbDL5zfgbQfRmQViYTpJ2rZ7+enAuPT8f0/t3cHbWT2JY8dh7Oatt/1tf9bS4Jc6C5BiqP6+8Pv683/i//WdvM3azr4av/0tes/pmeCZxwfs/3Xf4LHSi8ezc0dCpnYrfpPXDxs/SyZy+UUry2LR+AZzfmkmtReq/boff0zJ1L2TF46TIo2AEH32k9vCVb7wtssTlYfeBUr+836tUWd0xMcImZIiVBhomlyeZxKRCX1MiAUVWEd8wHbpJVKsSZa9WsBB5YNZXrF3Sz1NjeV9+Ae3bCZY9C4lw4+hEztv+Itb73sfits+A3kfBIOrxzZ9/2t7m0T1CwW/XetugpMPvGjtelXwC3fAgRaToBwdTLj2lvP7jmWWiogP98r2PJkYNv6yDwnB91nW1zY+vxCnLK6vnxxZPx8TLx4AcHUdOvgcYqOP6ZvqjkkA7aHDaImwnFe1tfvzmnnLSoQFIiAnh/7xB0dhhpLIUd97e5xEzRiSCdA2kx5CyNVo/3t7mkRoyuwK1vYakQQowiZpPBN5aO9+xiw4CodP1YcLueYTu1j2ff/YTailN8f0k4RkO5bj8VN8Pz/W0uMVP0cmzRbh0IVR7XwaLZzY/hhDlwbx8K9sbPhPN/roPBff/SM3cOh25IHzMVJl3m0W1e25ZPWIA3tywdh7fZ4KEPD7Nu/kxW+Ibq7NLgWHjlSp1o8fUP9LGNv4eWelpM/nyRW8m1c5MI9ffmLxuyKatt7rZ23qjX0qD/O3ZeKoV2ra8Ow7ilp3dcZzjXUmlfpEQG8t7eIpptdny9zL2/YJiN+Rk3IYToF5MJEmYTdtYNPFG/ggOT7oVVf4LrX4fzftL3+5m9dcCXtwk2/B+MW+bR8mVvGlvsumjwknshZQl89CO9p+3we1B+FM7xYNYOKLE0sfpgCV+Zn4yft5mvLxlHRkwQD3ychW3SpXpp96XLwcsfbvlIdwiIn6ULAZ86wN6Cahpa7CxNj+TK2Qk4FHy4zznrtvrn8PGPB/xZ+6PZ5nlGbJ+4iu92t1QK0vpqGPRnqXRcZABKQUFV4xCNanBJ4CaEED1YMSUWkwGfHBqEPVsJc/TSYkM5XPigR8uXvfnWazu549VM3W7p6qf1Uuk7d+nZtsgMmHqVR/f5xxcnsDsUX12YAoC32cSDV0zjZGUjH9oXQ0sd+IXoZdzICfpF8bP0c/FeNmeXYxiwOC2KjNhgJscF897eIj3zt/tV2P70ae99ml1ay1kPf8pfNmQP/s3d1XBzCUkA3xDZ5zYM+jPjNtpKgkjgJoQQPYgI9GH+uAjWHCoZ8L1UwmwASlIu47nj4fz07X18+emtLPv9OjL70ifVqaHFxpaccjZnl1Ne16w7KlzyO8j/HEoOOGfbel/6sdod/OOLE5w7MZrUyLYl4CXpUVw2M54f742i8vxH4Jb/tpVOAR2gBERB8V62ZFcwIzGU0AD9j+aVsxPZfaKa4uy9bV0E1j7Y4zhe3JzLeY+sp6Fl4FXsG1vsfPv1XdQ0WvnL+hyq6lt6f1Ff9BS4GYaedZPM0tPO0tj3GbcUZxHe0dKzdMwHbpJVKoQYqJVTYzlyqrbPv5FX1rew7kgJj605xs0vfMH57/mw2j6fq49dxEMfHmb1wRIcSlFW28w7u/veMP6L3EqsdoVDwbrDpfrg7K/CzK/oZdnp13l0n7WHSiixNHPTotQu535+6RQMw8zP8udBaKdlQcOA+Fk4ivaw+2QVSyZEtZ5aNUtv2j+64xN9YP5tcHw9HN/gdgxNVjtPrc8hr6KBv2/vpjVYH9z/3gGySuv4xWVTqG+x8cym4wO+ZweudlfukhOgrSRIf/rTin6x2h00Wu19Tk6ICvIh0Mc8ahIUxnzgJlmlQoiBunBqLNC35dL/Hihm/kNruPWlTP68LotSSxOLZk6mctWLPHnXKnb/8kJ2/fJC3rxrCWenR/PZsTL3XRp6sCWnAh+zibgQv7axGQZc8wx8c4P7xAc3XtueT2KYP8snx3Q5lxDmzz3np/Pfg6fYnF3e9cXxs6DsCCZ7M0vTI1sPJ4UHsGBcONa8rajAaLjofyEkSc+6ufmc7+8toryumbgQP57ZeFwXAO6nN3cW8O+dBdyzPJ3bl6WxamYCL23O07OSg6W2WHeb8O2msHPMVN2f9tMH4T/fh3/eCM9dqPvF7vn74I1DtKptbXfVtxk3wzBIGUXN5sd84CaEEAOVGhnIpNhgj5dLm212HvrwMBNjg/nnHYvY/8BF/Pd75/Dba2Zyw1kpzEuNIDywrdr+uROjKKhqJLe8b0s1W3LKmZMSxsXT49iUVd5xibFT0FbbZGXb8Qrsjo5BU3ZpHZuzK/jqwhTM3RQpvn3ZeBLD/PnDJ0e7BpfxszApG9O8CpmfGtHh1BWzEpjYfJDa6Hm6bMnyn+nixYff73CdUornN+UyOS6YP3xpFqW1zby5s6BPfxYux0pq+cW7+1mUFsH3LpgIwHcvyKDZZufpDTl9uteOvEre2NFNOzFLEQS7WSZ1ST5LP3/+uK6DV3kcfAJ0ksq7d8EH3wVrU5/GI3rmajAf3Mc9bgBT4oP5PKucd3b37/+700kCNyGE8MDKabHsyKv0aK/U69tOUFDVyM8vm8KitEgCfXueAThnYjQAG4+VeTye6oYWDhZZWJoexcppsTTbHGw85mZGzOnXHxzi+me2sfwPG3j+89zWno6vb8/H22zwlR5q3fl6mbl7eTq7T1TzWecxOhMULokswd+n4366y9JMpJpKyXToAIpZ1+slxE9/o4sDO23MKudoSS3fXJbG0vRIZieH8fRnOVjtDvqiocXG3a/vIsjXmyeun9MaiE6IDuKqOYm8ui2fEotnwZJSip+9vZ/73trH69vzu15gKXS/v80lcS785CT8shzuOw7f3go3vwe3r9NdMna+BC9cBFWd7u2w6+XkD/8HsrrphiHcap1x6+NSKcAvLpvK3NQwvv+vvfzfx0e6/IIzkkjgJoQQHrhwaiwOBZ8eKe3xOkuTlSfXZXF2ehTLMqI9undqZCDjIgPYmNV94NXZ1pwKlIIlEyI5a1wEof7e3S7lltY28d6eIs6ZGE1MsC+/+c8hFv92HQ+8f5A3dxZwyfR4ooJ6rrd23bwkEsP8eXzNsQ6zbpU+CdSoABb5d52piKjQWaT/PJWAw6F0osSK+6EiSze+d3pu03Fign1ZNSsBwzC4Z3k6BVWNvLfH8yK+Sil+8e4Bssvq+NP1s4kJ8etw/rsrMrA5FH/1cNbtYJGFrNI6ooJ8uf+9g3ze+b+Npbg1cFtzqITv/2uPLsvSnl+InmFrz+yl+9Je/3c9C/fMuZC1Bgoy4eOfwGNTdK28Hc/pmbnmWo//DM50rjZrfU1OAJ2E9OptC7lxYQpPf5bDHa9ktv5yM9JI4CaEEB6YkRhKXIgfa3rZ5/bsxuNUNVj58cWTe7yus3MmRrM1p8LjumObc8oJ9DEzKzkML7OJFZNj+PRwKTY3s1SvbTuB1eHgwSum8eZdS3j/nqVcODWW17fnU9tk46bFXZMSOvPxMnHvinT2FtSw/mhb8Lr1eCUHHONJs7kJiE5sx27yZb0lgV0nqvSxSZdC0lm6lp21kSOnLGzKKufrS8bhgxVa6lkxJYbJccH8ZUO2xzMff9t4nLd3FXLv+RksTY/qcj41MpAvzclPYnwAACAASURBVEvi79tPUFTde72ud3cX4m02eOfbS0iPDuKu13eSXVqnT9qtUFeCCo7nqfXZfPOVTN7ZXcgH+/rQLWLyZXDHBl0H7vXr4LkVkPmCXmL90svwjQ+hvhQ2Peb5Pc9wrqXSvpYDcfE2m3j46hn85sppbDhWxrV/3TIiS4RI4CaEEB4wDIMLp8ay8Vh5txvnSy1NPLcpl8tnxjMjqW8JUedkRNNotZOZV+XR9VuyKzhrfATeZv1jfOW0WGoarXzRqaxIk9XO69vyWTE5hvFRuuzBzKQwHv/KbDb/+Hxeu20hC8ZFdLm/O9fMTSIlIoDH12S1zrptzinnmCmNgKojHfuZApzcBolzMXn58Nq2fP0aw9AzTrVF8OEPKf73fbzg+yjf2nsdPBwHf5yJYW/hnvPTOV5Wz3896Hn678yTPP/xVvYF3s13k7K6ve6e89NRKP68vue6bnaH4r29RZw3KYbkiACe+/p8fMwmbnt5h14qrz0FKP551MEjq49y5ewE0mOCeL2v2bCRE+C2NXDBg3DlX+BHWfCV12DaVboX7owvw9anui6nCrfalkoH1hTqpsXjeOXWsyixNHP1XzZzvKxuMIY3aCRwE0IID104NZZGq73rspnTE+uysNod/HDlpD7fe/GESLzNhkf73IprGjleXt9hZumcidH4epm6JFC8t6eQivoWbj27a+uvmBA/zs7oOjvVHW+ziXvOT2d/YQ1rneVHtmSXY42ZgWFvhrKjbRe3NEDxXsypi7h5cSrv7inirtd2Udds022gJl4Me15jafkbTPOvxBzvLF/SUA4ntnHJ9HjSogP58/rsHrNtPz1cwk/e3s+d8dmE2Kswrb2/w/659pLCA/jKgmTe2HGSkz1kEG7OLqestpmr5+jyJ8kRATxz8zyKa5q487WdlBTmArC6wMR9F0/ij1+ZzY0LU9h7spoDhX0sPeUTAGd/D+bcqLNU27vgV2CYdCsz0au2pdL+zbi1tzQ9ire/vQSAm57/glM1IyeRZMwHblLHTQgxWBalRRLs68XHB051CSaOl9Xxjy9O8tWFKYyL6mMfUyDQ14v5qRFdN/+7sSW7AqBD3bQAHy+WZUTxycGS1rEppXj+81ymxIewOC3S7b366po5iaRGBvD4mmMUVDWQV9FARLozg7Jdw3kKd+pG9MmL+NmlU/jFZVP45NAprn5qs86eve5Fnpv9FlNbXqLx9s3wlVfh8sd1H9TstZhNBnedO4HDxRbWdbOvcGd+JXf/fRdT40P4Rky2fm1Fdof9c53dszwDk8ngsTXHur3m3d2FBPt5cX678ijzUiN45LqZfJFbyUP/0EkDd162jG+fl45hGFwzNwk/b5P7RIb+Ck2CJd+Bg2/Die2Dd98xytJkwzAguJdkIE9NiA7i5VvPoqbRys0vbKe6YZCLOPfTmA/cpI6bEGKw+HiZWDktjrd2FXDBY5/xx7XHWvc9PfrJMXy9THzn/Ix+3/+cidEcOVXba+bj5pxyIgJ9mBwX3OH4yqlxFFY3cqjYAsDn2eUcK6njtrPHYwxCey0AL7OJe8/P4FCxhQfe1704Z8ycCz5BHQO3k9v0c/JZGIbB7cvSeO22hZTXNXPFnz/nwyM1/HmfgxVTE9oCXd8gSFkE2Z8CcNWcRBLD/Pnj2iy+yK2ktLapNSjNKqnl1pcyiQvx48WbZ+OV95nOWk2cD5/9rttSG3Ghftx+9nje2V3Izvyu3SoaWmz89+ApLpsRj593xyzZK2cnct/Fk5gerMu2LJ49vfVcqL83V8xK4L09Ra0zP4Ni6XchKA5W/1S3DxPdsjRaCfLxwtRNWZv+mJ4YyjM3zSOvvIFbX9rRNQFlGIz5wE0IIQbTb66axsNXTyc62Jc/fZrFBY99xsrHP+PD/cV8c1ka0cE9Z2f25FwPyoIopdiSXcHitMgu/0CdPyUGw4BPDurl0uc/zyUqyLe1i8FguXJ2AuOjAll7uISoIF8mxoXoTg3tA7cT23Xpj4C2/XNL0qP44DtnkxIRwN1/30V1g5VvLkvrePP0C6D0IFiK8Tab+P6FE9lfWMOX/7aVsx7+lOm/Ws2lf9rEDc9uw9ts4pVbFxJVvR+aLZBxoc5atRTqrMxu3L08nbgQP+5/72CX5Ic1h0poaLFz1Rw3zeOBb5+Xzp2z/cDLH/zDO5y7cWEqDS123u1HF4xu+Qbpz1S4Ew68OXj3HYMsTdZ+lQLpzZL0KP50/Wz2nKzmrtd39rlMzWCTwE0IIfogwMeLGxem8s87FrPtpyv41aqpBPl6MSE6kG+ek9b7DXowJT6Y6GDfHsuC5JbXc8rSxJL0rkufUUG+zE8N55NDJWSX1rLhaBk3LUrF16v3fqV94WXWGaagy5EYztZXnNqv65A5HFDwBSQv7PLapPAA3rprCTeclcwVsxKYl9ox+CF9hX7O0bNu181LYtN9y3nplgU8sGoqX5qfTEyIL2lRQbx86wJSIgMgey0YZhh/LqSdC2nLYdOj0GRxO/5AXy9+dtkUDhZZ+OeOjgkF7+wuJDHMn7N6StiwFOlSIJ1mMWclhzEjMZTXt53ocxeMHs26Qf/5rn1A7x0EnQhSeVx/9pJDg/deo1htU9/7lHrqkhnxPHTVDDYcLeO+N/fp8jbDZGg+oRBCnAFiQ/y4Zel4blnadeN/fxiGwbKMKNYdKcXuUG47GWzO0fvblk5wn1SwcmocD390mIc+PIyPl4kbF6UMytg6u2JWIjvyqrjGNTMVPwusT0NFjt7b1lSjlz3d8PM289trZrq/cex0CIrVy6Vzvgbo5IDkiADoLucje60uo+Efpr9fcT88u1xnZC7/qduXrJoZz+vb8vnD6qNcNiOesAAfymqb2ZRVzp3npPW83OYK3Ny4cWEKP3l7Pzvzq5jvYbZur0wmuOi38NKl8PyF0FIP1SdAOZft/MPh+4d0osMZzNJo7XcpELdyN+n+uivuB+CrC1OorG/mb58dJ7+yoTVL+3STGTchhBhBzp0YTXWDlf3dZCduyS4nIdSP1Ej3/0i7+qpuOFrG1bMTey2s219mk8H/Xj2jLThxdlCgeG+7/W1dZ9x6ZRh6uTRnnZ69601dGRTvaZupA921YMoV/H979x1fZXn3cfzzS0ICYSQgewhoGLJBRJaoyFOlMoRiBdGK2mpFi9T6WNtax2Md1VqLiri3jEqrOBEREBFRICAyZchGhozICgnnev64Tsg6J4QMTk74vl+v80ru6x7nOrlzkx/X+vHlU3AgdOulmXHvgNakHc7ksWl+osJ732zlaMAdm00a1k/hA7cBHepTNSGON+aV8BIeTXrAuTdBTBzU7+izLwx8Gvr9Cw7tgSWTSvb9StCOtMMFzuItKT8dziz2UiC5zH7Et9xuXnCs6OYLU5h2W6+IBW2gwE1EpEzpmVITs9Dj3AIBx5frfqR7Ss2wkw2a1PR5VYGQS4CUmpotIK6iD6I2zoPKtaBGEbuOz+wNh/fCltTjH7t2hv+a0id3ee+7IONggQvYnlWvGld3bcybX21g+dY03lm8hdb1q9GsTtWw5xAI5MqakFdifByDOzXgw29/YHch0qOdkL4Pw42fweUvw0V/9UuInD0C6raDeeOgJLtnS0Ag4Hj9y/Vc+I9ZDHhqTqHSxRVH2uESbHE7sAvWf+G//+rZY8VmRr2kSiXzHkWkwE1EpAw5rUoCbRskhVwWZPm2NPYezKBHiPFtOY26qBmjeqfQom4BAUhJi42DOq19i9vGeb61ragzWc/sDdixcW4FWjMdEmtC3fa5y2u1gPZX+kkK+8InDv99n+YkJ8YzauIilmzed/zWtoO7IJDhMx6EMbxrY44cDfDWgjAJ6gth5qodPPHp6nyvYxkosphB15Gwa1V2EFsGrN91gKHPz+OvU5bRukESaYczeeTjVcc/sRjSDmXkH+O26iPYVfCCyyGt/MB3RTc5D5a9DfsLTnV3MilwExEpY3o1q8XiTXvZdyj3shJz1/puv+5hxrdlubRdPW4rwiLAxVavPWyeD3s3hB3fViiJNaDB2T4oK0gg4IO7lIv8OLC8LrgTcDD7H2EvkZRYgTsubsGaHfuJMRjQvoDE8eDHtwFUDT9Tt3mdqnRpUoPxX29kxbY0vt91gG37DrH34JFCpTRL3biH616Zzz8/+S7f63fjF+VPa9ZmMFSu7VvdIuxowPHC5+u4ZMxsVmxL45Eh7Zh0Q1eu7d6EifM3snjT3lJ530DAsT89M/es0u3LYOKV8P7oE7/ginehehO/tmAgAxa+UlJVLbZyPznBzPoD/VNSUiJdFRGRQjm/RS2emrmGe6YspUH17G6Zacu2k1K7CnXyJFAvM+q19/k2ARoVI3ADH4zNfhQO7s61pEgu2xbDwR/zd5NmSW4EHa+G1Neg1+1+QdsQftm5EZMXbua0KvH5ktPnkxW4hekqzTK86+ncOnExfcd8nqu8QqwfG3h550Yhz0vPPModk5dQr1pFPhrdiyo5FpOdtuwHbnozlekrdnBJm7rZJ8UlwDm/hlkPws7voFbzgj9DKdn5Uzo3v5nK1+t30+es2jwwqO2x39Vb+zTj3W+2cveUpbw9skfIiTfF8d6SrQQcubsxP7kbXADWf35iP5dDe2DdLN+SWbOZbwFe8JIfVxhb8suNnKhy3+KmBXhFJNp0bJRMSu0qvL9kG89+tu7Y6/tdBxh4vBahSKrXwX+Nq5g9WaGoUvr4P7rrZoU/Zs2ngAW7VsPoORpw8MWYsIfExBgTb+jK08PPDn8d5/xM2R3BpTcK6CoF33I3/tfnMm54Jx6/oj0PDmrL3f1a0aFRMn95e2n+Ls+gp2asYc2O/Tw4uC1JlSoQG2PHXv/Tqg71kyry2pfr85/Y+TqITYCvnimwXqXl2837GPDUHJZs2ctjl7fn+V91zvUfjKoVK/CXS89iyeZ9+ZZgyakoy2ws27qPP/5nCV2a1GDI2cHgfO0M32LbfZTPqJH1H4rCWDXVz4xudZnf7nIj/LQNVrx3wnUrDeW+xU1EJNrExcYw/bbzI12NE1f7LP9HssHZEBdfvGvV7wQVk31w1mZw6GPWTIf6HaByAV3HyadDhyth4avQ8zaoFrqLM27OY/DdR/l3ZBz249oO/uj/mANUSPSTLwpgZnRPyV+vQR0bMGDsHG56YyHv/a4ntatmBzfLt6YxbtZaBndqwAUtauc7Ny42hqu6NeaRqav4bvtPNM85iaJKLWh3OXwzwU/MCNdKWQqmLN7CHZOXULNKApN/2502DUI3lAxoX58JX2/kkamr6NumHjUqZ/+O7PwpnfvfX86sVTt4+Bft+Hnbwi0avfvAEW54bSHVE+MZO7wT8XExfjbytLv9ve99lx/j+M14v6xHYZZMWT4FqjX0s5PBL+yc3Bi+fj787+JJVO5b3ERE5CSJS/Bdkl1HFv9asXFwxgV+DFuo2ZKH9vhFfsN1k+bU8zYfdM19IvT+JW/BzL/596lUPferehNofrFvufnZA3DZM3Ddx6HH1BVC9crxPHtVZ9IOZTLyjVSOZPrxaplHA9zxn29ITozn7n6twp4/9JzTiY+L4dW56/PvPPcmP5M29bUi1e1EHQ04HvpwBbdOXEz7RslMuaVH2KANfDD7fwPbcCA9k79/tBLwmUD+PX8Tff75GVOX/kDtahUZ+WYqf5+6Ml9Wi7wyjwa4ZXwqO/en88xVZ2dnLVkyCbZ/CxfdE+xGvt63li79z/E/1OE031rXakD25JqYWOjyG9g41y8yHWFqcRMRkZJzwZ0ld62UPrD8Hd89Wad17n3rPvNdqYUJ3Go09XlMF7zsxylVydGa9eNaP3j99G5wzfs+YCxlrepX49HL23HL+EXc994yHhjUluc//56lW9IYN7wTyYnhWytrVI5nQPv6/Dd1C3dc0pKknIPx67aBpr3g6+eg280lNh5r0cY9jJ25hl37cy/nkXYog3W7DnB118bc3b8VFWKPH8w2r1OV63o25bnZ6+h25mlMnL+Reet2c06T6jw0uC2NaiRy33vLGTdrLUu37OPJYR3D/jwe/mglc9f+yKND2tG+UXDx5SMH4dP7fatvm1/4ssY9fPq1BS9Cp6sLruDqaXA03a8DmFOH4TDjAf+zHfDkcT9naVKLm4iIlE1Zi+qGml26ZjokJPmk8oVx3h/8H+S5Of7oZhyGt67xAc4vXjgpQVuWfu3qc+P5Z/DmVxt5ZOpKHp/+HX3b1KVvIboIR3RvwqGMo6GXG+k60udqLYHxWJt2H+SW8akMenouizfto1qlCrlejWok8siQdtx/WZtCBW1ZRl3UjDrVEhg9aTHLtqbx4KC2TLqhGym1q5IQF8uDg9ry8OC2fLVuN/2fmsPyrflTl01ZvIUX5nzPNd0a557oMW+sXyD5Z3/LbjEz82MAty46/tqAy6dAlbr5F49OrAHtfulbZw/uLvRnLQ1qcRMRkbKpWn2o3cqPc+txa3a5c77szAsKH2yddia0GeLXdetxqx8X98lffdfXsIlhZ5yWpjsubsnyrWk8PWstSZUqcN/A1sc/CWjTIImzG1fn9XkbuK5H09zpuZpdDNWbwvR7YNWHuU+sUgcu/Mtxx3mlHc5g7Mw1vPzFemIMRvVO4cbzz6RyQsmEDFUS4nj8ig5MXfoDt1yYEnIm79Aup9M6KZ2Z/36C7ePuZTrNmeE6sYImgHHkaIAuTWtwV85u5f07YM6/oGU/aNw99wXbD/W5Xhe8mD12La8jB2D1J35h41Bd4V1ugNRXYdEb0GNUUT9+sSlwExGRsivlIvjyaXimZ3ZZIOBbVQrTTZpTr9vh27d8DtMGnXy3V9eboUXfkq1zIcXGGE8O68joSYu5ssvpuSYqHM813ZswasIiPvtuJxe2zNH1GxPjsyrMeAC2LDxWHHBge9djO1bAsAl+7FcI89fv5revL2T3wSMM7tiQ2y9uXiqZArqfWTP0eoSBAKybAamv0Xblh7QNZLC3ckMuODyZUbxFWnxt1iT3YNNpPbmwYxIVtszPPnfBS5B5GPrcl/+6FZOg7RDfYvazv/nxi3mtmQ6Zh6DVwNCVrtvGd7vOf953RcfEFu3DF5MCNxERKbs6jYA9G7JndGap0wrO6n9i16rVAlpf5gO2mFif87PPvSVU0aJJToznlWu7nPB5fdvUpXbVBF6Zuz534AZ+bFdwfFcg4Hhn8RYe/XgVPY9M5dG1z3F08vXEXv5KvtbKqUu3MWriYhomV+LV67oUONGgyDKPwIY5sPJD2L40//69myBtM1SqAefeCB2vJrl2S5+TdvXHVFv1EZ3WTqPTjrdhRYjrd7kRaoZZt7Xz9X7ixjcToetN+fcvnwKJp8Hp3fPvO3b938D0+2DfJj9xJQLMlbHcZqWlc+fObsGCBcc/UEREyq/ty2Bcd0ioBjfO9hMXotSY6at5fPp3zPjD+ZxRq0q+/XPX7uKBD1awbGsa7Rsm0bZhEvHzn+XuCq9z4KwrqHz5M8e6BF/7cj33vLuMDo2SefGac3It1RFW+n6/TEouzndlu4B/BY761FE7Vviu29WfQHoaxFXygXPeru6Eaj7obHlp2FZBMtN9a2LGodzlsfG+i7SglrDne0P6T3Dz17lTsmUchkfP9O89IMzsY/CfByvyrOITYWYLnXP5BnGqxU1ERE4ddVpD/yf8ivhRHLQBDDu3EU/NXM3fPlhBzzxrxn2xZhefrtxBg+RKjBnagf7t6hMTY0xrdhdj3zrMzSsmsWliAg2HjuHRad/x9Ky19DmrDk8O60il+EJ0Aa78EN75rV9mo7ASa/puyJaX+qVeKhSxCzYuIf8YtsLqfD1MGemzKTTtlV2+dgYc2e+XASlIhLpHcyr3LW45Ul79ZvXq1ZGujoiISIn54+QlTAoxu7RqQhw3905hRPcmVKyQO9jYsGs/qc+PZFD6FCYlDuOPu/sxrEtj7h/YmrjjzQ49mgGf/p9fE69eez9gnzzpqyzGv2Jis79WreeX6Ih04JNxCB5rCYdD5EytmAS3ryn+4tElJFyLW7kP3LKoq1RERMob5xxphzLzlVeKj/VZBMI4fCSTb8f9inP2fMCeSqeT3P1arMOVULVu2HNI2waTr4WNX/rlNS5+CCqU0by5BVk703+GvBp2gWYnOOGlFClwU+AmIiKSLXCUgwsmkLhsPGz4AiwWmv0MOgyDqnly4qZthg//1y9w23+MT68lpUpj3ERERCRbTCyJXa6CLlf5DBKLXofFE0LnbAWffWDEB352rkSMWtxERETEO5rpc8AeOZi7PCYGGnUtXJJ2KRFqcRMREZGCxcYVfcamnBTKVSoiIiISJRS4iYiIiEQJBW4iIiIiUUKBm4iIiEiUUOAmIiIiEiUUuImIiIhECQVuIiIiIlGi3AduZtbfzJ7bt29fpKsiIiIiUizlPnBzzr3nnLshKSkp0lURERERKZZyH7iJiIiIlBcK3ERERESihAI3ERERkSihwE1EREQkSphzLtJ1OCnMbCewoZTfpiawq5TfQ4pG96Zs0n0pm3Rfyi7dm7KpNO5LY+dcrbyFp0zgdjKY2QLnXOdI10Py070pm3Rfyibdl7JL96ZsOpn3RV2lIiIiIlFCgZuIiIhIlFDgVrKei3QFJCzdm7JJ96Vs0n0pu3RvyqaTdl80xk1EREQkSqjFTURERCRKKHArIWZ2iZmtMrM1ZnZnpOtzqjKzRmY208yWm9kyM7s1WF7DzD4xs9XBr9UjXddTkZnFmtkiM3s/uN3UzL4KPjeTzCw+0nU8FZlZsplNNrOVZrbCzLrpmYk8M/t98N+xpWY2wcwq6pmJDDN7ycx2mNnSHGUhnxHzngjeoyVm1qkk66LArQSYWSwwFugLtAKGmVmryNbqlJUJ/ME51wroCtwcvBd3Ap8655oBnwa35eS7FViRY/vvwOPOuRRgD3B9RGolY4CpzrmWQHv8PdIzE0Fm1gAYBXR2zrUBYoGh6JmJlFeAS/KUhXtG+gLNgq8bgHElWREFbiWjC7DGObfOOXcEmAgMjHCdTknOuW3OudTg9z/h/wA1wN+PV4OHvQpcFpkanrrMrCFwKfBCcNuA3sDk4CG6LxFgZklAL+BFAOfcEefcXvTMlAVxQCUziwMSgW3omYkI59xsYHee4nDPyEDgNefNA5LNrF5J1UWBW8loAGzKsb05WCYRZGZNgI7AV0Ad59y24K4fgDoRqtap7F/AHUAguH0asNc5lxnc1nMTGU2BncDLwW7sF8ysMnpmIso5twX4B7ARH7DtAxaiZ6YsCfeMlGpMoMBNyiUzqwL8BxjtnEvLuc/5qdSaTn0SmVk/YIdzbmGk6yL5xAGdgHHOuY7AAfJ0i+qZOfmC46UG4gPr+kBl8nfVSRlxMp8RBW4lYwvQKMd2w2CZRICZVcAHbW865/4bLN6e1VQd/LojUvU7RfUABpjZevxQgt74cVXJwW4g0HMTKZuBzc65r4Lbk/GBnJ6ZyOoDfO+c2+mcywD+i3+O9MyUHeGekVKNCRS4lYz5QLPgbJ94/ADSdyNcp1NScNzUi8AK59w/c+x6F7gm+P01wJSTXbdTmXPuT865hs65JvjnY4ZzbjgwExgSPEz3JQKccz8Am8ysRbDoImA5emYibSPQ1cwSg/+uZd0XPTNlR7hn5F3gV8HZpV2BfTm6VItNC/CWEDP7OX4MTyzwknPugQhX6ZRkZj2Bz4FvyR5L9Wf8OLd/A6cDG4BfOufyDjSVk8DMLgBud871M7Mz8C1wNYBFwFXOufRI1u9UZGYd8JNG4oF1wLX4/9jrmYkgM7sPuAI/W34R8Gv8WCk9MyeZmU0ALgBqAtuBe4B3CPGMBAPtp/Bd2weBa51zC0qsLgrcRERERKKDukpFREREooQCNxEREZEoocBNREREJEoocBMRERGJEgrcRERERKKEAjcRiRpmNsLMnJmlBLdHm9ngCNYn2czuNbNOIfbNMrNZEaiWiJRjccc/RESkzBoNzMGvKh8Jyfj1nDYDqXn2jTz51RGR8k6Bm4hIDmaWUBILmjrnlpdEfUREclJXqYhEpWDe08bA8GD3qTOzV3Lsb29m75rZHjM7ZGZfmNl5ea7xipltNrNuZjbXzA4BjwT3DTWzGWa208z2m9kiM7smx7lNgO+Dm8/nqMOI4P58XaVm1sLM3jazvcE6zTOzS/Icc2/wOs3M7IPge28ws7vNLCbHcVXM7Ekz22hm6Wa2w8ymm1nLYv5oRaQMU+AmItFqEPAD8DHQLfi6HyA45mwuPi3Qb4BfAD8C083s7DzXScKnEJoA9AXGB8vPwCdcHw5cBrwHvGBmvw3u3wZkja97KEcdPghVWTOrj+/WbQ/cAvwS2At8YGZ9Q5zyNjAj+N7vAPeRnRcR4PHgNe4D/ge4EViM774VkXJKXaUiEpWcc4vMLB3Y5Zybl2f3o/gk3b2dc0cAzOxjYCnwV3wwlKUKPt9jrmTdzrkHs74PtnTNAuoBNwHPOOfSzWxR8JB1IeqQ121AdaCbc25N8Lof4hOHPwB8lOf4x5xzLwe/n25mvYFhQFZZN+BN59yLOc55+zh1EJEopxY3ESlXzKwScD7wFhAwszgziwMMmA70ynNKBvB+iOs0M7MJZrYleEwGPsl3iyJWrRcwLytoA3DOHcW39HUws2p5js/bcrcUn8w6y3xghJn92cw6m1lsEeslIlFEgZuIlDc1gFh8y1pGntctQPWcY8WAncEA6hgzqwJ8gu/WvBM4DzgHeAlIKEa9toUo/wEfVFbPU747z3Y6UDHH9u+AZ4Hr8EHcDjN73MwSi1g/EYkC6ioVkfJmLxAAxgKvhTrAORfIuRnikG74iQ/nOefmZBUGW+6KajdQN0R53WAd9pzIxZxz+4E/AX8ys8bAEOBh4Ajwx2LUU0TKMAVuIhLN0oFKOQuccwfM7HN8a1lqniCtsLJarTKyCsysOjAwDOPwVwAAAV5JREFUxPuTtw5hfAaMNrMmzrn1wWvGAlcAi5xzaUWoJwDOuQ3AY2Y2HGhT1OuISNmnwE1Eotly4Dwz64fvctwVDIpuA2YDH5vZi/guyppAJyDWOXfnca47F0gDxprZPUBl4C5gF34Wapbt+NmqQ81sCXAA+N4592OIaz4OjAA+CV4zDb9Ib3Pg0hP83JjZl8C7wLfAfvy4vvbAqyd6LRGJHhrjJiLR7E/AKuDf+HFe9wI451LxY9J+BJ4ApgFjgLb4gK5Azrmd+OVGYvFLgjwEvAC8kee4AH7CQnX8xIf5QP8w19wK9ASWAeOC160BXOqcm1roT5xtNn45kDfxExmGAL93zo0pwrVEJEqYc6GGd4iIiIhIWaMWNxEREZEoocBNREREJEoocBMRERGJEgrcRERERKKEAjcRERGRKKHATURERCRKKHATERERiRIK3ERERESihAI3ERERkSjx/2rTx6xIjg0bAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAmYAAAGYCAYAAADoXC5+AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdeZycVZX/8c+pqk5IJwQSEkIIS3QAQQYRTRB+QMgCEiGgOCiILEEZkdFRHEcHQWMC4s4wLAOICJFlUCKrshlIwiIBDIQdBWZkC0tiEkhCZ+mqOr8/7vN0qqurqqueqq7uDt/369Wv7nrWW9W1nDr33PuYuyMiIiIivS/V2w0QERERkUCBmYiIiEgfocBMREREpI9QYCYiIiLSRygwExEREekjFJiJiIiI9BEKzAAze8nMPPqZVmG7p6NtJhYtXxAtn97NeWZH280sWj6z4PxLzCxd4RgHFWzrZja2irtYqU3bmdnPzewpM1tjZuvN7DUz+7OZ/beZHVXFMf65oD3f6Gbb6QXbrjWzLStsu3PRfZ1YtH5m0Xo3s5yZ/d3M5pnZSWZmRftMjLZ7qbv7VbBP8TnK/Ywtse/Hzex30WO63sxWmdn/mtldZjbDzHavth3NUvQ/KvezTYX9p5rZH81shZm1Ra+bM81sYDfn/ZiZ3WRmS81snZm9YGY/NbMtutnvA2Z2jZm9Hj3GL5vZJWY2upv9to22ezna73Uzu9rMdqn8CNWm4DnnZvaqmW1WZrtxpZ6bZjY23r+Kc5V8Llrn97gfd3OMawq2XdDdOSscZ6xV97rZsmCfUq/pUj+zuzn3tQXb7lW0rtpzlHx9V3q9F50nfswnFi1fUOLYG6Ln3+/N7JMVjpn48TGzYdH+iyy8D20wszfM7HEzu9zC677sZ09vqOH+/l+Z/fcxs99aeP/dYOEzbrGF997Nq2xDyszuLzjXuMbey64yPX2CfuiHZna7u+d76fzbAgcDd5ZZf1KjTmRmE4A/AJsDy4EHgWXAMODDwL8ARwO/6+ZQXyhq33lVNmEz4Bjg0jLrp1d5nP8FHoj+Hgh8AJgU/XzSzP7J3XNVHquSG4A1FdZ3Wmdm/wV8Pbr5DLAIWAvsABwAfBwYCvx7A9rWEwof12JrSy00s28DPwFywAJgJXAg8ANgmplNcfe2Evt9DrgaSAN/ApYA+wDfAo40s/3cfWmJ/Q4E7gAGAY8B9wF7Al8G/snM9nf350vstxtwP7AV8BfgJmAX4Djg02b2cXf/U5n7Xo/tgK8CP++BY1freDM7s9RrwsyGAp/ugXP+usK6DSWWVXruUWmdhUD+yIJFXwD+teD242XaMxUYRXj+vVhifaXXfhKF59kc2AOYRnidXODuXy+7Z42Pj5l9ELgbGE24H48AbwFDovN+Mfr5HY2/n/Uo97+KHUH4vJpfvMLMvgBcDhjwFOHzbSiwHzAL+Hz0/rCsmzacBuwPeHSsnufu7/kf4KXoQX83+n18me2ejtZPLFq+IFo+vZvzzI62m1m0fGa0/M/R79+U2X8LoA14lvDicWBswvs8kPDh54QPic1KbPNR4EfdHGe36BhrgLejv8dX2H56tM1jQBZ4qMx2KeBVYAXwQpnHPX7cZpfY/0hCcODAyQXLJ0bLXqrhsfJaH2vg8GifVcCkEutbgc8Cn+/t53+F/1GXx7Wb/cYB+eh19LGC5UOAe6Njnldiv+2i53UO+GTB8gzwm2i/m0rsNxh4I1r/1aJ1P4+WPwpYiefWE9H6nxWt+9do+RKgtUGPZ/yca4t+Lwe2KPP4dXluAmPj52DS5yob3+Pi95ipZfb/UrT+kej3gjrud9XtLtin7Gu6hmOcGh3jtYLHe2AV+y2guvfxqt4PCh7zidWch/Chf2bB8fdu1OMTvQ4cuBYYWmL9rsBPgUGNeM4344eQxMhG9+v/Fa3bko2f58WP81bA4mjd+d2cY5fodfuHgv/nuJ6+b+rK7OyC6PcsMxvQC+d/GHiOkOUp1cV3NCEzMLsB5zqA8MR+3d3/3d3XFW/g7o+6+3e6OU6cLZtD+BAtXFbJ68Bc4GNmtmuJ9QcRPrB/A6yv4niduPtNhDchgM/Uun8DHB39vsjdu3ybc/c2d7/e3a8tXtePnU74cPmJuz8cL3T3NYRMah74lxLP7dMIz+tfu/stBftlCYHCKuBT0bf+QicB2wDz3f2ionX/QcgqfAT4RNG6Q4EPEbIVpxeucPcLCR+c21J9xrZa/0d4gx8OfLvBx67W7Oj39DLrpxMC5Kub0JaeEr//nEnIhg4HPtV7zamOh0jgh0Cc4S1bVlMLM9uJ8DrIAl9y91Ulzv0Xd/+2u5fMhPdRJxIy7H9x9weL1n2M8OX3eXefXbjC3ZcTglCAfcsd3MxShNdLOyED3zQKzDq7gfBN8X00+R9RYDYbu/iKnUTj3jS3jn53l8Yty8wywPHRzSuBK6K/j7EydTRFrox+Ty+x7qSibZJYFP3esY5jJBU/vl263zZF0ReZOADqEmy6+/8BC4EBhMCoUPyhWWq/VcDvi7arZr8cG78olNvvN166i/vaou0a6QxCgHqaVajT60Flv/yZ2QcIH1R3ETKR/Y6Z/SMh87iG0C0Xv39U82Wx10XB2VPRzVENOmz8XrTG3d9t0DH7gvgz4lcl1lX7Zf7vFdZ9k/B6+Hd3f62WhtVLgVlXcYboTDMb0gvnv5oQfHWqJYveNPcB7nT3RrxpvhL9/kczm5LwGIcR3jz+D7jf3R8hdLNuCfxTFfvfQqhBOr6w6DT6wPgU8Iy7/zlh2yDUE0CCjFsDxI/vdOumeL0P28nMfmBml1kYIHJshdfEBwjfUFe4+/+W2Sb+X3YUY0c1Tf9QtL7b/YpuN2u/urn7U4TArxWY0ejjV+lKwpe/zxUtn16wvr/6YvR7ThSExO+nB5nZ9r3XrJrE7xdvNeh48XvRltbNALX+wswOAHYmZLOuKrHJo4Sga5fi+2xmW7ExY/3LMsffDTgLmOfuJbfpSQrMirj7POCPhG8Z3+yF879B+Ma6d/TkiMWB2uwGnepBQmFlGphrZvPN7LtmdqiZjazyGPG30NnRNz2o4Ruqu68HriN0G328YNUxhA+O2VW2o4soDR2Pbno86XHqcBnhA2FP4GULo/1OtTDyMHE3uXUeXVftz0sJT7cfoTvonwmvhWuBV6z0SN33Rb9fKbGOonXvK1g2Nvr9dqkulnL7RQHd8OjmyzWcr/B2d/uN6KEvZzMIBe8nm9k/dLdxD4iDlenxguiL0QmEms5be6FNdTOzFsLgDYjeh6L30zsJn3XTe6dl1TOzrYG9o5t/aMQxo2xPnHW+0sweib5wfcrMtkt6XKtu9Hapn+kNuFvx58sfvMSgIHdfTXg+ryHc5yej0Zl3EGrFtiHUHt9U4n6lCZ89ecJ7X9NpVGZp3yGMjPymmV3s3Y/aaLQrCd0904H/iJ4oxxOKWBvypunueTM7jPAEPJhQoDwxXm9mjwO/AH5ZqrvHzEZFbczTedTM1cCPgElmNtbdX+qmKVcSRn9OJ4yugxCEZknQZRsFPR8gfPiNI3wAFdcfJfU3s7KDcp5w9w/HN9z9ETP7J+BiQuB5HBs/NNaZ2e3AjxNkBH8HjKhxn0rp+lLeIIyivJWQDc0SBnl8mzCo4rdmdqi731WwTxzAVOoqiUd7FQ5Tr3e/SvuW2q+acxaOStucBo9Sc/eXzOxS4GuEx7k4c9Wj3P1NM7sTOMzMdnP35whfjLYl1ERuqPA8T8TKT/VxUnENUOREMzuxwiGPdPebi5YdQXhtvOju9xcsv4KQ3Z9uZj8o+BLZZ1iYuuHDhEErQwnvDY9U2KXWx+d4wuj3o4Hx0U987ucJj9EFNdaYvUjlEZOV9ksseqziuuFS3ZgAuPsdFmYe+C1h5OkeBavvIgwAKOVbhOD4tKgEo+kUmJXg7o+Z2fWEJ/GZhOLkZrqV8M31ODM7g6I3zUadxN1fBz5uZh8ivKntSxiJOYrwJnEJYcqBw0qc9wTC8+dud3+l4JhvRUHHEYQA6/vdtGGRmT1NqHkZRvgmszfhm1C1qfxyb1KrgVPq7A4tVGm6jC6ZIne/JfoA/AQwhfBmuCchG/hpwn3+srtfXm0D3L3Hp9aIAq67ihY/RJhG4lzg34BzS2wj1fsB4fVxtJn91N3LfUj0lNlEwQphoMT0guU9odwHeLkP6e6mgyiVme3I4Bct/z3hy8n7CV8+uwzG6SVXmllxt7EDX3T3K0rtUKCmx8fd3wE+Z2bfI5SJ7EcYELADYeThj6P1E9397Woa7+4PdNOGnnI0YUT265SfVgoz+zJhQN8CQmD6DGFqjSMJr7/DzOyT7v7Hgn12J4x8XQhc2DPN754Cs/K+S6iT+rKZnefu5bo9ILyYoPs5TuL1Fb+xRd9Y/4cw39Eh9HDth7s/CTwZ3zazPQnZkWMJoyO/DvysaLdKxflXEgKzE81slnc/J9xswjfFz7Gxa6uW+1r4JpUjTNvxBHBrtW8yVfr3KjKAnUTdtTdHP5hZK2G+pB8R3hD/28zubHZxaR1+QHg+7G5mOxQE5XHAOrjCvnGmanXBsnr3i/d9p8r94n2HVThnYTaueN+GcPdlZvafhC8uPyI8J8puHv9hZlYu42Od01zdZYVuJWTgjzeznxG6/Z9y90eraX+t3H16jbs8UMs+ZrYt4b0yT1HNkbu3m9m1hOftF2hcYFbv+33hPGYjCCPlhwIXmNlz7r6wwrFrenxi7v4i4b3259BRu/wVwmfNnsA50e2+LA7Af11m8A5mth8hsfAUcJi7t0er1hAe378TSjMuNbOd3T1X0IUJ8IUqPrd6jAKzMtz9RTO7nDA68yzC0Nxy4gkzK324wMY3/Gq6Rq4kvFhOAyYAT7r7Y1XsVzd3f4Iw+d4QQoD1KQoCMzPbl9C1BfANM/uXokPEz6sdCZmiud2c8hrCN7YvEDKDf2djTUQ1Er1J9QYPk6veaGYPE4bFtxIyalUVmJrZz0nQldmoTJu7rzSzpYSJKsew8Zv5S9HvHSrsHhdfv1SwLP7Cs6WZDS1TZ9ZlP3dfZWYrCQHWjhR8sejmfPHteL8nKuy3PJrqo6ecS+jGP8TC7PDlzlU4Ie/gCtsVBpQV213w5e9fCe81A+nfRf/x1AnrgetKdMVuFf3+JzP7SoV6xmq0EV639b7fX17YjRvVTd4ETCbchw96icmYG8nd/wp8Lepq/hrhvb6qwMzM9gdOTnDay6NsW80sTK0UT3FRKas4Pfo9pyAoKzSHEIS9j5BJfYHwuh9HGJB2aYnnUDyK+hdmthr4XYlpehpCgVllZxG67I6LvlWWE3847dTN8XaOfr/a3Ymj7tSnCPVf0HNdDJX8kRCYFQ8GKCzs7+7yFF+gm8As6v68gzApK4Rah1Ivpk2Guy8xs2cJj1+1gy0AjqL26T9epkFXF4i+Vcajxgo/cP5CuBrAcDP7hzIjM+Oi5o5uO3d/x8z+lzAyczxwTzX7RR5jYxdxqcCs0n57RfuVqtkst19DuftqMzsH+C/CF5Ovltl0BaEebjDhPabcYJb4/WUN4cOlO7MJgdk0Qh1hf55TL87gDyR005UziJCZ/0Ud53qFMCHrToRJx7sws+FsHJzS7fs9dHzZOIbwWtqRUDLwgzraWYs/EgKzWt6LdqJywqKcBSTvAo0/e+6Nsn/lxF8QS2XS4yzqu4QZBIYXrR5GuFpJOR+JfvfYoDKNyqwgGtFzPuFx+mGFTePU+BEW5vbqIkoZ70Hoaru/1DYl/ILQ3bCUkFVqGCvxdaCE+Mnd0c0WdcXFk6dOcHcr9QPEk4F+Kqod687lhPu6nP79zR3o/vGNApxto5tVd2O6+9hyj3mFn7F13JVi0wjZgtWED5C4XRvYOHjj88U7mdn7Cd90NwC3Fa2OJ5Uttd9QNgbsxSOoKu2XZuNcgOX2O8ZKXxswPl6XEVs94BJCBu9jlLkUUtRdc190s9I0NPFo2fuq6YaJMvB/Irzm5pQa3dYf2MapE9YAgyu8J8WZ/XrnNIvf76v5Xzxby+MaDTQ7O7r571bhWsLVSvpe3x13n53gvcjKDPao5n5kCIkSqFD0H3k9+r1PmWPtSgjKIMqou/tLldrNxuz++GhZj9WeKzDr3k8J3z4Pp+uw+9hNhDqn9xFqhjpNrhrVP8SB1TVe5Txk7v7f7j7C3Ud540eGHm5mN5rZZAtTSxS218zsU2z8Bv/bgtWfIYxUe4kK33o8jPR6lFDofmx3jXH3W6P7OsLde2N6i0b7lZmdZaUvbL4FYaTotoQA5/bmNq08M2u1MK1Hl2kiolG8cZfrf5fIav6YUE/zH2a2d8F+QwjdDing4hJ1f/9FyLadaGZHFOyXIXw5GQrc7O7PFu13JfAmYQRwcffLjwlZuMVsDBhjtxEybDsR6rsK7+NXCQXir1MiS20bL0A9s3hdElFAGw+QqXRtxPgSU9+M/g/F7Toc+AYbL7FW7fn3j15z3b5G+7A40Lqhm66/3xK+GOwdFXkndQGhy/TzZvbF4pVRqUf8Rb5ST0s5lxCCgC0I/9N6fcjM5pnZ4RamFOkkCmzj5+Bvi9f3IfG8me/Q/fWb4/XHmNlxhSssTOwcd4PO9+oHmTWNujK74e5vm9mPCAFaa5lt2s3s04QRIl8i1DEsJDyBRhNS6wMJc4d9rSkN716KMDrlSGCFmS0mZOaGErJdcRB6HZ3rn+I3wWvKFSEXuJowyvMLwH83qN2NMtrMHqqw/jF3L66d+7mZVarduaCgDnA4oXvlexaGoz9HqE3ZhtCFNgRYB5zo7rVOZ9GTBhCm+PhPM3uM0A0zgFBTGF8660ZKTI7q7n82s9MJFzF/0MzmEQZiHEiYF/Bhwijn4v1ejT7grgZuNrMHCIHRPoQunReBU0rstybq+rkDuMjMTiLUiuwZtffvwOeKn6cepor5HCEL9S0zm0aoNduZ8HxdCxxd5kM+/hLTyK72awhD9P+x3AbuPs/M/oPw2P7BzJ4hjDID2D36ceA/vMQlwPqh/c1sdoX1r7j7jCjoj6dOqDi9jruvsDBi/FOE96RE81S6+18sXCD7SuByCyPn42v/7kR4DhlhFH2l+1Du+OvNbAZhJOvXzey/3L24a7qqxyf624BJ0c+a6HX9OqFbd2c29m7Mp3ldp0nEnz3XeTfTerj7bVGN+MnA1Wb2HcLk58MI2ekhhMegV+Yp65b3gYuR9vYP3VyclJD1eZWNF5edWGa7rQhDbf9MCMraCZc8uofwBGgps9/M6LgX1dDmei9ivhlhJNjPCQHjy4RAYS1h7qrrCaNZCvf5B8KoJwd2qeIcW0ePgQMfipZNj27/oYa2lrt4fPy4za7hWBML/o+VfhYU7FPN9g58qmCfMYTA7FpCdmYp4Y37HcKb+M+B9/X2c7/E4zOA0JVyV/S6WEPIMiwhdAF+uopjTCXUFa6Mnk/PEAKyiheSJrxh3hy9ZtYTArKfUuKC30X7fSB6nN+M9nuFMGfT6G722zba7pVovzcIQVLJ5zahuPzt6HWyY4Ln3NMVtplW8Dx6qcJ2exNGHf5f9NjGr9erKHHR64L9XqKGCzATuuI6vQ4SPJfGxvephn1mUt1r7fFo+y9Gt18DUlUc/9PR9m9R9H5MlRcxL9h+t+j581dCDeB6wufE74CPV9iv2/MQvgA8GW13dtLHJ9onQ/hydDZwb/R8aYuex68S6iyPrebx660fQqYs/iwZX8N+nyZkyN+M9l9D+BJ2DrBVjW2o6TVUz49FJxQRkQqiLqoHgfPc/d96uz0ismlSjZmISHUOBlYRvm2LiPQIZcxERERE+ggV/28iuikELZZ4gj8Ree+pdTJR7ycTPov0RcqYbSKs/EWCSyl34WARkS7MbDo1zC/oYd4nEUlAgZmIiIhIH7FJdGWOGDHCx44d29vNEBEREenWo48++nd3L3kJrE0iMBs7diyLFi3q7WaIiIiIdMvMXi63TtNliIiIiPQRCsxERERE+oimBmZmtpmZPWJmT5jZM2Y2q8Q2081smZk9Hv1UPURbREREpD9rdo3ZemCyh4sPtwAPmNkd7l58MenfuvtXm9w2ERERkV7V1MDMw9wca6KbLdGP5usQERERoRdqzMwsbWaPA0uBue7+cInN/snMnjSz35nZ9mWO8yUzW2Rmi5YtW9ajbRYRERFphl6bYNbMtgRuAv7V3Z8uWL4VsMbd15vZKcDR7j650rHGjRvnmi5DRGTTtmrVKpYuXUp7e3tvN0WkpJaWFrbeemuGDh1acTsze9Tdx5Va12vzmLn722Y2H5gKPF2wfHnBZpcDP21220REpG9ZtWoVb731FmPGjGHQoEGY6apP0re4O2vXrmXJkiUA3QZn5TR7VObIKFOGmQ0CDgb+UrTN6IKbRwDPNa+FIiLSFy1dupQxY8bQ2tqqoEz6JDOjtbWVMWPGsHTp0sTHaXbGbDTwazNLE4LC6939D2Z2FrDI3W8FvmZmRwBZYAUwvcltFBGRPqa9vZ1Bgwb1djNEujVo0KC6utubPSrzSWCvEstnFPz9HeA7zWyXiIj0fcqUSX9Q7/NUM/+LiIiI9BEKzOrx4EXwzmu93QoRERHZRCgwS6ptBfzxTHj2lt5uiYiIvAfMnj0bM+Oll17qWDZz5kzmzZvXZdvp06ez3XbbNbF10igKzJLKZ8PvnObTERGRnnfYYYexcOFCRo/eOHnBrFmzSgZm0n/12jxm/V4ckMUBmoiISA8aOXIkI0eO7O1mSA9TxiypOCBTYCYiIjV69NFHMTMeeOCBjmUXXnghZsZ3v/vdjmUvvPACZsZtt93WpSszHv13zjnnYGaYGTNnzux0nsWLF3PAAQfQ2trKzjvvzKWXXlpTOxcsWICZccMNNzB9+nSGDRvG0KFD+fznP8/y5cs7bRu3/dxzz2XHHXektbWVww47jKVLl7J06VI++9nPssUWW7D99tvzk5/8pNO+y5Yt45RTTmGXXXahtbWV7bffnmOPPbZjstbYzJkzMTOeeuopJk2aRGtrK6NHj2bGjBnk8/ma7ltfpcAsKQVmIiKS0F577cWWW27ZqRty3rx5DBo0qMuyTCbDhAkTuhxj4cKFQKgnW7hwIQsXLuTkk0/uWL9q1SqOPfZYjjvuOG655RbGjx/Pqaeeyvz582tu72mnnYaZcd1113HOOedw6623ctRRR3XZ7uqrr2bevHlcfPHFXHTRRdx///2ccMIJHHnkkXzoQx/ihhtu4NBDD+X000/n9ttv79hvxYoVbLbZZvzoRz/izjvv5Gc/+xkvvPAC++23H+vWretynk996lMcdNBB3HzzzRx77LGcffbZnHXWWTXfr75IXZlJKTATEelVs37/DM++vqpX2/DBbYfy/cN3r3m/VCrFhAkTmD9/fke259577+XUU0/lggsuYM2aNQwZMoT58+fz0Y9+lM0337zLMfbZZx8AxowZ0/F3odWrV3PxxRczadIkACZMmMBdd93Fdddd17GsWrvvvjtXXnklAFOnTmX48OEcd9xx3HPPPUyZMqVju4EDB3LLLbeQyYTw4umnn+a8887j7LPP7sgETpw4kZtuuok5c+Zw6KGHAvCBD3yA888/v+M4uVyO/fbbjx122IE77riDI488slN7/vmf/5nTTz8dgI9//OOsWrWKc889l9NOO40tt9yypvvW1yhjllRcY6bifxERSWDy5MksXLiQdevW8fjjj/P222/z7W9/m4EDB3L//fcDMH/+/JqDqFhra2unfQcOHMguu+zCK6+8UvOxPvvZz3a6/ZnPfIZUKtWRtYsdfPDBHUEZwK677grAIYcc0rEsk8mw00478eqrr3ba95JLLmHPPfdkyJAhZDIZdthhBwD++te/dtueY445hjVr1vD000932ba/UcYsqY6MWa532yEi8h6VJFPVl0yaNIn169fz4IMPsnjxYvbcc09GjRrF/vvvz/z589lhhx1YunQpkydPTnT8YcOGdVk2cODAkl2D3Rk1alSn2wMGDGDYsGFdasCKzzlgwICyywvbceGFF/K1r32Nf/u3f+NnP/sZw4YNI5/Ps88++5Rsb3F74tvF7emPFJgl1RGYKWMmIiK122OPPRgxYgTz5s1j8eLFHQHY5MmTuf7669l+++0ZMGAA++23Xy+3FN56661Otzds2MDKlSsZM2ZMQ47/m9/8hilTpnDuued2LPvb3/5WsT3vf//7u7SvUe3pTerKTEo1ZiIiUgczY+LEicydO5f777+/U2C2ePFibrrpJvbee29aW1vLHmPAgAGsXbu2x9t6/fXXd7o9Z84c8vk8++67b0OO39bWRktLS6dlcU1bNe35zW9+w5AhQ9hjjz0a0p7epIxZUprHTERE6jRp0iS+8pWvkE6nOeCAA4AwYnPzzTfvGBhQyQc/+EFuu+02pk6dyrBhw9h2223ZdtttG97OZ555hpNOOoljjjmG559/njPPPJOJEyd2Kvyvx9SpU/nJT37CD3/4Q/bee2/mzZvH7373u7Lb//KXvySfzzN+/HjuuusuLr/8cmbOnMkWW2zRkPb0JmXMkuqY+V+BmYiIJBMX548bN46hQ4cCkE6nOfDAAzutL+eiiy5i8ODBHH744YwfP57LLrusR9p5/vnn4+4cffTRnHHGGUybNo05c+Y07PgzZszglFNO4bzzzuPII4/kySef5K677iq7/S233MLcuXM54ogjuOaaa/jud7/L9773vYa1pzeZu/d2G+o2btw4X7RoUXNP+sJcuPYo+Mej4KhfNffcIiLvMc899xy77bZbbzfjPWfBggVMmjSJuXPnctBBB/V2c5g5cyazZs2ivb290+jPvqa756uZPeru40qtU8YsKRX/i4iISIP13XCzr+uoMdN0GSIi0v9ks5VLcdLpdJNaIoWUMUtKozJFRKQfa2lpqfjz61//mokTJ+LufaIbE0JXprv36W7Mem2696yndRT/qytTRET6nz//+c8V17/vfe9rUkukkAKzpJQxExGRfmzcuJK159LL1JWZlOYxExERkQZTYJaUMmYiIiLSYArMklJgJiIiIg2mwCwpFf+LiIhIgykwS0rzmImIiEiDKTBLSl2ZIiIi0tE8F/4AACAASURBVGAKzJLSJZlERKSJZs+ejZnx0ksvdSybOXMm8+bN67Lt9OnT2W677ZrYumDBggXMnDmTfD7f9HNvKhSYJaXpMkREpIkOO+wwFi5cyOjRozuWzZo1q2Rg1lsWLFjArFmzFJjVQRPMJtVR/K/ATEREet7IkSMZOXJkbzdDepgyZkmpxkxERBJ69NFHMTMeeOCBjmUXXnghZsZ3v/vdjmUvvPACZsZtt93WpSvTzAA455xzMDPMjJkzZ3Y6z+LFiznggANobW1l55135tJLL+3SlkceeYSDDjqIIUOGMHjwYKZMmcIjjzzSaZuJEycyceLELvuOHTuW6dOnA6FbddasWUC4DmfcpmqZGWeeeSbnnHMO2223HYMGDWLChAk8/vjjXdqy//77c+edd/LhD3+YQYMGsddee/Hwww+TzWY544wzGD16NMOHD2f69Om8++67nfb//ve/z0c+8hGGDh3KiBEjmDx5Mg899FCnbRYsWICZccMNNzB9+nSGDRvG0KFD+fznP8/y5curvk9JKDBLSoGZiIgktNdee7Hlllt26oacN28egwYN6rIsk8kwYcKELsdYuHAhEOrJFi5cyMKFCzn55JM71q9atYpjjz2W4447jltuuYXx48dz6qmnMn/+/I5tnnzySQ488EBWrlzJ7Nmzueqqq1i1ahUHHnggTzzxRE336eSTT+aLX/wiAA888EBHm2px1VVXcfvtt3PRRRcxe/Zs3nrrLaZMmcKKFSs6bffiiy/yrW99i9NPP505c+awfv16jjjiCE499VTeeOMNZs+ezYwZM7j22ms7gsXYkiVL+MY3vsEtt9zC7Nmz2XrrrZkwYQJPPfVUl/acdtppmBnXXXcd55xzDrfeeitHHXVUTfepVurKTEo1ZiIiveuO0+HNrh+mTbXNHvCJH9e8WyqVYsKECcyfP58ZM2aQz+e59957OfXUU7ngggtYs2YNQ4YMYf78+Xz0ox9l880373KMffbZB4AxY8Z0/F1o9erVXHzxxUyaNAmACRMmcNddd3Hdddd1LDvrrLMYOHAg99xzD1tuuSUABx98MGPHjmXWrFnceOONVd+n7bbbrmPAwcc+9jEymdpDjLVr1/LHP/6RwYMHdxxn55135rzzzuPss8/u2G758uU8+OCDvP/97wcgn8/zyU9+kr/97W/cfffdABxyyCHcd999zJkzh5/+9Kcd+15++eUdf+dyOaZOncruu+/O5Zdfzvnnn9+pPbvvvjtXXnklAFOnTmX48OEcd9xx3HPPPUyZMqXm+1cNZcySUsZMRETqMHnyZBYuXMi6det4/PHHefvtt/n2t7/NwIEDuf/++wGYP39+RxBVq9bW1k77Dhw4kF122YVXXnmlY9l9993HtGnTOoIygKFDh3LEEUdw7733JrxnyR166KEdQRmErtJ99tmnS+Ztl1126QjKAHbddVcgBGOFdt11V1577TXcvWPZ3XffzaRJk9hqq63IZDK0tLTw/PPP89e//rVLez772c92uv2Zz3yGVCpVcyawFsqYJZTPtZMq+C0iIk2WIFPVl0yaNIn169fz4IMPsnjxYvbcc09GjRrF/vvvz/z589lhhx1YunQpkydPTnT8YcOGdVk2cOBA1q1b13F7xYoVnUZ5xrbZZhtWrlyZ6Lz1GDVqVMllzzzzTKdlxfdtwIABZZdns1lyuRyZTIbHHnuMQw89lEMOOYRf/epXjB49mnQ6zcknn9zpcSnXngEDBjBs2DCWLFmS6P5VQ4FZQu3t7QwETBkzERFJYI899mDEiBHMmzePxYsXdwRgkydP5vrrr2f77bdnwIAB7Lfffj3WhuHDh/Pmm292Wf7mm292CnI222wzVq1a1WW74tqver311lsll40ZM6Yhx7/hhhvIZDLceOONtLS0dCxfuXJlp6xhufZs2LCBlStXNqw9pSjZk5DnNgBgOGi+FhERqZGZMXHiRObOncv999/fKTBbvHgxN910E3vvvTetra1ljzFgwADWrl2buA0HHnggt99+O6tXr+5Ytnr1an7/+993GoW544478vzzz7Nhw4aOZffdd1+n/SBk5IDEbbr99ts7jaJ86aWXeOihh9h3330THa9YW1sb6XS602jRefPmdereLXT99dd3uj1nzhzy+XzD2lNKUwMzM9vMzB4xsyfM7Bkzm1Vim4Fm9lsze9HMHjazsc1sY7W8cP4yzf4vIiIJTJo0iUceeYS2tjYOOOAAIIzY3HzzzZk/f3633Zgf/OAHue2225g7dy6LFi3i9ddfr+n83/ve92hra2PKlCnccMMN3HjjjRx00EG0tbUxY8aMju2OOeYYli9fzhe+8AXuvvtufvnLX3LKKaewxRZbdGkPwLnnnsvDDz/MokWLamrPoEGD+PjHP87NN9/Mb3/7W6ZOncrQoUP5xje+UdNxypk6dSpr1qxh+vTp3HPPPVxyySUcd9xxZTNgzzzzDCeddBJ33XUXF154IaeeeioTJ07sscJ/aH7GbD0w2d33BD4MTDWz4qEkXwRWuvtOwHnAT5rcxqp4riAYU3emiIgkEBfnjxs3jqFDhwKQTqc58MADO60v56KLLmLw4MEcfvjhjB8/nssuu6ym83/oQx9iwYIFDB06lBNPPJHjjz+eIUOGcO+997Lnnnt2auell17Kww8/zOGHH86VV17JNddc06X7b9q0afzLv/wLF198Mfvuuy/jx4+vqT0nnHAChx12GF/96lc58cQTGTlyJPfccw/Dhw+v6TjlHHLIIVxwwQX86U9/Ytq0aVxxxRVcddVV7LTTTiW3P//883F3jj76aM444wymTZvGnDlzGtKWcqxwpEIzmVkr8ABwqrs/XLD8LmCmuy80swzwJjDSKzR03LhxXmtUXq93rziSwa9Ec838x8swqGvftIiINMZzzz3Hbrvt1tvNkB4UTzD7gx/8oLebwoIFC5g0aRJz587loIMOqnn/7p6vZvaou48rta7pNWZmljazx4GlwNzCoCwyBngVwN2zwDvAViWO8yUzW2Rmi5YtW9bTze6qMEuWzzX//CIiIrLJaXpg5u45d/8wsB2wt5n9Y8LjXObu49x9XG9cO6xzjZm6MkVERErJ5XJks9myP7rgeWe9Nl2Gu79tZvOBqcDTBauWANsDr0VdmVsAPXthqiQKC/5V/C8iIlLSlClTKk5We+KJJzJ79mx6q7SqlIkTJ/Zae5oamJnZSKA9CsoGAQfTtbj/VuBEYCFwFDCvUn1Zr8krYyYiItKdX/ziF12m1Sg0YsSIJram72t2xmw08GszSxO6Ua939z+Y2VnAIne/FfgVcLWZvQisAI5pchurUzgqM6fATEREpJQPfOADvd2EfqWpgZm7PwnsVWL5jIK/1wGfaWa7EvGCgn9lzERERKQBNPN/QpZrJ+/RzMEKzEREelxfrGoRKVbv81SBWVKeYx3hoqkq/hcR6VktLS11XXpIpFnWrl3b6TqctVJglpDl2llH9MBrHjMRkR619dZbs2TJEtra2pQ5kz7J3Wlra2PJkiVsvfXWiY/Ta9Nl9HuFGbOcMmYiIj0pvlzR66+/Tnu73nOlb2ppaWHUqFEdz9ckFJgllMq3s84HgKEaMxGRJhg6dGhdH3gi/YG6MhOyfJb1HTVmCsxERESkfgrMErJOxf8KzERERKR+CswSMo+6MkGBmYiIiDSEArOEUvncxlGZKv4XERGRBlBgllDKs+rKFBERkYZSYJZEPo/hBYGZ5jETERGR+ikwSyKa6X+9t3S6LSIiIlIPBWZJRF2X6soUERGRRlJglkRU7K+Z/0VERKSRFJglEdWUrVeNmYiIiDSQArMkopoyzWMmIiIijaTALImOGjMV/4uIiEjjKDBLorjGTBkzERERaQAFZklENWUbuzJVYyYiIiL1U2CWRNR1uSHqysxnN/Rma0RERGQTocAsiajrMkuaDZ4mn1NXpoiIiNQv09sN6JeiGrN20uRIax4zERERaQgFZklENWU50mRJY1kFZiIiIlI/dWUmkd+YMcuSxjUqU0RERBpAgVkSUSCW8zRZUrgyZiIiItIACsySyBVmzDIq/hcREZGGUGCWREeNWYocKVzF/yIiItIACsySiGrMsmTIehrXJZlERESkARSYJdExj1mKLGlQV6aIiIg0gAKzJKKuS0u1hFGZCsxERESkARSYJRFlzFItLZouQ0RERBpGgVkSUSCWybSQJdVRcyYiIiJSDwVmSURdmS0tA6JLMiljJiIiIvVTYJZEnDFrGRCK/9WVKSIiIg2gwCyJwsDM0+rKFBERkYZQYJZEpxqzdMeEsyIiIiL1aGpgZmbbm9l8M3vWzJ4xs6+X2Gaimb1jZo9HPzOa2caqRDVmAwaErkxTxkxEREQaINPk82WBb7r7Y2a2OfComc1192eLtrvf3ac1uW3VizJkAwYMJEdKGTMRERFpiKZmzNz9DXd/LPp7NfAcMKaZbWgEz20AYGBLC+2kMRX/i4iISAP0Wo2ZmY0F9gIeLrF6XzN7wszuMLPdy+z/JTNbZGaLli1b1oMt7Sqfy9LuaQYOyJAjjbkCMxEREalfrwRmZjYEuAE4zd1XFa1+DNjR3fcELgRuLnUMd7/M3ce5+7iRI0f2bIOL5HPt5EixWUuKdk2XISIiIg3S9MDMzFoIQdm17n5j8Xp3X+Xua6K/bwdazGxEk5tZUT7XTjsZBrWkyXmalAIzERERaYBmj8o04FfAc+7+n2W22SbaDjPbm9DG5c1rZfc8GzJmg1rSZElhruJ/ERERqV+zR2XuBxwPPGVmj0fLzgB2AHD3S4GjgFPNLAusBY5xd29yOyvK57O0k2bQgDRZMqoxExERkYZoamDm7g8A1s02FwEXNadFyYSMWZqBUcZMXZkiIiLSCJr5PwHPZcmSZrNMinYy6soUERGRhlBgloDn2mn3NC3pFHlLk1JXpoiIiDSAArMk8llypMmkDbc0KWXMREREpAEUmCXguXbaSZNJpchbhhR5yOd7u1kiIiLSzykwS8CjjFlL2vBUOizUAAARERGpkwKzJHLtZEmTSYeMGaDATEREROqmwCwBz4dRmS0pwzsCs/bebZSIiIj0ewrMErBouox0qrArUwMAREREpD4KzJLIt5P1FJl0Ck+1RMvUlSkiIiL1UWCWRD5Llkwo/rcoY5ZTV6aIiIjUR4FZEvksWVJkUqmCGjNlzERERKQ+CswSsIKMGWkFZiIiItIYCswSsChjlk4ZpBSYiYiISGMoMEsini4jnQIV/4uIiEiDKDBLwLzztTIBFf+LiIhI3RSYJWD5LO0erpVJOs6YaR4zERERqY8CswRSvvFamabifxEREWkQBWYJWD5HezTz/8bif3VlioiISH0UmCWwMWOWwtIq/hcREZHGUGCWgHmWdtJkCjNmOQVmIiIiUh8FZgmEjFk0j5lqzERERKRBFJjVyp2058hbBjMjpXnMREREpEEUmNUqmhYjH81ftnFUpor/RUREpD4KzGoVZcbcQqYsldE8ZiIiItIYCsxqFWXG8qkoY9ZR/K+MmYiIiNRHgVmtOjJmISCzjGrMREREpDEUmNUqmhbDo0xZSvOYiYiISIMoMKtVvjgwC79dXZkiIiJSp6oDMzPby8xuNLO/m1nWzD4SLf+hmU3tuSb2MVGNWdyVmc4MACCnCWZFRESkTlUFZma2P7AQ2BX4n6L98sCXG9+0PqooYxZPl5HPbui1JomIiMimodqM2Y+Bu4DdgX8rWvcY8JFGNqpPK1NjllfGTEREROqUqXK7jwCfdnc3My9a93dgZGOb1YdFGbN4mox0Jg7MVGMmIiIi9ak2Y7YOaC2zbjTwTmOa0w/ENWbRpZjSmaj4P6vATEREROpTbWD2AHCaWXQdoiDOnH0RmNfQVvVlHV2Z4aFoSadp97S6MkVERKRu1XZlfg/4E/AE8DtCUHaimf0n8FFgfM80rw+KuzKj2rJ0ysiRIp9T8b+IiIjUp6qMmbs/AUwA3gLOBAz4arT6QHf/a880rw+KL1Ye1Zhl0kY7GVwZMxEREalT1fOYuftj7j4F2BzYDhjq7pPcfXG1xzCz7c1svpk9a2bPmNnXS2xjZnaBmb1oZk/G86X1GUUZs5Z0ihwpBWYiIiJSt2q7Mju4+zrg9YTnywLfdPfHzGxz4FEzm+vuzxZs8wlg5+jnY8Al0e++IQ7A4oxZysiSxjQqU0REROpUdWBmZrsBRwHbA5sVrXZ3P7G7Y7j7G8Ab0d+rzew5YAxQGJh9ErjK3R14yMy2NLPR0b69L74mZpQxy6RDYNaijJmIiIjUqarAzMxOAK4gFP0vBYor3YvnNqvmmGOBvYCHi1aNAV4tuP1atKxTYGZmXwK+BLDDDjvUevrkohqzVDQqM5NKkSVNRhkzERERqVMtozJvAb7o7m/Xe1IzGwLcAJzm7quSHMPdLwMuAxg3blzNgWFiHRmzcI3MTNrIehrPK2MmIiIi9ak2MNsG+HKDgrIWQlB2rbvfWGKTJYTu0th20bK+IeqyTEXXyIyL/1FgJiIiInWqdlTmn4Dd6j2ZmRnwK+A5d//PMpvdCpwQjc7cB3inz9SXQZdRmXHxv6srU0REROpUbcbsq8CNZrYc+COwsngDd89XcZz9gOOBp8zs8WjZGcAO0TEuBW4HDgVeBNqAk6psY3NENWaWiUdlhhoz8rnebJWIiIhsAqoNzF4DFgPXlFnv1RzL3R8gTE5baRsHvlJlu5ovypilUhtHZW4gDcqYiYiISJ2qDcx+CRwN3Az8ha6jMt87ohozy8QTzBptpMFVYyYiIiL1qTYw+yTwLXc/vycb0x94vh0D0h01ZqEr01T8LyIiInWqtvj/XTpPAvuelc9G85hlNnZl5lyjMkVERKR+1QZmVwLH9mRD+ot8UVemMmYiIiLSKNV2Zb4MfM7M5gJ3UnpU5hWNbFhflc+G8rp0NI9ZfEmmeLSmiIiISFLVBmaXRL93BKaUWO+ESzZt8jyXJe/WUWPW0pExa+vllomIiEh/V21g9r4ebUU/ks+1kyVFSzrM+hFnzMw1j5mIiIjUp6rAzN1f7umG9BchMMuQSYfyvEzayJFSjZmIiIjUrdrif4l4rp0saTKpKGOWStFOBtM8ZiIiIlKnshkzM/s/4Eh3f8LM/kaoIyvH3f0fGt66Piify5IjRSbqykynQsYspUsyiYiISJ0qdWXeC6wq+LtSYPbeEXdlpjYmG92UMRMREZH6lQ3M3P2kgr+nN6U1/UA+27n4HyBnaVIq/hcREZE6VVVjZmYzzGzbMutGm9mMxjarD8tnyXq6U8Ysb2lSypiJiIhInaot/v8+sF2ZddtG698TOor/CzJmeWtRYCYiIiJ1qzYwswrrhgHrG9CWfsHzWbKkaUkX1pipK1NERETqV2lU5kRgcsGiU8xsWtFmg4DDgGca37Q+KtdOjjTp1MZY1S2tUZkiIiJSt0qjMg8Evhv97cBJJbbZADwLfK3B7eq78jnaSXcq/s+nMqRzOXAHq5RcFBERESmvbFemu89y95S7pwhdmfvEtwt+NnP3j7j7wuY1uZflQ8aseLqMsE51ZiIiIpJctZdk0hUCYrks7aQZXJAx81RBYBZd3FxERESkVgq4apVvJ+ddi//DOmXMREREJDkFZrWKasw6Ff/HGbNcey81SkRERDYFCsxqZPl2cqRoKagxIxV1X2pkpoiIiNRBgVmNzHPhWpmdaszirkxlzERERCQ5BWY1sny4VmZhYLYxY6YaMxEREUmuqlGZMTObBOwLjAGWAAvdfX5PNKzPyufCzP+dujJV/C8iIiL1qyowM7PhwBxgImGy2ZWESzGZmc0HPuvuK3qqkX1JKrokU7pUxiynwExERESSq7Yr8wJgPHA8MMjdRxIux3RCtPz8nmle32OeJetFGbO0JpgVERGR+lXblXk48B13/594gbu3A9dG2bQf9ETj+iLzkDHrXGMWB2Yq/hcREZHkqs2Y5YAXyqz7a7T+PSHlocYsUzCPman4X0RERBqg2sDsFuDoMuuOAW5uTHP6vlQ+S97SWOHFyju6Mt8z8amIiIj0gGq7Mn8PnGdmtxEGAbwFjAI+C+wOfN3MJscbu/u8Rje0rzDPkbfOD5tp5n8RERFpgGoDs99Fv7cHPlFi/Q3RbyOM2kzX2a4+K+0hY1bIMurKFBERkfpVG5hN6tFW9COlMmYq/hcREZFGqCowc/d7e7oh/YI7aXLk42L/SCqta2WKiIhI/Wqd+X8EsA+wFfB7d19hZpsBG9w93xMN7FOirkov7srUPGYiIiLSAFWNyrTgZ8BrwK3AFcDYaPUtwJk90rq+Jiru91TneDbOmLmK/0VERKQO1U6X8R3gq8BZwMcIRf6x3wPTqjmImV1hZkvN7Oky6yea2Ttm9nj0M6PK9jVHVEPmVjowy2cVmImIiEhy1XZlngyc5e4/MrPiEZcvAv9Q5XFmAxcBV1XY5n53ryrQa7qohixflDGL5zHL5do33eGoIiIi0uOqzZiNAR4qs24DMLiag7j7fUD/vdh53FVZFJilo4xZThkzERERqUO1gdkS4B/LrNsT+FtjmgPAvmb2hJndYWa7l9vIzL5kZovMbNGyZcsaePoK4uL/4hqzTLjtORX/i4iISHLVBmZzgBlmtl/BMjezXYBvAr9pUHseA3Z09z2BC6lwqSd3v8zdx7n7uJEjRzbo9N2I5ykr6s1NZQYAkMtuaE47REREZJNUbWA2E/gLcB8bL2Y+B3gquv3jRjTG3Ve5+5ro79uBlmiKjr4hqjHzonnM4q7MvDJmIiIiUoeqAjN3XwtMBKYDDwJ3A38GvgQc7O4NSRWZ2TYWXR3czPaO2re8EcduiI4as9LzmGlUpoiIiNSj6glm3T0HXB39JGJm1xECvBFm9hrwfaAlOv6lwFHAqWaWBdYCx7i7Jz1fw8UTyKaLMmYtoSszr3nMREREpA5VBWZmlgP2dfdHSqz7KPCIu3c7U4S7f66b9RcRptPom/JlRmWq+F9EREQaoNoaM6uwLg30naxWT4qvhdlluoyQMXN1ZYqIiEgdKmbMzCzFxqAsFd0uNAj4BPD3Hmhb3xN1VVpRV2ZLJkXWU+TzCsxEREQkubKBmZl9H4gvieTAnyoc5+JGNqrPimvMijNmKSNHWl2ZIiIiUpdKGbMF0W8jBGi/IlzEvNB64FngDw1vWV+UL5MxS6doJ62LmIuIiEhdygZm7n4vcC+AmTnwS3d/vVkN65OiGrN4eoxYJmXkSCljJiIiInWpalSmu88qvG1mWwA7A2+6e3EWbdPVUWNWFJilU2RJ46oxExERkTqUHZVpZoeYWZcZ/c3sTGAp8DDwspn9j5lVPR9avxbVmFlqQKfFLWkjSxqUMRMREZE6VAqovkzRNBhmdjBwNuFSTJcDuwGnAI8C5/ZQG/uOfOmMWTplUcZMgZmIiIgkVykw24sQhBU6CVgHHOLubwJEV1A6lvdEYBbVmGW6Fv9nPU1axf8iIiJSh0oTzG4N/G/RsoOBB+KgLHIbsEujG9YnRYFXqkTxf5b0xgloRURERBKoFJitBgbHN8xsZ2Ar4KGi7VYRZv/f5OWjGrJUurjGLEWO1MZLNomIiIgkUCkw+wvwyYLbnyTUnP2xaLv3AW81uF19Ui67AYBUplSNWWbjBLQiIiIiCVSqMTsPuNHMhhMCr+mEov/iKwAcCjzRI63rY/JlLsmUSRtZUurKFBERkbqUzZi5+83AacB44ARCF+Zn3L1jpKaZbQMcBNzew+3sE/LRRcpTmaKuzFSYx8zUlSkiIiJ1qDj/mLtfAFxQYf2bwIhGN6qvykddmelM8QSzKv4XERGR+lWqMZMiFYv/XRkzERERqY8CsxrEXZnponnM0imjnTSmjJmIiIjUQYFZDTyfJespWjKdH7b4IubmGpUpIiIiySkwq4HnNpAjTTrV+WEzM3KWwTRdhoiIiNRBgVkN8rks7aRpSVnXdZbGXF2ZIiIikpwCsxp4tp0cKTLprg9bnjQpFf+LiIhIHRSY1cDzWdrJkEl3zZjlLKOMmYiIiNSl4jxm0pnn2smToiVVImNmaVIKzERERKQOypjVIsqYpUvWmGVIaVSmiIiI1EGBWQ08107OU7SU6Mr0lDJmIiIiUh8FZrXIZ8mSLln878qYiYiISJ0UmNUiFwVmJboyc5ZRxkxERETqosCsBh5lzFpKZszSpFBgJiIiIskpMKuB5dvJki5Z/O+pDGl1ZYqIiEgdFJjVIpeLMmYlAjPLkCEH7r3QMBEREdkUKDCrRZQxK1X8TyodbaPuTBEREUlGgVkNzLNkvcy1MlMt0R/qzhQREZFkFJjVwPLZstfKJBVdREGBmYiIiCSkwKwWFa6VuTEw04XMRUREJBkFZjXoyJiV6MpUjZmIiIjUS4FZDcxztJMmU+Ii5urKFBERkXo1NTAzsyvMbKmZPV1mvZnZBWb2opk9aWYfaWb7upPKZ8mVmy4jLv7PqStTREREkml2xmw2MLXC+k8AO0c/XwIuaUKbqmbeTo4MZpVqzJQxExERkWSaGpi5+33AigqbfBK4yoOHgC3NbHRzWte9lOfIW7rkOktrugwRERGpT1+rMRsDvFpw+7VoWRdm9iUzW2Rmi5YtW9aUxqXyWXJlAjPScfG/AjMRERFJpq8FZlVz98vcfZy7jxs5cmRTzmmeI2+Z0us0wayIiIjUqa8FZkuA7Qtubxct6xPSni0bmHXUmKn4X0RERBLqa4HZrcAJ0ejMfYB33P2N3m5ULFUpY9ZRY6Z5zERERCSZMumfnmFm1wETgRFm9hrwfaAFwN0vBW4HDgVeBNqAk5rZvu6kPIunS9eYpdKa+V9ERETq09TAzN0/1816B77SpObULEX3GbN8tr3PpSFFRESkf1AMUa18jhSOp8oFZmF5VjVmIiIikpACs2pFoy29u4xZuwIzcfN9SwAAIABJREFUERERSUaBWbWiTFi5jFkqCsxyypiJiIhIQgrMqhVnzMoFZpm4xmxD05okIiIimxYFZtWKJ44tM/P/xoyZJpgVERGRZBSYVasjY9ZScnU8XUYuq65MEZG63PczuP/c3m6FSK9QYFatuHYsVSZjFnVlugIzEZH6PHtL+BF5D2rqPGb9WtyVmS6TMYtrzNSVKSJSn7aVvd0CkV6jwKxacWBWpvg/HU+XkVPxv4hIXdauAM+DO5j1dmtEmkpdmdXqCMxKZ8zSHaMylTETEUmsfR20t0F2HWxY09ut6T/u/0948e7ebkXf1L4Obv4KvLOkt1tSFQVm1eqoMas8j1le85iJiCS3tqAb892/9147+psHzoPH/6e3W9E3vfUMPH4NvDi3t1tSFQVm1eqoMSvTldkSFf8rMBMRSW7tio1/ty3vvXb0J9kNsH4VrH6rt1vSN7VFAf6qN3q3HVVSYFatKDCzcl2Z6TgwU1emiEhibQWB2bvLeq8d/UkczK5RYFZS/Dxa/XrvtqNKCsyq1U3GLKOMmYhI/QozZurKrE6cWezrgdljV8Nrjzb/vPHzaPWbzT93AgrMqhUFXKky02VkMhlybsqYiYjUozBj1qbArCpxYLZ+FWxo6922VHLnd+Ch/27+edWVuYnqLmOWMrKkyecVmImIJBYX/6cyyphVq7AWr69mzdrXwobV8ParzT93R8ZMXZmblrjGrFxglg6BGerKFBFJbu0KyAyCzbdVYFat/hCYrVkafr/9SvPPHT+P2pZDdn3zz18jBWbV6ujKHFBydSaVIkcaV8ZMRCS5tpXQOhwGb9WzXZnZDXDr13onUGi0wu7fJHVUbz0DD/+i++02tIVrmGYTTKQeF+CveTPMK9ZMhc+jflBnpsCsSvGllsplzFrSRrsyZiIi9Vm7AgYNh9YRPTsqc9lf4LFfw/N39dw5mqVtOVj0cR5npmrx6K/hjm+H7sZKXvgj3HMWvPxA7ecobNc7r9W+fz3e/Xt4PgGs7vt1ZgrMqpSLviHE18QslkmHjBn5XDObJSKyaWlbAYO2hMEj4d0enMesYwqFvv9B3a225bDF9mDpkJGqVVx71V3AFGcXk8yg/25BYPb2y7Xvn5R7CMxGfyjc7gf/bwVmVcpnQybMyo3KTBlZUurKFBGpx9oVnbsy3XvmPHFg1k9G6lXUtjwEskO2TjbJbPwYdBcwxYHZqgSB2ZqC7Gczu483vAvZtbDNHuF2P/h/KzCrUi6uMStzSaZMysh6GlNXpohIcm0FXZk9eb3MuGutn4zUq6htObRuBUNGJcyYRft0FzB1ZMwSdEW+uxQGbB5G2zYzMIvry7baGdID+8X/u3SUIV3EGbNUpkzxfzoVRmUqYyYikox7mC6jdXjIAEHohhq4eePPFXet9YNi8G61rYCtdwez2rNZ+fzGYK7awCxRxmwpbD4qfEY2sysz7g4fPBKGjlbGbFMSX5w8lSlf/K/ATESkDutXgedCxmxwVKzdU9fLXLOJdWW2Dg8Zs1q7Mt9dtvFzq1Jg5l5njdkyGLw1bLlDczNmcZf14BGw+eh+EYgrMKtSnDFLl8mYpVNGjjSmwExEJJl42ofW4RtH0fXUyMw4Y7b+nVCH1F9taIP2to1dme8ug1quQNNRDG+VA6a2FdD+bugOXLWk9tq/NUthyMjmB2ZxV2ZHYNb3uzIVmFVpY2BWuvi/JZUiS0oZMxGRpOLrZA4atjFj1lOTzBYGfP0gi1JW/Ji1bhW6CvHa5n+LA7NRu1cOmOLuxzEfDXV/696prZ0dGbMdwyS43U3NUaxtBby+uLZ9YOPzp3UEDN02ZEh7akBJgygwq1K+Y+b/0oFZKmVkyShjJiKSVFt0OaZOXZk9FJitWRaCBIBVfT+LUlbc1du6FQzZJvxdS6AZ3/ft964cMMVB2477ht+1DADIboB1b4dRo/FjXusAgj+dD1dMrX1y2neXhStJDBgMm28TRmjWGlQ2mQKzKnkuzGOWKZMxA8hbGnPNYyYiksjagq7MAYOhpbVnMmb5fPjA7pjbqh9nzDoFZqPC37Vclmn1G2Fy2jHjwu1yAVMcmO3w/8LvWgYAdNR5RV2ZUPsAgJV/C6N033yqtv3alocg3yx0ZUKfn8tMgVmVPOqzT5eZ+R9QjZmISD3iGrNBw8Pv1hE9E5itXRkGGWyzZ7jdD+qOymor7sqk9sBsyCgY/v5wu1zA9PYrsNkWsPVu4XYtGa+4nm/I1gWBWY11ZvH5Xn+stv3eXRYeGwhdmdDnM6SaLqNK+Ww77Z6mJZMuu03O0pgrMBMRSWTtSsDCzP/Qc9fLjAOF4e+DAUP698jM/9/emcfJUZV7//tUdc++ZCcbEJZAjAiIYXsRUHxVVvEqIly8ioheF67L9aq4XperXPfX7bqACigqXlQWiSAKCIgigbCEHUIgCQkzWSaZfbq7zvvHqdNd01PdXT3pmekhz/fz6U91V1dXna6l+1e/85zniTpmjW32eTUjM3dusl18lQRTz7N2mfb5tsJANY6ZGwHbOs9+3kuPQ5iF29t4T3Wf699SSL2Sd8zq2yFVxywhJsiSxSflS8llAlSYKYqijJvBbdaV8cIb4Na5EzMq062zbV44Um+6C7NQzKYa7cCJapLM9m6C9oWVBVPPMzY+zPOt81RNyoy8YzbXfr5zcXXCLDtScAE3VumYua5MsN8R6t4hVWGWEJPNkMXH90oLs5yk8LQrU1EUZXwMhOWYHC1zJqZepsv67xyc6S7MmmcWxGzbHtU5Qr2bbOLVcoLJ5TBzgfsdi6p0zNz+Dp2ralNm9G4CDMxcAlufgMGeZJ8zZnRXZrrZ7qs6d0hVmCUldMzSfuldZjT4X1EUZfwMbrN/nI6JqpcZdcw6Fk5/YeaEB4RlmbpKLx8lM2i7j52TVEowDWy1udJcd2fnoipjzLoh3WoHdJTbTimcCHzRaXa66b5knxvptwMGnCAE6w7W+fFWYZYQk8uQxSNVyTFTYaYoijI+XJ1MR+vciamX2ddlazY2zShkg6/z3FYlKRZm7fOTd2U6gdIeBsWXEkxuQIATZh2LbAB9ECTbjksu66g2l5nrNn3R6+w0aZxZNLmsYxo4pCrMEmJjzFJlHbNAfDyNMVMURRkfgzFdmVD7kZn9XXbdnmeFWW5k4ko/TTT9W0cLj7Z5Nvg/idB0XZ4dYVB8KcHkxFreMVsMueHkAzP6u2y3scOtJ6nrtmO9ne7xYjt6NGmcmTtvoo7ZNKiXqcIsKbkMObyywf9GHTNFUZTxM7C9yDGboHqZfd0FB6djeuS2Komrk+lom29FU5Ikqi5thButWEow5YXZnnbasSh+uVL0dVvB6Kg2l9nOjdbdbGi1lQeqFWYtUcdsoRWK1ZStmmQmXZiJyIki8piIPCkiF8a8f66IdIvIfeHj/MluYyxBjozxy3ZlBl5KHTNFUZTxkMvASO9okdE6QfUyXXkgKHTj1bmLEosx8V2ZkCyXWb4rs0iYFQumnmetMGrqtK87Q2GWdABAf/do18ptZ3tCYbZjo3XpwAqz3ueSHa98Ytui/WOCwkjROmRShZmI+MD3gJOA5cDZIrI8ZtErjTGHho9LJrONJQky5PBJeeWD/9UxUxRFGQeDrhxTJPh/wroyIw5OPoXCNBRmw70QZIqC/8PvlWRkZu9mW13BCa5SucxcDjNHRyiSkqTMyGWteIw6ZtXmMtu5oeDSLTzMTpMkmh2I68qsfyE+2Y7ZEcCTxpi1xpgR4FfA6ZPchvGRy5KpkMfMeCl8VJgpiqJUzUCkHJNjIuplGmOD0fNJR6exMIsml3W4eplJRmbufM66ZRL+r5USTMXCrHUO+I1WMCVqoxktjqrNZbZjY8GlW3CwTXCbpDuzf0uhTqYjn2S2fnOZTbYwWwSsj7zeEM4r5o0i8oCIXCUie8atSETeJSKrRGRVd/cEJCAsJnTMyqfLSOFrV6aiKEr1uDqZUcdsIuplDu+0MVjOwfHTVjTUeZmeWKLlmBz5skxJHLNNBaECVjDN2HO0YCrOYQZWyHUuSuaYRcsxRZm5dzJhNjJgzw3XlZluhj2WJxuZGc3673COWR1n/6/H4P/rgCXGmIOBm4DL4hYyxvzIGLPCGLNi7ty5cYvUFHGZ/8vEmBlPuzIVRVHGRXGdTEet62X2RQpqO1zKjOlGnGPW2GFdokRdmZsKgx8cxSkzinOYOZImmY0m8y23nVI4wey6T8HGmT13b+WRpwNbRseXQTgaN1XXQnyyhdlGIOqALQ7n5THGbDXGDIcvLwFeNkltK0++JFMFx0y7MhVFUapnMKYrE2y3WS27MvuLstBDmGS2fv+oS5IXZpF9JmLdqUpdmcaEdTIrCLPiHGaOzsXJRmX2xwhht77+rsq5zFyqjM5I59rCw+yo021rK2x7y+gRmWBTpLTVdy6zyRZmdwNLRWQfEWkAzgKujS4gItGz5HXAI5PYvpIkc8xUmCmKooyLUo5Z65zajsqMZv13tM+v62DwksQ5ZpAsyezgdtulGyfMornMinOYOToWWXFTKe1EX6RO5qjthF2jPespi3PlOiLCbFHo11SKM4vryoQwl1n9CvFJFWbGmCxwAXAjVnD92hjzkIh8XkTClL68X0QeEpH7gfcD505mG0shJkfWlM9jhpfGwyTPhqwoiqJYBreD3zA6UBtqXy8zrmutfaF15bIjtdvOZDCw1XbLNXaMnt+2h00yWw7nGI3pygwFk3PDXEqLGUXh3p2LbNqJSgKwv8sOFChuY6kRoMW4ODYXGwYwd5ntri0XZ2ZMfFcmhNn/67frOjXZGzTGrARWFs37TOT5x4GPT3a7KiFBhixp0uXSZXjh7gwy4DVOUssURVFeAAyG5Zik6ObXOWbGjH1vPPR3A1Ii99fmsc5QPeNymBXvl7Y94Om/lP/szqJyTI5oLrM5S8fmMHNEU2Z0LqYkLrlscRuTJpnducGK6FTkP9VPwcJDy6fMGOkbWyfT0b4Q1lbYP1NIPQb/1yVicmTx8cp0ZYrn2yeBjsxUFEWpioGiAuaO1jm2y61W9TL7uqyY8SO+xDQYqRdLcXJZR/seNgYrM1T6sy6mzolSR7GTVZwqw5FPMlshzqy/K14ctSXMZVZK+C08DDbdbxMTx243Juu/o2OBHZ07XOMarDVChVlCJMgSiF92mbxjVupEURRFUeIZ3D428B9qn2S2v3ts6gYXZ1XHcUexDGyLF2ZtCbL/7yzK+h/9bFQwlRRmCZPMFpdjcnje2NQcse3cODrw37HoMOuIdT0c/zkXf9caI8zyuczqU4irMEuIZ7LkpELPr5+200AHACiKolRFSccsdFtqJcz6usb+WbdP03qZxXUyHUnKMvVusqI31TB6flQwxeUwczR1QkN75ZQZpRwzqJwywxgb69YR45hVGgCQHw1aTpjVpxBXYZYQz1R2zLQrU1EUZZwMbosXGS54u1YpM6J1Mh0ts+zAg2kpzOIcs/D7VRJmxW6Zwwmm/i2QHSwdd9e5qHzKjCCw64hzzKLbKcXQDtuFHeeYzVxiYxJLDQAo25VZ32WZVJglRIIcQQXHzHjOMdOuTEVRlMQYEzpmccKsxo5ZXFemyPRLmREEoZgt05VZrqtu53NjR2Q6nGAqlSrD0VFBmA1uB5MbK4Sj2+nvstn9Y9sYkyrDIWK7M59bHf/ZfJ3MOMesvstwqTBLSBLHzMuPylTHTFEUJTEj/WEx7nIxZjXIZTYyYB2YUiP16vSPOpahHpuuIk6Ytc4B8So4ZpvLO2Z9z8OWxwqv4+iskP2/v0QOs/x2XGqOErnMXPxaqVGfCw+zMWYj/THb3mLLeRWnXwFobLfdsHV6vFWYJcQzlR2zfIxZpYR7iqIoSoHBEsllARpa7B/sQA1ymZWq2wjWParTP+pY4upkOjzfis9SwiyXsUK3pDALBdMzfw1fl3LMFtv1ZIfj3y9Vjim/nQq5zNyIz1LCbK8jrThd/4+x78Vl/Y9Sx0lmVZglxDNZTKUYs3D4tdGuTEVRlOQMxBQwj9Jao3qZ+TqZMUKhfYHtyqxUf7FeiCvHFKVcktnezYAp35UJsO4Oe0yaOuKXy6fMKOGalSrHlN+Oy/5fIpfZjo02gW7bHvHv73kUiG/bGbftuG5MR3v9CnEVZgnxqOyYOWGWzaowUxRFSUypOpmOlhqVZao0Ui/TD8O9u76dyaBUOSZHubJMLvasOLmswwmz7evKJ9x1sV+lUmb0lXEowQouv6GMY7bRHhevhCnS2GbjzNbdPva9gS0JhJmmy5jW+CZbyFNWijD4P5dRYaYoipKYUnUyHbUqZF62K9Mlma1PF2UMlYRZWcesRHLZ/Gfn5//PygqzzrBMU0nHrMs6Xk0z4t/3PLuO7aUcsw3xgf9RlrzcjswsjjPr31q5K7N3U12WUFRhlhDP5AgqCDMvdMxymmBWURQlOYPb7bSUY9Y6tzb1MvvKdK3V+Ui9MSQRZv3d8Xk13ejTjhKOmctlBvE5zBzu86VGZvZ1231dppQhCw6BZ+6MF0g7NsSnyoiy5OV2wN36uwrzjEnQlbnQfq5WaVhqiAqzJBhDihymYlemc8ymWSFcRVGUqaRSjFnL7EK9zF2hv8smRk3F1DLOZ/+fRsIs1WQHRsTRPt+mqogbNNG7yTpipUQdFJyyco5ZQ4t1OUsJs3LJZR3LTrHLbVw1er4xYUqPCsIsLs5spM+W8SorzOpXiKswS0J4x1GpK9NzwkwdM0VRlOQMbofGjsLI9mJqVS+zr6v0CME6zwY/BleOqVRhdxcwHxdH5ZLLlisKn0SYQfmUGX1dpePLHEtfbUXiI9eNnt+/xR7zcgXSIRJnFhFmbqBIOVFYx0lmVZglIRxlmXRUZi6r6TIURVESM1iiHJOjVklm45LLOhparJtWpwHhYyhVjsnhhJkLwI9SLrmsI6kw61hcOvg/rspCMU2dsM9x8OjvRzuilVJlRCmOMyuX9d9Rx0JchVkSXMJYr8TdXIgX1hwLtCtTURQlOaXqZDpqVci8v7u8i9K+sG5zW41hYEv5rsh2J8ziHLMyyWUdS18Ly06F2fuXX65zUUFERXFxXqWSy0ZZdgpsWwvdjxXm7SiT9b+Y4jizfNb/cvtnvj2vnvhT5fVPMirMkuC6JksN2Q3xUtYxC7QrU1EUJTml6mQ6XKzQrgZq91WIeWqfX5cxR7GUqpPpqNSVWSrw37HgYDjrivh4vCgdi2xNy+GibuahHZAbqeyYARx4sp0++vvCvB1VOGbFcWaV8qeB/T8/7K3w+B/K1+ucAlSYJSEfY1bBMfNDx0yFmaIoSnJK1cl0tNbAMcuO2DJG5WKeOhZOs67MMsIs3QyNnWO7Mod22li9UqkyqqVUygwnjirFmIHtVl20Ah69vjBv5wY7uKHcd3QUx5kl6coEWHGena76SeVtTCIqzJLgMvn7FYL/Uy7GTIWZoihKYio5ZrWol5nERWmfb4VZNMVEEMA9l5Uulj0V5DLWkaokWtr3sF2E0ditSsllq8Wlsyiud5kvx5SgKxNsd+Zz9xa6MHdstEK53ACFKEuOLcSZDWyFdKuNGyzHjD2tW3fv5ZAZSradSUCFWRLCGLPKjpnrytTgf0VRlEQEOSsyyjlmDS32j3ZX6mWWSy7raF9gU0w4x2VoJ1x5Dlz3frj23+qnXFM+71sFYbbfCfDkTXDNBYV6li7YvVLwf1LmHAipZlj109Hz+6sVZqfa6WMr7XTnxmTxZY5onFl/d/n4siiHn2/Pq4evTr6tCUaFWRLCrkmpEGPmu+B/dcwURVGSMdhjp+UcM7B/tLvSlVmuTqYjn/3/OdjyJFzyKnj8Rlj6Gtj8YHyx7KmgUp1Mx2svguM/Bvf9HH56sk0N4dJDVAr+T0rrbHjFx2x8WLQrsq+KrkyAuQfA7KWFOLMdGwvdpEnY80hbZWDdHZULmEfZ9xV2u/+4OPm2JhgVZklIOCrTd0XMNcZMURQlGYMVkss6drVeZj7mqUJXJsDqK+DiV1oB9NZr4Iyf2jxrd9fJn3demFUQH54Hr/wEnPkz6HoEfvQKePwG+16thBnA0RfAvOWw8iOFWqP9XSBeshgxx7JTQmG11Q5QqJT1P0pjGywM48wGtiR36kSsa7ZxVd10V6swS0IozKRCjJmftqNXtCtTURQlIZXqZDpa5+7aqMwkXWsu7urui2HmEnjXrbDPsfZP/9B/hoeuLjhBU0mlckzFLH8dnH+THWH58NU2d1il+Ktq8NNw2rdsqpFbvmTn9XXZ9lXoaRrFslPt/+3qy22XcjVdmVDIZ9azvnzW/2IOPdt2lf/jkuq2N0GoMEtC4jxm6pgpiqJUhXPMWio4ZrP3h+cfhu7Hx7edvu4wILy19DJt82DuMjj4LDjvxtHJVVe8ww4Eu/ey8W2/llQrzAD2eLEVmgecaLvvas2eR9hRjnf9wDpPSZLLFrPoZbaAuhNISVJlRHFxZoPbqts3TZ1w8Jmw5qrCjcIUosIsCQsO4SBzJU/OPK7sYr4G/yuKolRHUsfs5R+youq6D8QXvK5Ef1flZKeeD++7C97ww7GO0twDYJ/jbZD7VP/GJ40xK6ZlFvzzlXDm5bVvE8CrPmMdyes+YLsikySXjeJ5sOzkQsLaah0zF2cGybsyHUe8E7JDsPrn1X1uAlBhlpBMzpBOld9dfpiITx0zRVGUhOQdswoio20uvOa/4Nk7YfXPqt9OuTqZSTn8fCsanrhx19azqwxsg4b2yslfJ5vmGXDSl2HT/dY1G8/+XnZK4Xk1MWZQiDOD6roywTqKex8Dq348Ol3KFKDCLCHZwJDyy+dTSafDrsxAHTNFUZRRGANP3DS2aPTgdutyNHZUXsdL32LzVd30aeh9vrrtl6uTmZQDT7YuzlSP4KtUJ3MqWf56O4oVxre/lxxnz4XGDtvFWPXnX26n1TpmYIX39nXw5NSWaVJhlgBjDLnAkPIqOGZh8P+U29yKoij1RHbY5tK64gz44bHwzJ2F91ydzCSJREXg1G9CZhBuuLC6NlSqk5kEPwUvezusvcWm05gqKmX9n0pE4OSvWVE1d1n1n081wEFvgPkHj2/7B54MfkPlGp9xvOg0G+P20NTmNFNhloBMziYVTFdwzFK+T2BEuzIVRVEcvZvh0lNsLq2j3mv/sC87zbpOxtiuzEqpMqLMWQrHfQQe+i08/sdknwlyVszsqjADW1/RS9sur6minoUZwMy94T+egMP+ZXyfP+Ub8LZrx/fZPQ+HT2yCWftU/1k/DW9fCad/d3zbrhEqzBKQC6wwS/nld1fKF7J42pWpKMq4cL81Lxg23GNzZz3/ELzpMjjxInjnzbDfq2Dlf8C1F1jhVinwv5hjPmgzzl//4bHFs+MY2Aom2PWuTLBljpa/zuY6G+nf9fWNA1Pvwgx2Kf7t5se38LnrH+V3qzfwVHcfQbXXRYXUVmWZvV91KT4mgF1o/e5DJhwBlPIqOGaeR5YUrf3PwlO32B8CjL0rHOm3BXQHewpTjL1TbJ5lpy2z7KgjY0Z/tlwZEPFAACTsCpBwXjhF7HogXE/xuiKfw9g7S5Ozo55cG8QvrM/zC8vm21XqeXQzbhvRZczo5hR/j+j3d23Jty9n55mgsB8838aqiB9+95h2mcj3MoGdF91X0X3n1uNe5/dLzg7JLhUgmj8ORVPXjrj9n9++FNoVPQei6xrVVsa2PUr0+496zdjzJLrt4v3lHkEu5ju47+FFHtH2SOn9Ej1Ho8fU5IrOkaJzeNSUyPbcdymxuyPNzRNEtuna4M4n8cNpeD4VLxu7H90+iB4bGbv/o98D+MtjXfzq7vV85LUHsu+cttFtjK5j1PPo6oqvwXLXOqPPM/edcpnw3M7a5yawLoLfYKde2k7z+7j42EaOefejcOMnrZB5x00w/yC7bFMnnP0ruPVLcNtX7bwDTy5xoEqQarB5s356ohVnB7y2aP94kWPoFQps18IxAzj8nbDmN3Db12Cvoxl7Tsaco2MwRcfAYExgK8yIFH5/3G9uuG837RhiRs/zrGkSDq/NtymJMQZJWquyRnT1DvGBX95H73DB4GhvSnHI4hm8ctk8zjtmyaS3abJRYZaAbNiVWUmYpX1hB60s2fxH+FkZi91LQdMMe/ENbi/kSVMUZbfl+PDBFA/4qylLjrVOWXHdQs+DEz5l44iufk91pXccex9tBdLdF8MDv0r2mZl7V7+dOPY6yrb9jm/UZn0hSeSGy9e/cn2KdavW86YV49h3Cbj4trV895Yn+eI/HcSpB9eo4HkCvnj9IwxnA/7078eTCwz3r+/h/g093PPMdr7w+4fZ0jfMx04cR+zaNEKFWQKyudAxq9iV6XHG8H9y4dGtnHbootF38OlmO5S4aYZ1xaJ30CN9NgB2cBuMDBQ5Ds5RiKPITSp2OtwdWam77eI7u/ydvl9wC6DIxQiK1ll0x1z8PNpO9zzWTYq504xz/7yi9kl4TPJuR+hmGfd5RrdrlKsTcWvy37HU84iLEnVSxhybMnfOo/ZPzP537mD+O0tMG4uWHeX+lcrtVLTdUW5ksXsYcTyKnUMvut+Kv3Ywuk2jXK8K+8Xt21EuZWQ7o84VRh/XuH0ZdUKj3z96jPJPTcFViToszh2MuqTI2GXdOor3Y+y+jWmPCIGBz173ME919/OOY/flOzc/wWuW78F7jt9v9Hcr/p6xrkGcO1n0ObcPRu1r7Gvninnhc/EKLlpuJHxkx64reizdcy8Niw8v3620/HU2RUFjW+llynHyV+HId9vEr6PaE3V5w2m62aZEqAUitlTT1qci+xpif9vG/B5G12P3//qeId5zxWoGMwEehs+cuoxj95tV+F3DkM0FfOG6NTz03A4+c9pBrH2gkSt+t4Z95rSyYkltR2iufHATX1z5CB1NKS74xWoe2LCDj772wIr/gbvKHU9s4ZpRTB6pAAAWHElEQVT7nuMDr1rK/vPsOXHg/HbOPHxPjDF88uo1fP/Wp5jZkuZdx+03oW2ZSlSYJSAT9m9XDP73hI3MZX3ngbB3whEhItDYbh+1uptTFGXacMXf1nH5hq1c9IaXcMIRe3FX5hG+/Je17PfSvXjNi+dPdfMmnmI3rRpEYM44Rt/VgpZZNUlZ0Tec5byf/5WtqcX88j1HceFvH+Dtv9/B/5yz/6jj/6XrHuayZ+fw5Te+koMP34tvvyTD6//nr/zrz+7hmguOYfHM2pRYum99Dx+68j4O22sGl513BF+98TF+dNtaHnpuB985+zBmtTbUZDvFDGVyfPqaNSyZ3cJ7XjFWdIkIXzj9IHYMZPjSykeZ0dLAmRPkFk41GvyfgLxjViFdhuvqdF2fiqIo5Vi/bYCL/vAoxy6dw1mH2z+ZD7/6QF68sIMLf/sgXTuHpriFykRijOGjV93PU919fPfsl3Lg/HYuO+8IXryok/f94l7+9LDN1fbru9fzk78+zduPWcKbD7dlojpb0lz81hWM5ALOv2wV/cO7HhKzYfsA51+2inkdjVz81hW0N6X5/OkH8dUzDubudds57Tt38OCGHbu8nTh++Je1PL2ln8+ffhBN6fjge98TvvHmQzh26Rwu/M0D/PGhzRPSlqlGhVkCsvlRmeUdMz8vzMZRLkRRlN2KIDB85Kr78UT47zcenA9obkh5fOusQxkYyfLh/72/+hFpyrThh7etZeWDm7nwpGX8n/1tpvqOpjSXn3cEL1rQwXuvuJfv/PkJPnn1gxy7dA6fPPlFoz6//7w2vvfPh/H487186Mr7dulc2TmU4bxL72Y4m+On5x7O7LbCqMo3rdiTq959NMYY3viDO7lo5SNs7x8Z97aKWbeln+/d+iSnHryA4w4oP0CjMeXzg7e8jJcsnsEFv1zN39durVk76oVJF2YicqKIPCYiT4rImAyBItIoIleG798lIksmu43FFIL/y+8uESHtS17IKYqilOLndz3D39du41OnvIhFM5pHvbf/vHY+dcpybn9iC5feuW5qGjjBbNg+wFX3bOCjV93Phb95gB/d9hR/fuR51m3p3y1ubm9/opuv3PAopxy8gHceu++o9zqb0/zsvCM5YH4bX7/pcRbPbOG7Zx8WG+N13AFz+fSpy/njw8/zT//zV/7r9w9z3f3PsX7bAKbciP4ImVzA+664l7Xd/fzgLS9j/3ntY5Y5ePEMrvu3l3PqwQv40e1rOe4rt/DtPz9B3y46dcYYPn3NGhp9j0+fujzRZ1obU1x67uHsNauFt/7kH7zn5/dwzX0b6R16YeQQndQYMxHxge8BrwY2AHeLyLXGmIcji70D2G6M2V9EzgK+DLx5MttZTCYf/F95zEzK83j8+V6uf2ATnoDnCZ4IvmeFmydi54vkE1kYAyYM7rXP7clqsAsYjI1fBQJj7HvGrc+uy/NAEDK5gJFcYKdZ+8gGtnJBLjDkjCGXM3ie4HtCKjL1PMEX+9o9wArTXGDIBoZsEGCM3Rdp3yMdTgVhYCTLwEiO/pEsA8M5hrM5mtM+rY0pWhpTtDX6NKdTBMaQyQVkc3aaCQy5XKGdbuqJHVCR8uw2Ur6Q9jzbXr/QbhCMMQQGcuH+CcJ95Papwy/6vp4II9mA4WwunBb2WWAi+y0wbB8YYUvfMFv67HRb3witjSkWzWxm8cxmFs1oZtHMZnwRegYz9Axk2DE4Qs9ABhGY29bI3PZG5rU3Mbe9kbamlP2+Obtfc4EhEz7P5AL7PGfIBUH+PHEYY/fTSLgfs+F+dPvKHZuGlIdg82MFxp4/gTEIo4+zHx57cedm0dSdtxLOM+7kjLRL3PsUPpuNfD93HiHgh+esW7c9J+x3dZ9x204VtTPte+FU8MObJXeuj+TscczkzKhrzfNsu1wsvIi9XqLtdsu6a34ok2M4G04zASnfoznt0dzg05T2890thW3bqTt33ToRGRP6nc0FfOXGxzjugLm8+fD4OJlzjtyLWx/r4r//8CgrH9zEnLZG5rQ3MKetkdltjTT6XmzbDaZilh33G1M4ivY6HxjJ0j9sr+H+4SzD2SC/34TwfPCEppRPU9qjOe3T3ODTmPLCfRxpixtjgj33jIFsEPDghh38be1WNmwfBGBGSxpfhK0RBybtCwtnNDOrtYHZrY3MaWtgdlsD7U1pUu53K/xtyAaGbX0jbO0fZmt4bQ6M5Jjb3sj8ziYWdDSxR6e95gTC69pdEyb/W5nJGUayOTI5W37PHmMv/K4+ucAwmMkxFD4GMzlyAaPOsehz3x0XT8DAUNZ91p5Tv/jHsyyd185XIm5plM6WND9/x5F89+YnOeeovelsSZc8nuf+nyXkAsPKBzdx+d+fYeSOpwGY1drA/nPbmNvRyLzwt2deeyMG61I9vbWfdVvso38kx5ff+BKO2b90jcnZbY1848xDeffx+/H1Pz7GN256nEvvXMe/Hrcve89uzV/n7nepdyjD9sjvYM9Ahsa0Z49pewNzWhvZMZjh9ie28NnTlrNHR1Ppk7aIma0N/OL8I/nOzU9yw0Ob+cOazTT4HscuncMJL5pHe1M6f9258zd/bUBZ0bpkdiuH7DkjcVtqjSRV1DXZmMjRwGeNMa8NX38cwBhzUWSZG8Nl/iYiKWAzMNeUaeiKFSvMqlWrJqzdI9mA53cOMbutgZaG8lr2uK/cwrPbBiasLdMFT2yXzFDmhXHnKwIzmtPMbW+0f5BtjcxqbaB3KMvGngE29gyyqWdolFua9oUZLQ3MaE6TM4bu3mF6hyYmNYoI+T+pSbyklV1gTlsj115wDAuL3LIo2/pH+MoNj/LM1oHwpmCY7QMT7wo0pT1aG1I0pqzwDUxB8AXGMJQJQmFS/ck2oyXNkfvM4uh9Z3PUfrM5YF47nif0DIzwVHc/T3X3sba7n007BvNCa2v/CNv6R8pub0ZLmtmtDcxua6Slwae7d5jNO4ZGCb56wfeExTObufTtR7DPnNaarjuTC3hscy/3b+jh/vU9rNs6QHfvMF07h+gfKeRedG1YMruVfea0cuQ+szjpJQvKrHks96/v4Wt/fIzbn9hScpno72Bnc5rhbMDW8AZ3JDQ9Dl7cye/ee0zeDKiWIDDc++x2/rBmMzes2czGnsFxrcfxlqP24r9e/5JdWkclROQeY8yK2PcmWZidAZxojDk/fP0vwJHGmAsiy6wJl9kQvn4qXKbkkZ9oYVYNO4cydO0csu5NeEcWBAWnwt45FtwLezcquBtMcHf1hXnRO2KJ3hkb8s6QW3/a92hIeTSEU+eceM4lCh0xE7Yv6tQUO0S58NxIex6+X3AuBOuEFFydgMBAS0PojoV30CJCEN5l9od34gMj2dCxsm1M+VJwwiLb8EUw2Lv4TFBwhZyblskFeXctMAbfK3IjpcghCe+Wot8tm7P7riHl0ZjyaEwX9lvK8/A88g5ikoSGucDQ1TuEMfZPojntj/ncUCZHd+8w3X1WpKU9ITXKASo4Xs798sM7ccJzwZHyxe5Db3Qb3f4ZzgZ5t9eT0BELXSoTnpfZILAuarhfnNMShI6t/SOMuG3huRx1UNw5Gz0f3fOUb491KnJso+87h9M6Y/YcSEfc2lzROVnswGXD5M8Nvm/P+/BYpjzJO8zuewXOITIFl8g5086FdgIk7Xs0pe153JT2afC9UY7JYCbHYPgn1xhu11137ju6fRi4DUbORbBJM0sFOZcjkwvY3j9CJjAE7piF57Y7B6IuVymip6YV9h6tjT4tDanEf5DOWRzKBPnfNeeQBUHB9XRt8cQKUm8cf8Dut8S5/s7B90WY2dpAukQqh+Fsjq6dw3T1DgPkf188j/z1Vvx7mc2Z0OEK8g5ZyvOsgxa6hE1pn5Qn+XPM/ba73xf3G+vEZN6BS/sl2zrR9A9n6eodxhjD4pktNKRq046nuvsYyuRG/W6lfY/2Jvt/EPf7aYyhdzjLlt5h5nU00dZYmw48YwzPbhvI//YVHDIi/6kFdz+O9iZ7Ez6RlBNm0zZdhoi8C3gXwF577TXFrSnQ0ZSmo6m05Vx/TGzpCc8TWhtTtDamYGzYQkXSPjRPcBtrhe8JCzpLux9gf5z3nNXCnrNqM7S9VDt8zx/Xn75SnuaGqd+nad9jXhVdPhNJ4Q944rflfkuqpTE18dfcdKG1McU+NRJAUfabW30eOhGZkP9LEWHv2bV1ISebyZbtG4FoQMXicF7sMmFXZicwZtiFMeZHxpgVxpgVc+fWqMyGoiiKoijKFDLZwuxuYKmI7CMiDcBZQHEJ+WuBt4XPzwBuLhdfpiiKoiiK8kJhUrsyjTFZEbkAWw3OB35ijHlIRD4PrDLGXAv8GPiZiDwJbMOKN0VRFEVRlBc8kx5jZoxZCawsmveZyPMh4E2T3S5FURRFUZSpRjP/K4qiKIqi1AkqzBRFURRFUeoEFWaKoiiKoih1ggozRVEURVGUOkGFmaIoiqIoSp2gwkxRFEVRFKVOUGGmKIqiKIpSJ6gwUxRFURRFqRNUmCmKoiiKotQJ8kIoQyki3cAzk7CpOcCWSdiOUh16XOoXPTb1iR6X+kSPS/1S62OztzFmbtwbLwhhNlmIyCpjzIqpbocyGj0u9Ysem/pEj0t9oselfpnMY6NdmYqiKIqiKHWCCjNFURRFUZQ6QYVZdfxoqhugxKLHpX7RY1Of6HGpT/S41C+Tdmw0xkxRFEVRFKVOUMdMURRFURSlTlBhlgAROVFEHhORJ0Xkwqluz+6KiOwpIreIyMMi8pCIfCCcP0tEbhKRJ8LpzKlu6+6KiPgislpEfh++3kdE7gqvnStFpGGq27i7ISIzROQqEXlURB4RkaP1mqkPRORD4W/ZGhH5pYg06TUzNYjIT0SkS0TWRObFXidi+XZ4jB4QkcNq2RYVZhUQER/4HnASsBw4W0SWT22rdluywIeNMcuBo4D3hcfiQuDPxpilwJ/D18rU8AHgkcjrLwPfNMbsD2wH3jElrdq9+RZwgzFmGXAI9vjoNTPFiMgi4P3ACmPMQYAPnIVeM1PFpcCJRfNKXScnAUvDx7uA79eyISrMKnME8KQxZq0xZgT4FXD6FLdpt8QYs8kYc2/4vBf7B7MIezwuCxe7DHj91LRw90ZEFgOnAJeErwU4AbgqXESPzSQjIp3AccCPAYwxI8aYHvSaqRdSQLOIpIAWYBN6zUwJxpjbgG1Fs0tdJ6cDlxvL34EZIrKgVm1RYVaZRcD6yOsN4TxlChGRJcBLgbuAPYwxm8K3NgN7TFGzdnf+H/BRIAhfzwZ6jDHZ8LVeO5PPPkA38NOwi/kSEWlFr5kpxxizEfga8CxWkO0A7kGvmXqi1HUyobpAhZky7RCRNuA3wAeNMTuj7xk7zFiHGk8yInIq0GWMuWeq26KMIgUcBnzfGPNSoJ+ibku9ZqaGMF7pdKx4Xgi0MrYrTakTJvM6UWFWmY3AnpHXi8N5yhQgImmsKLvCGPPbcPbzzkYOp11T1b7dmGOA14nIOmx3/wnY2KYZYTcN6LUzFWwANhhj7gpfX4UVanrNTD3/F3jaGNNtjMkAv8VeR3rN1A+lrpMJ1QUqzCpzN7A0HCnTgA3OvHaK27RbEsYs/Rh4xBjzjchb1wJvC5+/Dbhmstu2u2OM+bgxZrExZgn2GrnZGHMOcAtwRriYHptJxhizGVgvIgeGs14FPIxeM/XAs8BRItIS/ra5Y6PXTP1Q6jq5FnhrODrzKGBHpMtzl9EEswkQkZOx8TM+8BNjzBenuEm7JSLycuB24EEKcUyfwMaZ/RrYC3gGONMYUxzEqUwSIvIK4D+MMaeKyL5YB20WsBp4izFmeCrbt7shIodiB2Q0AGuBt2NvyvWamWJE5HPAm7EjzlcD52NjlfSamWRE5JfAK4A5wPPAfwJXE3OdhEL6u9iu5wHg7caYVTVriwozRVEURVGU+kC7MhVFURRFUeoEFWaKoiiKoih1ggozRVEURVGUOkGFmaIoiqIoSp2gwkxRFEVRFKVOUGGmKErdICLniogRkf3D1x8UkTdMYXtmiMhnReSwmPduFZFbp6BZiqK8gElVXkRRFGXK+CBwBzYr+lQwA5vPaANwb9F775385iiK8kJHhZmiKLsVItJYi4SdxpiHa9EeRVGUKNqVqShKXRLW3dwbOCfs3jQicmnk/UNE5FoR2S4igyLyVxE5tmgdl4rIBhE5WkTuFJFB4Cvhe2eJyM0i0i0ifSKyWkTeFvnsEuDp8OXFkTacG74/pitTRA4Ukd+JSE/Ypr+LyIlFy3w2XM9SEbk+3PYzIvIZEfEiy7WJyHdE5FkRGRaRLhH5k4gs28VdqyhKHaPCTFGUeuWfgM3AjcDR4eMLAGHM153YsjXvBN4IbAX+JCIvK1pPJ7bEzS+Bk4BfhPP3xRb1Pgd4PXAdcImIvDt8fxPg4tsuirTh+rjGishCbLfrIcAFwJlAD3C9iJwU85HfATeH274a+ByFunwA3wzX8Tng1cC/Avdhu1cVRXmBol2ZiqLUJcaY1SIyDGwxxvy96O2vYotAn2CMGQEQkRuBNcCnsWLH0YatNziqGLQx5kvueehU3QosAN4D/MAYMywiq8NF1sa0oZh/B2YCRxtjngzXuxJbmPqLwB+Klv+6Mean4fM/icgJwNmAm3c0cIUx5seRz/yuQhsURZnmqGOmKMq0QkSageOB/wUCEUmJSAoQ4E/AcUUfyQC/j1nPUhH5pYhsDJfJYItIHzjOph0H/N2JMgBjTA7r1B0qIh1Fyxc7b2uwxZIddwPnisgnRGSFiPjjbJeiKNMIFWaKokw3ZgE+1hnLFD0uAGZGY7WA7lAg5RGRNuAmbLfjhcCxwOHAT4DGXWjXppj5m7GicWbR/G1Fr4eBpsjrfwN+CJyHFWldIvJNEWkZZ/sURZkGaFemoijTjR4gAL4HXB63gDEmiL6MWeRo7MCCY40xd7iZofM2XrYB82Pmzw/bsL2alRlj+oCPAx8Xkb2BM4D/BkaAj+1COxVFqWNUmCmKUs8MA83RGcaYfhG5Het23VskwpLiXKeMmyEiM4HTY7ZPcRtK8BfggyKyxBizLlynD7wZWG2M2TmOdgJgjHkG+LqInAMcNN71KIpS/6gwUxSlnnkYOFZETsV2CW4JRc+/A7cBN4rIj7FdiHOAwwDfGHNhhfXeCewEvici/wm0Ap8CtmBHcTqex472PEtEHgD6gaeNMVtj1vlN4FzgpnCdO7FJaA8ATqnyeyMifwOuBR4E+rBxdYcAl1W7LkVRpg8aY6YoSj3zceAx4NfYOKvPAhhj7sXGhG0Fvg38EfgW8BKsYCuLMaYbm47Dx6bMuAi4BPh50XIBdkDATOzAgruB00qs8zng5cBDwPfD9c4CTjHG3JD4Gxe4DZsu4wrsQIEzgA8ZY741jnUpijJNEGPiwi8URVEURVGUyUYdM0VRFEVRlDpBhZmiKIqiKEqdoMJMURRFURSlTlBhpiiKoiiKUieoMFMURVEURakTVJgpiqIoiqLUCSrMFEVRFEVR6gQVZoqiKIqiKHWCCjNFURRFUZQ64f8DnMv8dGIbwecAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "print(\"num_samples:\", NUM_SAMPLES)\n", "print(\"num_features:\", NUM_FEATURES)\n", "print(\"num_features:\", NUM_CLASSES)\n", "print(\"maxiter:\", MAXITER)\n", "print(\"stepsize:\", STEPSIZE)\n", "print(\"linesearch (ignored if `stepsize` > 0):\", LINESEARCH)\n", "print()\n", "\n", "errors, step_times, compile_time = run()\n", "print('Average speed-up (ignoring compile):',\n", " round((step_times['without_pmap'] / step_times['with_pmap']).mean(), 2))" ] } ], "metadata": { "accelerator": "TPU", "colab": { "provenance": [] }, "gpuClass": "standard", "jupytext": { "formats": "ipynb,md:myst" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.1 (main, Dec 23 2022, 09:28:24) [Clang 14.0.0 (clang-1400.0.29.202)]" }, "vscode": { "interpreter": { "hash": "5c7b89af1651d0b8571dde13640ecdccf7d5a6204171d6ab33e7c296e100e08a" } } }, "nbformat": 4, "nbformat_minor": 0 }